"""Raster data structure for map visualization with pcolormesh."""
import numpy as np
import pandas as pd
from typing import Optional, cast
from canopy.core.field import Field
from canopy.core.grid import get_grid_type
from canopy.core.grid.grid_lonlat import GridLonLat
from pandas.api.types import is_string_dtype
[docs]
class Raster:
"""Raster data for map plotting with ax.pcolormesh.
Uses 1D center arrays (x, y) with same dimensions as vmap rows/columns.
Compatible with pcolormesh using shading='nearest' or 'auto'.
"""
@staticmethod
def _get_coord_indices(coords: np.ndarray, coord_min: float, dcoord: float) -> np.ndarray:
"""Compute grid cell indices from coordinates (vectorized)."""
return ((coords - (coord_min - 0.5 * dcoord)) / dcoord).astype(np.intp)
def __init__(self, field: Field, layer: str, timeop: str = "av") -> None:
"""Create a Raster from a Field.
Parameters
----------
field : Field
The field object from which to create the Raster.
layer : str
The field layer to rasterize.
timeop : str, optional
Time reduction: 'av' or 'sum'. Default is 'av'.
"""
grid_type = get_grid_type(field.grid)
if grid_type != "lonlat":
raise ValueError("Raster currently supports only 'lonlat' grid type.")
layer_dtype = field.data.dtypes.get(layer)
is_categorical = layer_dtype is not None and is_string_dtype(layer_dtype)
if field.timeop is not None or is_categorical:
field_to_use = field
else:
field_to_use = cast(Field, field.reduce_time(timeop))
data = field_to_use.data
grid = cast(GridLonLat, field_to_use.grid)
ilon = self._get_coord_indices(
data.index.get_level_values("lon").to_numpy(), grid.lon_min, grid.dlon
)
ilat = self._get_coord_indices(
data.index.get_level_values("lat").to_numpy(), grid.lat_min, grid.dlat
)
lons = grid.lons
lats = grid.lats
self._x = lons
self._y = lats
self.vmap = np.full((lats.size, lons.size), np.nan, dtype=float)
if is_string_dtype(data.dtypes[layer]):
unique_vals = data[layer].unique()
self.keys: dict[int,str] | None = dict(enumerate(unique_vals))
values = pd.Categorical(data[layer], categories=unique_vals).codes
self.vmap[ilat, ilon] = values
else:
self.keys = None
self.vmap[ilat, ilon] = data[layer].values
self.metadata = {
"name": field_to_use.metadata.get("name", "[no name]"),
"units": field_to_use.metadata.get("units", "[no units]"),
"description": field_to_use.metadata.get("description", "[no description]"),
"timeop": field_to_use.timeop,
"grid_type": grid_type,
"grid_xaxis": grid.xaxis,
"grid_yaxis": grid.yaxis,
}
@property
def x(self) -> np.ndarray:
"""1D array of x (lon) cell centers. Same length as vmap columns."""
return np.asarray(self._x)
@property
def y(self) -> np.ndarray:
"""1D array of y (lat) cell centers. Same length as vmap rows."""
return np.asarray(self._y)
@property
def x_edges(self) -> np.ndarray:
"""1D array of x (lon) cell edges. For use with plotting packages that need edges."""
if len(self._x) > 1:
dx = self._x[1] - self._x[0]
else:
dx = 1.0
return np.concatenate([[self._x[0] - dx / 2], self._x + dx / 2])
@property
def y_edges(self) -> np.ndarray:
"""1D array of y (lat) cell edges. For use with plotting packages that need edges."""
if len(self._y) > 1:
dy = self._y[1] - self._y[0]
else:
dy = 1.0
return np.concatenate([[self._y[0] - dy / 2], self._y + dy / 2])