Source code for climagrid.forecasting.dataset

"""
Dataset construction for stress-feature forecasting.

Two stages:

1. ``build_training_panel`` turns an asset registry into a daily panel by
   fetching hourly climagrid features per asset (streaming one asset at a time
   so the full hourly panel is never held in memory) and aggregating to one
   row per (asset, day). The daily panel is optionally cached to parquet.

2. ``build_supervised_frame`` turns the daily panel into a supervised learning
   frame: backward-looking predictors (autoregressive lags, trailing rolling
   statistics, calendar harmonics, static location) plus the shifted target
   columns ``y_h1..y_hH``.

All predictors are strictly causal (computed from data available at or before
the forecast origin), and rows are sorted by (asset_id, date) before any lag
or rolling operation because the underlying climagrid feature functions assume
time-sorted input and do not sort internally.
"""

from __future__ import annotations

import hashlib
import logging
import os
import tempfile
from collections.abc import Callable, Iterator
from datetime import datetime, timedelta
from pathlib import Path

import numpy as np
import pandas as pd

from climagrid.assets.registry import AssetRegistry
from climagrid.forecasting.config import ForecastConfig
from climagrid.pipeline.orchestrator import run as default_run

logger = logging.getLogger(__name__)

RunFn = Callable[..., pd.DataFrame]

# Long fetches are split into chunks so a single huge request (e.g. 25 years of
# hourly NASA POWER data) cannot time out or be rejected. Each chunk is fetched
# with a lookback buffer so trailing-window features (up to 720 h) stay
# continuous across chunk boundaries; the buffer days are dropped after
# aggregation. NASA POWER hourly data begins in 2001, the floor for the buffer.
_FETCH_CHUNK_YEARS = 5
_FETCH_BUFFER_DAYS = 35
_HISTORY_FLOOR_YEAR = 2001


def predictor_columns(config: ForecastConfig) -> list[str]:
    """Return the ordered list of predictor column names for a config."""
    cols = ["y_t"]
    cols += [f"lag_{lag}" for lag in config.lags]
    for window in config.rolling_windows:
        cols += [f"rollmean_{window}", f"rollstd_{window}"]
    cols += ["doy_sin", "doy_cos", "lat", "lon"]
    return cols


def horizon_target_columns(config: ForecastConfig) -> list[str]:
    """Return the shifted-target column names ``y_h1..y_hH``."""
    return [f"y_h{h}" for h in range(1, config.horizon_days + 1)]


def _iter_single_asset_registries(registry: AssetRegistry) -> Iterator[AssetRegistry]:
    """Yield a one-asset AssetRegistry for each asset, via a temp CSV.

    Streaming one asset at a time bounds peak memory to a single asset's hourly
    series rather than the whole panel.
    """
    assets = registry.assets
    for _, row in assets.iterrows():
        with tempfile.NamedTemporaryFile(
            mode="w", suffix=".csv", delete=False
        ) as handle:
            pd.DataFrame(
                [{"asset_id": row["asset_id"], "lat": row["lat"], "lon": row["lon"]}]
            ).to_csv(handle, index=False)
            path = handle.name
        try:
            yield AssetRegistry(path)
        finally:
            os.unlink(path)


def _aggregate_daily(raw: pd.DataFrame, config: ForecastConfig) -> pd.DataFrame:
    """Aggregate an hourly climagrid frame to one row per (asset_id, date)."""
    df = raw.copy()
    timestamps = pd.to_datetime(df["timestamp"], utc=True)
    df["date"] = timestamps.dt.tz_convert("UTC").dt.tz_localize(None).dt.normalize()

    present = [t for t in config.targets if t in df.columns]
    if not present:
        logger.warning(
            "None of the target columns %s are present in fetched data.",
            config.targets,
        )

    agg: dict[str, str] = {target: config.daily_agg for target in present}
    agg["lat"] = "first"
    agg["lon"] = "first"
    daily = df.groupby(["asset_id", "date"], as_index=False).agg(agg)
    return daily  # type: ignore[no-any-return]


def _cache_path(
    registry: AssetRegistry,
    history_start: datetime,
    history_end: datetime,
    config: ForecastConfig,
) -> Path | None:
    """Deterministic cache filename for a panel, or None when caching is off."""
    if config.cache_dir is None:
        return None
    key = "|".join(
        [
            str(registry.count),
            ",".join(sorted(str(a) for a in registry.assets["asset_id"])),
            history_start.isoformat(),
            history_end.isoformat(),
            ",".join(sorted(config.targets)),
            ",".join(sorted(config.sources)),
            config.daily_agg,
        ]
    )
    digest = hashlib.sha256(key.encode("utf-8")).hexdigest()[:16]
    return Path(config.cache_dir) / f"panel_{digest}.parquet"


def _date_chunks(
    start: datetime, end: datetime, chunk_years: int
) -> list[tuple[datetime, datetime]]:
    """Split ``[start, end]`` into consecutive chunks of at most ``chunk_years``."""
    chunks: list[tuple[datetime, datetime]] = []
    chunk_start = start
    while chunk_start <= end:
        try:
            boundary = chunk_start.replace(year=chunk_start.year + chunk_years)
        except ValueError:  # Feb 29 -> fall back to Feb 28 in the target year
            boundary = chunk_start.replace(
                year=chunk_start.year + chunk_years, day=28
            )
        chunk_end = min(boundary - timedelta(days=1), end)
        chunks.append((chunk_start, chunk_end))
        chunk_start = chunk_end + timedelta(days=1)
    return chunks


[docs] def build_training_panel( assets: AssetRegistry | str | Path, history_start: datetime, history_end: datetime, config: ForecastConfig, *, run_fn: RunFn | None = None, ) -> pd.DataFrame: """ Build the daily training panel for a set of assets. Parameters ---------- assets: An AssetRegistry or a path to an asset CSV/GeoJSON. history_start, history_end: Inclusive UTC date range of history to fetch. config: Forecast configuration (targets, sources, daily aggregation, cache). run_fn: Injection point for ``climagrid.run`` (used by tests to mock fetching). Returns ------- pd.DataFrame Columns: ``asset_id``, ``date``, ``lat``, ``lon`` and one column per present target, with one row per (asset, day), sorted by (asset_id, date). Empty if no data was returned for any asset. """ resolved_run = run_fn or default_run registry = assets if isinstance(assets, AssetRegistry) else AssetRegistry(assets) cache_path = _cache_path(registry, history_start, history_end, config) if cache_path is not None and cache_path.exists(): logger.info("Loading cached daily panel from %s", cache_path) return pd.read_parquet(cache_path) # type: ignore[no-any-return] features = config.required_features() floor = datetime(_HISTORY_FLOOR_YEAR, 1, 1, tzinfo=history_start.tzinfo) buffer = timedelta(days=_FETCH_BUFFER_DAYS) chunks = _date_chunks(history_start, history_end, _FETCH_CHUNK_YEARS) daily_frames: list[pd.DataFrame] = [] for sub in _iter_single_asset_registries(registry): asset_chunks: list[pd.DataFrame] = [] for chunk_start, chunk_end in chunks: # Fetch a little before the chunk so trailing-window features are # warm at the boundary; the buffer days are dropped below. fetch_start = max(chunk_start - buffer, floor) try: raw = resolved_run( sub, fetch_start, chunk_end, sources=config.sources, features=features, ) except Exception as exc: logger.warning( "Fetch failed for chunk %s..%s: %s", fetch_start.date(), chunk_end.date(), exc, ) continue if raw is None or raw.empty: continue daily = _aggregate_daily(raw, config) lo = pd.Timestamp(chunk_start.date()) hi = pd.Timestamp(chunk_end.date()) daily = daily[(daily["date"] >= lo) & (daily["date"] <= hi)] if not daily.empty: asset_chunks.append(daily) if asset_chunks: daily_frames.append(pd.concat(asset_chunks, ignore_index=True)) else: logger.warning("No data returned for an asset; skipping it in the panel.") if not daily_frames: logger.warning("Training panel is empty: no asset returned any data.") return pd.DataFrame() panel = ( pd.concat(daily_frames, ignore_index=True) .sort_values(["asset_id", "date"]) .reset_index(drop=True) ) if cache_path is not None: cache_path.parent.mkdir(parents=True, exist_ok=True) panel.to_parquet(cache_path, index=False) logger.info("Cached daily panel to %s", cache_path) return panel # type: ignore[no-any-return]
[docs] def build_supervised_frame( panel: pd.DataFrame, target: str, config: ForecastConfig, ) -> pd.DataFrame: """ Turn a daily panel into a supervised frame for one target. Predictors (all known at the forecast origin ``t``): - ``y_t``: the target value at the origin, - ``lag_k``: the target value at ``t - k`` for each configured lag, - ``rollmean_w`` / ``rollstd_w``: trailing rolling mean/std over the ``w`` days ending at ``t - 1`` (shifted by 1 so the origin is excluded), - ``doy_sin`` / ``doy_cos``: day-of-year harmonics of the origin date, - ``lat`` / ``lon``: static location. Targets: ``y_h{h}`` is the target value at ``t + h`` for each horizon. Returns one row per (asset, origin date). Early rows have NaN lag/rolling predictors and late rows have NaN ``y_h*`` targets; both are retained here (LightGBM tolerates NaN predictors; NaN targets are dropped at fit time). """ if target not in panel.columns: raise KeyError(f"Target {target!r} not in panel columns: {list(panel.columns)}") if panel.empty: return panel.copy() # type: ignore[no-any-return] df = panel.sort_values(["asset_id", "date"]).reset_index(drop=True).copy() grouped = df.groupby("asset_id", sort=False)[target] df["y_t"] = df[target].astype(float) for lag in config.lags: df[f"lag_{lag}"] = grouped.shift(lag) for window in config.rolling_windows: df[f"rollmean_{window}"] = grouped.transform( lambda s, w=window: s.shift(1).rolling(w, min_periods=1).mean() ) df[f"rollstd_{window}"] = grouped.transform( lambda s, w=window: s.shift(1).rolling(w, min_periods=2).std() ) day_of_year = df["date"].dt.dayofyear df["doy_sin"] = np.sin(2.0 * np.pi * day_of_year / 365.25) df["doy_cos"] = np.cos(2.0 * np.pi * day_of_year / 365.25) for h in range(1, config.horizon_days + 1): df[f"y_h{h}"] = grouped.shift(-h) keep: list[str] = ["asset_id", "date", target] keep += predictor_columns(config) keep += horizon_target_columns(config) ordered: list[str] = [] for col in keep: if col not in ordered: ordered.append(col) return df[ordered] # type: ignore[no-any-return]