Source code for ocean_taco.dataset.dataset

"""OceanTACO Dataset - PyTorch Dataset for ocean surface state data.

Design principles:
- Query-based: samples defined by (bbox, time_range) queries
- Trust TACO: use built-in filter_datetime and filter_bbox, then flatten
- Lazy loading: use dask for windowed reads of large files
- Fast merging: vectorized operations for multi-region results
- DataLoader compatible: no internal parallelism
"""

from __future__ import annotations

import io
import re
from datetime import timedelta
from pathlib import Path
from typing import Literal

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import requests
import tacoreader
import torch
import torch.nn.functional as F
import xarray as xr
from torch.utils.data import Dataset

from .queries import Query

tacoreader.use("pandas")


# =============================================================================
# Constants
# =============================================================================

VAR_NAMES = {
    "glorys_ssh": "zos",
    "glorys_sst": "thetao",
    "glorys_sss": "so",
    "glorys_uo": "uo",
    "glorys_vo": "vo",
    "l4_ssh": "sla",
    "l4_sst": "analysed_sst",
    "l4_sss": "sos",
    "l4_wind": "eastward_wind",
    "l3_sst": "adjusted_sea_surface_temperature",
    "l3_sss_smos": "Sea_Surface_Salinity",
    "l3_ssh": "sla_filtered",
    "l3_swot": "ssha_filtered",
    "l3_sss_smos_asc": "Sea_Surface_Salinity",
    "l3_sss_smos_desc": "Sea_Surface_Salinity",
    "argo": "TEMP",
}

GRIDDED_SOURCES = {
    "glorys",
    "l4_ssh",
    "l4_sst",
    "l4_sss",
    "l4_wind",
    "l3_sst",
    "l3_ssh",
    "l3_swot",
    "l3_sss_smos_asc",
    "l3_sss_smos_desc",
}
POINT_SOURCES = {"argo"}

COL_VSI = "vsi_path"
TACOPAD_PREFIX = "__TACOPAD__"
_VSI_PATTERN = re.compile(r"/vsisubfile/(\d+)_(\d+),(.+)")


# =============================================================================
# Utilities
# =============================================================================


def build_file_index(
    taco_path: str, queries: list[Query], variables: list[str]
) -> list[pd.DataFrame]:
    """Pre-index files at init. Single SQL query, then split by date."""
    dataset = tacoreader.load(taco_path)

    # Find global time range across all queries
    all_dates = set()
    for q in queries:
        all_dates.add(pd.to_datetime(q.time_start).date())
        all_dates.add(pd.to_datetime(q.time_end).date())

    if not all_dates:
        dataset.close()
        return [pd.DataFrame() for _ in queries]

    time_start = min(all_dates)
    # Use exclusive upper bound so "2024-12-31T00:00:00+00:00" < "2025-01-01"
    time_end_excl = max(all_dates) + timedelta(days=1)

    # Build variable filter using tacoreader l2 column prefix
    var_conditions = []
    seen_glorys = False
    for var in variables:
        if var.startswith("glorys_"):
            if not seen_glorys:
                var_conditions.append("\"l2:data_source\" = 'glorys'")
                seen_glorys = True
        else:
            var_conditions.append(f"\"l2:data_source\" = '{var}'")
    var_filter = " OR ".join(var_conditions) if var_conditions else "1=1"

    # tacoreader flat-view SQL: tables are l0/l1/l2, columns prefixed "lN:col"
    sql = f"""
        SELECT
            "l0:stac:time_start"          AS time_start,
            "l2:internal:gdal_vsi"        AS vsi_path,
            "l2:res_deg_lat"              AS res_deg_lat,
            "l2:data_source"              AS data_source
        FROM l2
        WHERE "l2:id" NOT LIKE '{TACOPAD_PREFIX}%'
          AND "l0:stac:time_start" >= '{time_start}'
          AND "l0:stac:time_start" <  '{time_end_excl}'
          AND ({var_filter})
    """
    raw = dataset.sql(sql)
    dataset.close()

    # Normalise to pandas
    all_files: pd.DataFrame = raw.to_pandas() if hasattr(raw, "to_pandas") else raw
    all_files["time_start"] = pd.to_datetime(all_files["time_start"]).dt.date

    # Split by per-query date range (fast pandas filter)
    index = []
    for q in queries:
        q_start = pd.to_datetime(q.time_start).date()
        q_end = pd.to_datetime(q.time_end).date()
        mask = (all_files["time_start"] >= q_start) & (all_files["time_start"] <= q_end)
        index.append(all_files[mask].copy())

    return index


def parse_vsi_path(vsi_path: str) -> tuple[int, int, str] | None:
    """Extract (offset, size, filepath) from /vsisubfile/offset_size,path format."""
    if match := _VSI_PATTERN.match(vsi_path):
        return int(match[1]), int(match[2]), match[3]
    return None


def load_netcdf_var(
    vsi_path: str, var_name: str, bbox: tuple[float, float, float, float]
) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None:
    """Load single variable from NetCDF, cropped to bbox. Handles local and remote files."""
    parsed = parse_vsi_path(vsi_path)

    if parsed:
        # /vsisubfile/offset_size,path — byte-range subfile
        offset, size, filepath = parsed
        if filepath.startswith("/vsicurl/"):
            url = filepath.replace("/vsicurl/", "")
            headers = {"Range": f"bytes={offset}-{offset + size - 1}"}
            resp = requests.get(url, headers=headers)
            resp.raise_for_status()
            ds = xr.open_dataset(io.BytesIO(resp.content), engine="h5netcdf")
        else:
            with open(filepath, "rb") as f:
                f.seek(offset)
                data = f.read(size)
            ds = xr.open_dataset(io.BytesIO(data), engine="h5netcdf")
    elif vsi_path.startswith("/vsicurl/"):
        # Direct /vsicurl/ URL — download entire file
        url = vsi_path.replace("/vsicurl/", "")
        resp = requests.get(url)
        resp.raise_for_status()
        ds = xr.open_dataset(io.BytesIO(resp.content), engine="h5netcdf")
    else:
        # Local file path
        ds = xr.open_dataset(vsi_path, engine="h5netcdf")

    if var_name not in ds:
        ds.close()
        return None

    lon_min, lon_max, lat_min, lat_max = bbox
    lons, lats = ds["lon"].values, ds["lat"].values

    lon_mask = (lons >= lon_min) & (lons <= lon_max)
    lat_mask = (lats >= lat_min) & (lats <= lat_max)

    if not lon_mask.any() or not lat_mask.any():
        ds.close()
        return None

    ds = ds.isel(lon=lon_mask, lat=lat_mask)
    data = ds[var_name].values
    lats_out = ds["lat"].values
    lons_out = ds["lon"].values
    ds.close()

    return data, lats_out, lons_out


def _interpolate_to_patch(data: np.ndarray, target_size: tuple[int, int]) -> np.ndarray:
    """Bilinear interpolation that preserves the NaN mask.

    NaN regions are excluded from contributing to interpolated values via
    a mask-weight normalization. Pixels where the interpolated mask weight
    falls below 0.01 are set back to NaN.
    """
    has_time = data.ndim == 3
    if not has_time:
        data = data[np.newaxis]  # -> (1, H, W)

    valid_mask = np.isfinite(data).astype(np.float32)
    data_filled = np.where(valid_mask.astype(bool), data, 0.0).astype(np.float32)

    h, w = target_size
    # (T, H, W) -> (T, 1, H, W) for F.interpolate
    t_data = torch.from_numpy(data_filled).unsqueeze(1)
    t_mask = torch.from_numpy(valid_mask).unsqueeze(1)

    t_data_r = F.interpolate(t_data, size=(h, w), mode="bilinear", align_corners=False)
    t_mask_r = F.interpolate(t_mask, size=(h, w), mode="bilinear", align_corners=False)

    data_r = t_data_r.numpy().squeeze(1)  # (T, H, W)
    mask_r = t_mask_r.numpy().squeeze(1)  # (T, H, W)

    # Normalize by mask weight; restore NaN where coverage is negligible
    result = np.where(mask_r > 0.01, data_r / np.maximum(mask_r, 1e-8), np.nan)

    if not has_time:
        result = result.squeeze(0)
    return result.astype(np.float32)


# =============================================================================
# Fast Grid Merging
# =============================================================================


class GridMerger:
    """Accumulates gridded data from multiple sources, computes mean."""

    __slots__ = (
        "bbox",
        "resolution",
        "target_lons",
        "target_lats",
        "shape",
        "_sum",
        "_count",
    )

    def __init__(self, bbox: tuple[float, float, float, float], resolution: float):
        self.bbox = bbox
        self.resolution = resolution

        lon_min, lon_max, lat_min, lat_max = bbox
        self.target_lons = np.arange(
            lon_min, lon_max + resolution / 2, resolution, dtype=np.float32
        )
        self.target_lats = np.arange(
            lat_min, lat_max + resolution / 2, resolution, dtype=np.float32
        )
        self.shape = (len(self.target_lats), len(self.target_lons))

        self._sum = np.zeros(self.shape, dtype=np.float64)
        self._count = np.zeros(self.shape, dtype=np.int32)

    def add(self, data: np.ndarray, src_lons: np.ndarray, src_lats: np.ndarray) -> None:
        if data.size == 0:
            return

        while data.ndim > 2:
            data = data.squeeze()

        if data.shape == self.shape:
            valid = np.isfinite(data)
            self._sum += np.where(valid, data, 0)
            self._count += valid.astype(np.int32)
            return

        # Coordinate mapping for mismatched grids
        lon_idx = np.clip(
            ((src_lons - self.bbox[0]) / self.resolution).astype(int),
            0,
            len(self.target_lons) - 1,
        )
        lat_idx = np.clip(
            ((src_lats - self.bbox[2]) / self.resolution).astype(int),
            0,
            len(self.target_lats) - 1,
        )

        # Vectorized scatter-add (replaces O(H×W) Python loop)
        li, lj = np.meshgrid(
            lat_idx[: data.shape[0]],
            lon_idx[: data.shape[1]],
            indexing="ij",
        )
        valid = np.isfinite(data)
        np.add.at(self._sum, (li[valid], lj[valid]), data[valid])
        np.add.at(self._count, (li[valid], lj[valid]), 1)

    def result(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        with np.errstate(invalid="ignore"):
            merged = np.where(self._count > 0, self._sum / self._count, np.nan)
        return merged.astype(np.float32), self.target_lats, self.target_lons


# =============================================================================
# Main Dataset Class
# =============================================================================


[docs] class OceanTACODataset(Dataset): """Query-based PyTorch Dataset for OceanTACO data. Pre-indexes files via SQL at init, making it safe for DataLoader with num_workers > 0. """ def __init__( self, taco_path: str, queries: list[Query], input_variables: list[str], target_variables: list[str], target_resolution: float | None = None, temporal_agg: Literal["first", "last", "mean", "stack"] = "mean", default_patch_size: tuple[int, int] = (128, 128), patch_sizes: dict[str, tuple[int, int]] | None = None, ): """Args: taco_path: Path to TACO dataset file queries: List of Query objects defining samples input_variables: Variables to load as inputs target_variables: Variables to load as targets target_resolution: Output grid resolution in degrees (None = use native) temporal_agg: How to aggregate multiple timestamps default_patch_size: Target (H, W) pixel size for all gridded variables. Default (128, 128). patch_sizes: Per-variable overrides, e.g. {"l4_ssh": (64, 64)}. Point sources are never resized. """ super().__init__() self.taco_path = taco_path self.queries = queries self.input_variables = list(input_variables) self.target_variables = list(target_variables) self.all_variables = list(set(input_variables) | set(target_variables)) self.target_resolution = target_resolution self.temporal_agg = temporal_agg self.default_patch_size = default_patch_size self.patch_sizes = patch_sizes or {} # Validate variables invalid = set(self.all_variables) - set(VAR_NAMES.keys()) if invalid: raise ValueError(f"Unknown variables: {invalid}. Valid: {list(VAR_NAMES)}") self._file_index = build_file_index(taco_path, queries, self.all_variables) def __len__(self) -> int: return len(self.queries) def __getitem__(self, idx: int) -> dict: query = self.queries[idx] file_df = self._file_index[idx] if file_df.empty: return self._empty_result(query) inputs = { v: self._load_variable(v, file_df, query.bbox) for v in self.input_variables } targets = { v: self._load_variable(v, file_df, query.bbox) for v in self.target_variables } return { "inputs": inputs, "targets": targets, "coords": self._extract_coords(inputs, targets, query.bbox), "metadata": { "bbox": query.bbox, "time_range": (query.time_start, query.time_end), "n_files": len(file_df), }, } def _load_variable( self, var: str, file_df: pd.DataFrame, bbox: tuple ) -> dict | None: if var.startswith("glorys_"): var_df = file_df[file_df["data_source"] == "glorys"] else: var_df = file_df[file_df["data_source"] == var] if var_df.empty: return None nc_var = VAR_NAMES[var] resolution = float(var_df["res_deg_lat"].iloc[0]) use_merger = len(var_df) > 1 merger = GridMerger(bbox, resolution) if use_merger else None data_list = [] lats_out, lons_out = None, None for _, row in var_df.iterrows(): vsi_path = row[COL_VSI] result = load_netcdf_var(vsi_path, nc_var, bbox) if not result: continue data, lats, lons = result if data.size == 0: continue if merger: merger.add(data, lons, lats) else: data_list.append(data) if lats_out is None: lats_out, lons_out = lats, lons if merger: data, lats_out, lons_out = merger.result() elif data_list: data = self._aggregate_temporal(data_list) else: return None if var == "l4_sst": data = data - 273.15 # Resize gridded vars to target patch size if var not in POINT_SOURCES and self.default_patch_size is not None: target_size = self.patch_sizes.get(var, self.default_patch_size) if data.shape[-2:] != target_size: data = _interpolate_to_patch(data, target_size) # Update coords to match new pixel grid lon_min, lon_max, lat_min, lat_max = bbox h, w = target_size lats_out = np.linspace(lat_min, lat_max, h, dtype=np.float32) lons_out = np.linspace(lon_min, lon_max, w, dtype=np.float32) # TODO normalization # Handle NaN data = np.nan_to_num(data, nan=0.0) if data.ndim > 2 and data.shape[0] == 1: data = data.squeeze(0) return { "data": torch.from_numpy(data.astype(np.float32)), "lats": torch.from_numpy(lats_out.astype(np.float32)) if lats_out is not None else None, "lons": torch.from_numpy(lons_out.astype(np.float32)) if lons_out is not None else None, } def _aggregate_temporal(self, data_list: list[np.ndarray]) -> np.ndarray: if len(data_list) == 1: return data_list[0] shapes = [d.shape for d in data_list] if len(set(shapes)) > 1: return data_list[0] stacked = np.stack(data_list, axis=0) if self.temporal_agg == "first": return stacked[0] elif self.temporal_agg == "last": return stacked[-1] elif self.temporal_agg == "mean": return np.nanmean(stacked, axis=0) elif self.temporal_agg == "stack": return stacked return stacked[0] def _extract_coords(self, inputs: dict, targets: dict, bbox: tuple) -> dict: for var_data in list(inputs.values()) + list(targets.values()): if var_data and var_data.get("lats") is not None: return {"lat": var_data["lats"], "lon": var_data["lons"]} # Fallback: derive from default_patch_size lon_min, lon_max, lat_min, lat_max = bbox h, w = self.default_patch_size return { "lat": torch.linspace(lat_min, lat_max, h), "lon": torch.linspace(lon_min, lon_max, w), } def _empty_result(self, query: Query) -> dict: return { "inputs": {v: None for v in self.input_variables}, "targets": {v: None for v in self.target_variables}, "coords": self._extract_coords({}, {}, query.bbox), "metadata": { "bbox": query.bbox, "time_range": (query.time_start, query.time_end), "n_files": 0, }, }
[docs] def visualize_sample( self, sample: dict, figsize: tuple[int, int] | None = None, save_path: str | Path | None = None, title: str = "", max_cols: int = 3, ): """Visualize all variables in a sample. Args: sample: Output from __getitem__ or _execute_query figsize: Figure size (width, height) save_path: Path to save figure (None = display) title: Optional title prefix max_cols: Maximum columns in subplot grid """ # Collect all variables to plot all_vars = {} for name, data in sample["inputs"].items(): if data is not None: all_vars[f"[Input] {name}"] = data for name, data in sample["targets"].items(): if data is not None: all_vars[f"[Target] {name}"] = data if not all_vars: print("No data to visualize!") return n_vars = len(all_vars) n_cols = min(max_cols, n_vars) n_rows = (n_vars + n_cols - 1) // n_cols if figsize is None: figsize = (6 * n_cols, 5 * n_rows) fig, axes = plt.subplots( n_rows, n_cols, figsize=figsize, subplot_kw={"projection": ccrs.PlateCarree()}, squeeze=False, ) axes = axes.flatten() bbox = sample["metadata"].get("bbox") coords = sample["coords"] for ax, (var_label, var_data) in zip(axes, all_vars.items()): self._plot_variable(ax, var_label, var_data, coords, bbox) # Hide unused axes for i in range(n_vars, len(axes)): axes[i].axis("off") # Suptitle with metadata metadata = sample["metadata"] time_range = metadata.get("time_range", ("?", "?")) suptitle = f"{title}\n" if title else "" suptitle += f"Time: {time_range[0]} to {time_range[1]}" if bbox: suptitle += ( f" | BBox: [{bbox[0]:.1f}, {bbox[1]:.1f}, {bbox[2]:.1f}, {bbox[3]:.1f}]" ) suptitle += f" | Files: {metadata.get('n_files', '?')}" fig.suptitle(suptitle, fontsize=12, fontweight="bold") fig.tight_layout() if save_path: Path(save_path).parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight") print(f"Saved: {save_path}") else: plt.show() plt.close(fig)
def _plot_variable( self, ax, var_label: str, var_data: dict, coords: dict, bbox: tuple ): """Plot a single variable on an axis.""" data = var_data["data"].detach().cpu().numpy() # Use coordinates from var_data if available, else from sample coords if var_data.get("lats") is not None: lats = var_data["lats"].detach().cpu().numpy() lons = var_data["lons"].detach().cpu().numpy() else: lats = coords["lat"].detach().cpu().numpy() lons = coords["lon"].detach().cpu().numpy() if data.size == 0: ax.text( 0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes ) ax.set_title(var_label) return # Get colormap params cmap_params = _get_colormap_params(var_label) # Handle dark background for sparse data (L3 SSH/SWOT) var_lower = var_label.lower() use_dark_bg = "ssh" in var_lower or "swot" in var_lower or "sla" in var_lower if use_dark_bg: ax.set_facecolor("black") land_color = "#333333" grid_color = "gray" # Expose background for sparse L3 data (masked with 0.0 in loader) if "l3" in var_lower: data = data.copy() data[data == 0.0] = np.nan else: land_color = "lightgray" grid_color = "black" # Handle 3D data (time, lat, lon) - take last timestep if data.ndim == 3: data = data[-1] var_label += " (t=-1)" # Dynamically set vmin/vmax from data, fallback to defaults if not set finite_data = data[np.isfinite(data)] vmin = np.nanmin(finite_data) vmax = np.nanmax(finite_data) # Plot based on data shape if data.ndim == 2 and lats.ndim == 1: # Gridded data mappable = ax.pcolormesh( lons, lats, data, transform=ccrs.PlateCarree(), cmap=cmap_params["cmap"], vmin=vmin, vmax=vmax, rasterized=True, shading="gouraud" if "l3_swot" in var_lower else "auto", ) elif data.ndim == 1: # Point data mappable = ax.scatter( lons, lats, c=data, transform=ccrs.PlateCarree(), cmap=cmap_params["cmap"], vmin=vmin, vmax=vmax, s=10, alpha=0.8, ) else: ax.text( 0.5, 0.5, f"Shape: {data.shape}", ha="center", va="center", transform=ax.transAxes, ) ax.set_title(var_label) return plt.colorbar(mappable, ax=ax, label=cmap_params["label"], shrink=0.8) # Set extent if bbox: ax.set_extent([bbox[0], bbox[1], bbox[2], bbox[3]], crs=ccrs.PlateCarree()) # Add map features ax.coastlines(linewidth=0.5, color=grid_color) ax.add_feature(cfeature.LAND, facecolor=land_color, edgecolor="none") ax.gridlines(draw_labels=True, linewidth=0.3, alpha=0.5, color=grid_color) ax.set_title(f"{var_label}\nshape={data.shape}", fontsize=10)
# ============================================================================= # Visualization Helpers # ============================================================================= def _get_colormap_params(var_name: str) -> dict: """Get visualization parameters for a variable.""" var_lower = var_name.lower() if "ssh" in var_lower or "swot" in var_lower or "sla" in var_lower: return {"vmin": -0.6, "vmax": 0.6, "cmap": "RdBu_r", "label": "SSH (m)"} elif "sst" in var_lower or "temp" in var_lower: return {"vmin": 0, "vmax": 40, "cmap": "RdYlBu_r", "label": "SST (°C)"} elif "sss" in var_lower or "sal" in var_lower: return {"vmin": 32, "vmax": 38, "cmap": "viridis", "label": "SSS (PSU)"} elif "wind" in var_lower: return {"vmin": -15, "vmax": 15, "cmap": "coolwarm", "label": "Wind (m/s)"} elif "uo" in var_lower or "vo" in var_lower: return {"vmin": -2, "vmax": 2, "cmap": "coolwarm", "label": "Current (m/s)"} else: return {"vmin": None, "vmax": None, "cmap": "viridis", "label": "Value"} # ============================================================================= # Collate Function # =============================================================================
[docs] def collate_ocean_samples(batch: list[dict]) -> dict: """Collate function for DataLoader. Handles None values and variable-size tensors by padding. """ if not batch: return {} def stack_tensors(tensor_list: list[torch.Tensor | None]) -> torch.Tensor | None: tensors = [t for t in tensor_list if t is not None] if not tensors: return None # Pad to max shape ndim = tensors[0].ndim max_shape = [max(t.shape[i] for t in tensors) for i in range(ndim)] padded = [] for t in tensors: if list(t.shape) != max_shape: pad = [] for i in range(ndim - 1, -1, -1): pad.extend([0, max_shape[i] - t.shape[i]]) t = torch.nn.functional.pad(t, pad, value=0.0) padded.append(t) return torch.stack(padded, dim=0) # Collate inputs input_vars = list(batch[0]["inputs"].keys()) inputs = {} for var in input_vars: tensors = [ s["inputs"][var]["data"] if s["inputs"][var] else None for s in batch ] inputs[var] = stack_tensors(tensors) # Collate targets target_vars = list(batch[0]["targets"].keys()) targets = {} for var in target_vars: tensors = [ s["targets"][var]["data"] if s["targets"][var] else None for s in batch ] targets[var] = stack_tensors(tensors) # Coords from first sample coords = batch[0]["coords"] # Metadata metadata = { "bboxes": [s["metadata"]["bbox"] for s in batch], "time_ranges": [s["metadata"]["time_range"] for s in batch], } return { "inputs": inputs, "targets": targets, "coords": coords, "metadata": metadata, }