Source code for canopy.visualization.line_plot.latitudinal_plot

import warnings
from typing import Optional, List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import canopy as cp
from canopy.visualization.multiple_figs import setup_figure_and_axes, create_wrapper_from_locals
from canopy.visualization.visualization_helpers import (
    handle_figure_output, get_color_palette, make_dark_mode, get_field_metadata,
)
from canopy.visualization.line_plot.line_plot_helpers import (
    apply_legend_style, var_case_matrix, get_n_hue_classes,
)

[docs] def make_latitudinal_plot( fields: cp.Field | List[cp.Field], output_file: Optional[str] = None, layers: Optional[List[str] | str] = None, yaxis_label: Optional[str] = None, field_labels: Optional[List[str]] = None, unit: Optional[str] = None, title: Optional[str] = None, palette: Optional[str] = None, custom_palette: Optional[str] = None, move_legend: Optional[bool] = False, legend_style: str = 'default', max_labels_per_col: int = 15, dark_mode: bool = False, transparent: bool = False, x_label_rotation: float = 0, x_fig: float = 10, y_fig: float = 10, subfig=None, return_fig: bool = False, **kwargs, ) -> Optional[plt.Figure]: """ Create a latitudinal plot showing variable values as a function of latitude. The plot displays mean values averaged over time at each latitude. Parameters ---------- fields : cp.Field or List[cp.Field] Input data Field or list of Fields to display. output_file : str, optional File path for saving the plot. layers : List[str] or str, optional List of layer names to display. If None, all layers from the first field are used. yaxis_label : str, optional Y-axis label, if not provided canopy will try to retrieve the name of the variable in the metadata. field_labels : List[str], optional List of labels for each field when multiple fields are provided. Required when multiple fields are used. unit : str, optional Unit of the variable, if not provided canopy will try to retrieve the unit of the variable in the metadata. title : str, optional Title of the plot. palette : str, optional Seaborn color palette to use for the line colors (https://seaborn.pydata.org/tutorial/color_palettes.html, recommended palette are in https://colorbrewer2.org). custom_palette : str, optional Path of custom color palette .txt file to use. Names should match label names. move_legend : bool, optional Move the legend outside of plot. Default is False. legend_style : str, optional Style of the legend ('default', 'highlighted', 'end-of-line', 'hidden'). If 'hidden', the legend will not be shown. Default is 'default'. max_labels_per_col : int, optional Maximum number of labels per column in the legend. Default is 15. dark_mode : bool, optional If True, apply dark mode styling. Default is False. transparent : bool, optional If True, make the figure transparent. Default is False. x_label_rotation : float, optional Rotation angle in degrees for the x-axis tick labels. Default is 0. x_fig : float, optional Width of the figure in inches. Default is 10. y_fig : float, optional Height of the figure in inches. Default is 10. subfig : matplotlib.figure.SubFigure, optional If provided, the plot will be created in this subfigure instead of creating a new figure. This is used by multiple_figs() to combine multiple plots. User can also provide a plt.figure.subfigure object (https://matplotlib.org/stable/gallery/subplots_axes_and_figures/subfigures.html) return_fig : bool, optional If True, return a callable wrapper function instead of creating the plot immediately. This wrapper can be used with multiple_figs(). Default is False. **kwargs Additional keyword arguments are passed directly to `seaborn.lineplot`. This allows customization of line aesthetics such as `linewidth`, `linestyle`, `alpha`, etc. """ # If return_fig is True, create a wrapper function and return it if return_fig: return create_wrapper_from_locals(make_latitudinal_plot, locals()) # Force variables to be a list if not isinstance(fields, list): fields = [fields] if isinstance(layers, str): layers = [layers] # Retrieve metadata yaxis_label, unit, layers = get_field_metadata(fields, yaxis_label, unit, layers) # Pre-checks _pre_checks(fields, layers, field_labels) # Get latitude range lat_range = _get_lat_range(fields) # Convert fields to long format df_long = _fields_to_long_format(fields, layers, field_labels, lat_range) # Create figure and axis fig, ax = setup_figure_and_axes(subfig=subfig, x_fig=x_fig, y_fig=y_fig) n_fields = len(fields) n_layers = len(layers) hue_var, style_var = var_case_matrix( n_fields, n_layers, grid_not_reduced=False, reverse_hue_style=False ) # Determine number of classes for color palette based on hue variable n_classes = get_n_hue_classes(hue_var, [], n_fields, n_layers) # Get color palette colors, palette_dict = get_color_palette(n_classes, palette=palette, custom_palette=custom_palette) # Base arguments for seaborn lineplot # Use x=latitude, y=value so seaborn groups by latitude and calculates stats for value plot_kwargs = { "data": df_long, "x": "latitude", "y": "value", "ax": ax, } if hue_var is None and style_var is None: plot_kwargs["hue"] = None plot_kwargs["style"] = None plot_kwargs["legend"] = False else: plot_kwargs["hue"] = hue_var plot_kwargs["style"] = style_var plot_kwargs["legend"] = bool(legend_style) plot_kwargs["palette"] = colors # Update with user-provided kwargs, but remove subfig if present plot_kwargs.update({k: v for k, v in kwargs.items() if k != 'subfig'}) # Plot using seaborn lineplot (this plots with latitude on x, value on y) sns.lineplot(**plot_kwargs) # Swap axes to get desired orientation: value on x, latitude on y xlim = ax.get_xlim() ylim = ax.get_ylim() ax.set_xlim(ylim) ax.set_ylim(xlim) # Swap line data for line in ax.get_lines(): x_data = line.get_xdata() y_data = line.get_ydata() line.set_data(y_data, x_data) # Apply the legend style (only if legend was requested and there are labels) if hue_var is not None or style_var is not None: apply_legend_style( ax, legend_style, max_labels_per_col, move_legend, hue_var=hue_var, style_var=style_var ) # Set axis labels xlabel = f"{yaxis_label} (in {unit})" if unit and unit != "[no units]" else (yaxis_label or "Value") ax.set_xlabel(xlabel, fontsize=16) ax.set_ylabel("Latitude", fontsize=16) ax.tick_params(labelsize=14) if x_label_rotation: ax.tick_params(axis='x', labelrotation=x_label_rotation) if title: ax.set_title(title, fontsize=18, pad=20) # Remove the top and right spines sns.despine(ax=ax) # Apply dark mode if requested if dark_mode: fig, ax = make_dark_mode(fig, ax) return handle_figure_output(fig, output_file=output_file, transparent=transparent, subfig=subfig)
def _pre_checks( fields: List[cp.Field], layers: List[str] | str, field_labels: Optional[List[str]], ) -> None: """ Validate fields, layers, and field_labels before processing. """ for i, field in enumerate(fields): if field.grid.is_reduced('lat'): raise ValueError( f"Field {i} has reduced latitude. Latitudinal plot requires the latitude axis to be unreduced." ) if len(fields) > 1 and field_labels is None: raise ValueError("field_labels must be defined when there are more than one field.") if len(fields) > 1 and field_labels is not None and len(field_labels) != len(fields): raise ValueError("field_labels should be of the same size as the number of fields.") if len(fields) > 1: layers_set = set(layers) if isinstance(layers, list) else {layers} for i, field in enumerate(fields): field_layers = set(field.layers) missing_layers = layers_set - field_layers if missing_layers: raise ValueError( f'Field {i} is missing the following selected layer(s): {missing_layers}. ' 'Tip: use field.rename_layers() method or select layers that exist in all fields.' ) def _get_lat_range(fields: List[cp.Field]) -> List[float]: """ Calculate latitude range from fields' grids. If multiple fields, use the union of all ranges. Falls back to [-90, 90] if grids don't have lat_min/lat_max. """ lat_mins = [] lat_maxs = [] for field in fields: if hasattr(field.grid, 'lat_min') and hasattr(field.grid, 'lat_max'): if not field.grid.is_reduced('lat'): lat_mins.append(field.grid.lat_min) lat_maxs.append(field.grid.lat_max) return [min(lat_mins), max(lat_maxs)] if lat_mins and lat_maxs else [-90.0, 90.0] def _fields_to_long_format( fields: List[cp.Field], layers: List[str] | str, field_labels: Optional[List[str]], lat_range: List[float], ) -> pd.DataFrame: """ Process fields: reduce along longitude, calculate mean over time per latitude, and return a long-format DataFrame (latitude, value, field_label, layer). """ combined_data = [] field_labels_list = field_labels or [] for i, field in enumerate(fields): label = field_labels_list[i] if i < len(field_labels_list) else f"Field {i+1}" if not field.grid.is_reduced('lon'): field_lon_red = field.reduce_grid('av', axis='lon') if field_lon_red is None: field_lon_red = field else: field_lon_red = field for layer in (layers if isinstance(layers, list) else [layers]): if layer not in field_lon_red.layers: warnings.warn(f"Layer '{layer}' not found in field {i}, skipping.", UserWarning) continue data = field_lon_red.data[layer] if isinstance(data.index, pd.MultiIndex): if 'lat' in data.index.names: lat_level = 'lat' else: for level_name in data.index.names: if level_name != 'time': lat_level = level_name break else: raise ValueError("Could not find latitude level in index") grouped = data.groupby(level=lat_level).mean() latitudes = grouped.index.values values = grouped.values else: latitudes = data.index.values values = data.values latitudes_float = pd.to_numeric(latitudes, errors='coerce') mask = (latitudes_float >= lat_range[0]) & (latitudes_float <= lat_range[1]) latitudes_filtered = latitudes_float[mask] values_filtered = values[mask] for lat, val in zip(latitudes_filtered, values_filtered): if not np.isnan(lat) and not np.isnan(val): combined_data.append({ "latitude": float(lat), "value": float(val), "field_label": label, "layer": layer }) if not combined_data: raise ValueError("No valid data found for plotting.") return pd.DataFrame(combined_data)