import warnings
from typing import Optional, List
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import canopy as cp
from canopy.visualization.multiple_figs import setup_figure_and_axes, create_wrapper_from_locals
from canopy.visualization.visualization_helpers import (
handle_figure_output, get_color_palette, make_dark_mode, get_field_metadata,
)
from canopy.visualization.line_plot.line_plot_helpers import (
apply_legend_style, var_case_matrix, get_n_hue_classes,
)
[docs]
def make_latitudinal_plot(
fields: cp.Field | List[cp.Field],
output_file: Optional[str] = None,
layers: Optional[List[str] | str] = None,
yaxis_label: Optional[str] = None,
field_labels: Optional[List[str]] = None,
unit: Optional[str] = None,
title: Optional[str] = None,
palette: Optional[str] = None,
custom_palette: Optional[str] = None,
move_legend: Optional[bool] = False,
legend_style: str = 'default',
max_labels_per_col: int = 15,
dark_mode: bool = False,
transparent: bool = False,
x_label_rotation: float = 0,
x_fig: float = 10,
y_fig: float = 10,
subfig=None,
return_fig: bool = False,
**kwargs,
) -> Optional[plt.Figure]:
"""
Create a latitudinal plot showing variable values as a function of latitude.
The plot displays mean values averaged over time at each latitude.
Parameters
----------
fields : cp.Field or List[cp.Field]
Input data Field or list of Fields to display.
output_file : str, optional
File path for saving the plot.
layers : List[str] or str, optional
List of layer names to display. If None, all layers from the first field are used.
yaxis_label : str, optional
Y-axis label, if not provided canopy will try to retrieve the name of the variable
in the metadata.
field_labels : List[str], optional
List of labels for each field when multiple fields are provided. Required when
multiple fields are used.
unit : str, optional
Unit of the variable, if not provided canopy will try to retrieve the unit of
the variable in the metadata.
title : str, optional
Title of the plot.
palette : str, optional
Seaborn color palette to use for the line colors
(https://seaborn.pydata.org/tutorial/color_palettes.html,
recommended palette are in https://colorbrewer2.org).
custom_palette : str, optional
Path of custom color palette .txt file to use. Names should match label names.
move_legend : bool, optional
Move the legend outside of plot. Default is False.
legend_style : str, optional
Style of the legend ('default', 'highlighted', 'end-of-line', 'hidden').
If 'hidden', the legend will not be shown. Default is 'default'.
max_labels_per_col : int, optional
Maximum number of labels per column in the legend. Default is 15.
dark_mode : bool, optional
If True, apply dark mode styling. Default is False.
transparent : bool, optional
If True, make the figure transparent. Default is False.
x_label_rotation : float, optional
Rotation angle in degrees for the x-axis tick labels. Default is 0.
x_fig : float, optional
Width of the figure in inches. Default is 10.
y_fig : float, optional
Height of the figure in inches. Default is 10.
subfig : matplotlib.figure.SubFigure, optional
If provided, the plot will be created in this subfigure instead of creating a new figure.
This is used by multiple_figs() to combine multiple plots.
User can also provide a plt.figure.subfigure object
(https://matplotlib.org/stable/gallery/subplots_axes_and_figures/subfigures.html)
return_fig : bool, optional
If True, return a callable wrapper function instead of creating the plot immediately.
This wrapper can be used with multiple_figs(). Default is False.
**kwargs
Additional keyword arguments are passed directly to `seaborn.lineplot`. This allows
customization of line aesthetics such as `linewidth`, `linestyle`, `alpha`, etc.
"""
# If return_fig is True, create a wrapper function and return it
if return_fig:
return create_wrapper_from_locals(make_latitudinal_plot, locals())
# Force variables to be a list
if not isinstance(fields, list):
fields = [fields]
if isinstance(layers, str):
layers = [layers]
# Retrieve metadata
yaxis_label, unit, layers = get_field_metadata(fields, yaxis_label, unit, layers)
# Pre-checks
_pre_checks(fields, layers, field_labels)
# Get latitude range
lat_range = _get_lat_range(fields)
# Convert fields to long format
df_long = _fields_to_long_format(fields, layers, field_labels, lat_range)
# Create figure and axis
fig, ax = setup_figure_and_axes(subfig=subfig, x_fig=x_fig, y_fig=y_fig)
n_fields = len(fields)
n_layers = len(layers)
hue_var, style_var = var_case_matrix(
n_fields, n_layers, grid_not_reduced=False, reverse_hue_style=False
)
# Determine number of classes for color palette based on hue variable
n_classes = get_n_hue_classes(hue_var, [], n_fields, n_layers)
# Get color palette
colors, palette_dict = get_color_palette(n_classes, palette=palette, custom_palette=custom_palette)
# Base arguments for seaborn lineplot
# Use x=latitude, y=value so seaborn groups by latitude and calculates stats for value
plot_kwargs = {
"data": df_long,
"x": "latitude",
"y": "value",
"ax": ax,
}
if hue_var is None and style_var is None:
plot_kwargs["hue"] = None
plot_kwargs["style"] = None
plot_kwargs["legend"] = False
else:
plot_kwargs["hue"] = hue_var
plot_kwargs["style"] = style_var
plot_kwargs["legend"] = bool(legend_style)
plot_kwargs["palette"] = colors
# Update with user-provided kwargs, but remove subfig if present
plot_kwargs.update({k: v for k, v in kwargs.items() if k != 'subfig'})
# Plot using seaborn lineplot (this plots with latitude on x, value on y)
sns.lineplot(**plot_kwargs)
# Swap axes to get desired orientation: value on x, latitude on y
xlim = ax.get_xlim()
ylim = ax.get_ylim()
ax.set_xlim(ylim)
ax.set_ylim(xlim)
# Swap line data
for line in ax.get_lines():
x_data = line.get_xdata()
y_data = line.get_ydata()
line.set_data(y_data, x_data)
# Apply the legend style (only if legend was requested and there are labels)
if hue_var is not None or style_var is not None:
apply_legend_style(
ax, legend_style, max_labels_per_col, move_legend,
hue_var=hue_var, style_var=style_var
)
# Set axis labels
xlabel = f"{yaxis_label} (in {unit})" if unit and unit != "[no units]" else (yaxis_label or "Value")
ax.set_xlabel(xlabel, fontsize=16)
ax.set_ylabel("Latitude", fontsize=16)
ax.tick_params(labelsize=14)
if x_label_rotation:
ax.tick_params(axis='x', labelrotation=x_label_rotation)
if title:
ax.set_title(title, fontsize=18, pad=20)
# Remove the top and right spines
sns.despine(ax=ax)
# Apply dark mode if requested
if dark_mode:
fig, ax = make_dark_mode(fig, ax)
return handle_figure_output(fig, output_file=output_file, transparent=transparent, subfig=subfig)
def _pre_checks(
fields: List[cp.Field],
layers: List[str] | str,
field_labels: Optional[List[str]],
) -> None:
"""
Validate fields, layers, and field_labels before processing.
"""
for i, field in enumerate(fields):
if field.grid.is_reduced('lat'):
raise ValueError(
f"Field {i} has reduced latitude. Latitudinal plot requires the latitude axis to be unreduced."
)
if len(fields) > 1 and field_labels is None:
raise ValueError("field_labels must be defined when there are more than one field.")
if len(fields) > 1 and field_labels is not None and len(field_labels) != len(fields):
raise ValueError("field_labels should be of the same size as the number of fields.")
if len(fields) > 1:
layers_set = set(layers) if isinstance(layers, list) else {layers}
for i, field in enumerate(fields):
field_layers = set(field.layers)
missing_layers = layers_set - field_layers
if missing_layers:
raise ValueError(
f'Field {i} is missing the following selected layer(s): {missing_layers}. '
'Tip: use field.rename_layers() method or select layers that exist in all fields.'
)
def _get_lat_range(fields: List[cp.Field]) -> List[float]:
"""
Calculate latitude range from fields' grids. If multiple fields, use the union of all ranges.
Falls back to [-90, 90] if grids don't have lat_min/lat_max.
"""
lat_mins = []
lat_maxs = []
for field in fields:
if hasattr(field.grid, 'lat_min') and hasattr(field.grid, 'lat_max'):
if not field.grid.is_reduced('lat'):
lat_mins.append(field.grid.lat_min)
lat_maxs.append(field.grid.lat_max)
return [min(lat_mins), max(lat_maxs)] if lat_mins and lat_maxs else [-90.0, 90.0]
def _fields_to_long_format(
fields: List[cp.Field],
layers: List[str] | str,
field_labels: Optional[List[str]],
lat_range: List[float],
) -> pd.DataFrame:
"""
Process fields: reduce along longitude, calculate mean over time per latitude,
and return a long-format DataFrame (latitude, value, field_label, layer).
"""
combined_data = []
field_labels_list = field_labels or []
for i, field in enumerate(fields):
label = field_labels_list[i] if i < len(field_labels_list) else f"Field {i+1}"
if not field.grid.is_reduced('lon'):
field_lon_red = field.reduce_grid('av', axis='lon')
if field_lon_red is None:
field_lon_red = field
else:
field_lon_red = field
for layer in (layers if isinstance(layers, list) else [layers]):
if layer not in field_lon_red.layers:
warnings.warn(f"Layer '{layer}' not found in field {i}, skipping.", UserWarning)
continue
data = field_lon_red.data[layer]
if isinstance(data.index, pd.MultiIndex):
if 'lat' in data.index.names:
lat_level = 'lat'
else:
for level_name in data.index.names:
if level_name != 'time':
lat_level = level_name
break
else:
raise ValueError("Could not find latitude level in index")
grouped = data.groupby(level=lat_level).mean()
latitudes = grouped.index.values
values = grouped.values
else:
latitudes = data.index.values
values = data.values
latitudes_float = pd.to_numeric(latitudes, errors='coerce')
mask = (latitudes_float >= lat_range[0]) & (latitudes_float <= lat_range[1])
latitudes_filtered = latitudes_float[mask]
values_filtered = values[mask]
for lat, val in zip(latitudes_filtered, values_filtered):
if not np.isnan(lat) and not np.isnan(val):
combined_data.append({
"latitude": float(lat),
"value": float(val),
"field_label": label,
"layer": layer
})
if not combined_data:
raise ValueError("No valid data found for plotting.")
return pd.DataFrame(combined_data)