Source code for canopy.visualization.taylor_diagram.taylor_diagram

import warnings
from typing import Optional, List, Tuple, Any

import geocat.viz as gv
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import canopy as cp
from canopy.visualization.multiple_figs import setup_figure_and_axes, create_wrapper_from_locals
from canopy.visualization.visualization_helpers import get_color_palette, make_dark_mode, handle_figure_output

# Marker symbols for differentiating layers/fields (cycle if more than available)
MARKERS = ['o', 's', '^', 'v', 'D', 'p', 'h', '*', '<', '>']

[docs] def make_taylor_diagram( fields: cp.Field | List[cp.Field], obs: cp.Field, output_file: Optional[str] = None, gridop: Optional[str] = None, title: Optional[str] = None, layers: Optional[List[str] | str] = None, field_labels: Optional[List[str]] = None, palette: Optional[str] = None, custom_palette: Optional[str] = None, dark_mode: bool = False, transparent: bool = False, marker_size: float = 100, marker: str = 'o', fontsize: int = 16, x_fig: float = 12, y_fig: float = 10, subfig=None, return_fig: bool = False, **kwargs, ) -> Optional[plt.Figure]: """ Create a Taylor diagram (https://pcmdi.llnl.gov/staff/taylor/CV/Taylor_diagram_primer.pdf) comparing model data against observations. Note: This function requires all input fields and observations to have grid_type='sites'. Parameters ---------- fields : cp.Field or List[cp.Field] Model data Field(s) to compare against observations. If a list is provided, multiple fields will be plotted with one point per field. obs : cp.Field Observation/reference data Field (typically the "truth"). output_file : str, optional File path for saving the plot. gridop : str, optional If provided, the grid reduction operation. Either None, 'sum' or 'av'. Default is None. title : str, optional Title of the plot. layers : List[str] or str, optional Layer names to include. Defaults to all layers in the first field. field_labels : List[str], optional Labels for each field when multiple fields are provided. Required when field is a list. palette : str, optional Seaborn color palette to use (https://seaborn.pydata.org/tutorial/color_palettes.html). custom_palette : str, optional Path of custom color palette .txt file to use. Names should match site labels. dark_mode : bool, optional Whether to apply dark mode styling to the plot. transparent : bool, optional If True, makes the background of the figure transparent. marker_size : float, optional Size of the markers. Default is 100. marker : str, optional Marker shape. Default is 'o' (circle). fontsize : int, optional Font size for labels. Default is 16. x_fig : float, optional Width of the figure in inches. Default is 12. y_fig : float, optional Height of the figure in inches. Default is 12. subfig : matplotlib.figure.SubFigure, optional If provided, the plot will be created in this subfigure instead of creating a new figure. This is used by multiple_figs() to combine multiple plots. User can also provide a plt.figure.subfigure object (https://matplotlib.org/stable/gallery/subplots_axes_and_figures/subfigures.html) return_fig : bool, optional If True, return a callable wrapper function instead of creating the plot immediately. This wrapper can be used with multiple_figs(). Default is False. **kwargs Additional keyword arguments are passed directly to `taylor.add_model_set`. This allows customization of plot features such as `linestyle`, `linewidth`, etc. """ # If return_fig is True, create a wrapper function and return it if return_fig: return create_wrapper_from_locals(make_taylor_diagram, locals()) # Force fields and layers to be lists if isinstance(fields, cp.Field): fields = [fields] if isinstance(layers, str): layers = [layers] layers = layers or list(fields[0].layers) n_fields = len(fields) n_layers = len(layers) _pre_checks(fields, obs, field_labels, layers, n_fields) # Apply grid reduction to obs if needed obs_field = obs.reduce_grid(gridop) if gridop else obs # Build flattened data and compute stats with (color_class, symbol_class, label) per point points_data = _compute_taylor_points( fields, obs_field, gridop, layers, field_labels, n_fields, n_layers ) if len(points_data) == 0: raise ValueError("No valid data found with sufficient data for Taylor diagram") # Determine color_dim and symbol_dim from case matrix n_sites = len(set(p["site"] for p in points_data)) if n_sites > 1 and n_fields > 1 and n_layers > 1: raise ValueError( "Cannot differentiate 3 categories (sites, fields, layers). Use at most 2: " "e.g. reduce to one site, one field, or one layer." ) color_dim, symbol_dim = _taylorvar_case_matrix(n_sites, n_fields, n_layers) # Assign colors and markers per point colors, markers = _assign_colors_and_markers( points_data, color_dim, symbol_dim, palette, custom_palette, marker ) stddev_norm = [p["std_norm"] for p in points_data] corrcoef = [p["corr"] for p in points_data] label_names = [p["label"] for p in points_data] # Create the Taylor diagram figure fig = _plot_taylor_diagram( stddev_norm, corrcoef, label_names, colors, markers, x_fig, y_fig, marker_size, fontsize, title, dark_mode, subfig=subfig, **kwargs ) # Handle output and return return handle_figure_output(fig, output_file=output_file, transparent=transparent, subfig=subfig)
def _pre_checks( fields: List[cp.Field], obs: cp.Field, field_labels: Optional[List[str]], layers: List[str], n_fields: int, ) -> None: """Validate fields, obs, and options before processing.""" for i, field in enumerate(fields): if field.grid.grid_type != "sites": raise ValueError(f"All fields must have grid_type='sites'. Found: field[{i}] has grid_type='{field.grid.grid_type}'") if obs.grid.grid_type != "sites": raise ValueError(f"Obs must have grid_type='sites'. Found: obs has grid_type='{obs.grid.grid_type}'") for layer in layers: if layer not in obs.layers: raise ValueError(f"Obs must contain layer '{layer}'. Obs has layers: {list(obs.layers)}") for i, field in enumerate(fields): if layer not in field.layers: raise ValueError(f"Field[{i}] must contain layer '{layer}'. Found: {list(field.layers)}") if n_fields > 1: if field_labels is None: raise ValueError("field_labels must be provided when multiple fields are provided") if len(field_labels) != n_fields: raise ValueError(f"field_labels length ({len(field_labels)}) must match number of fields ({n_fields})") def _taylorvar_case_matrix( n_sites: int, n_fields: int, n_layers: int, ) -> Tuple[Optional[str], Optional[str]]: """Return (color_dim, symbol_dim) from case matrix. At most 2 dimensions can vary.""" case = (n_sites > 1, n_fields > 1, n_layers > 1) color_dim, symbol_dim = None, None match case: case (True, True, True): pass # Error raised elsewhere case (True, True, False): color_dim, symbol_dim = "site", "field" case (True, False, True): color_dim, symbol_dim = "site", "layer" case (True, False, False): color_dim = "site" case (False, True, True): color_dim, symbol_dim = "field", "layer" case (False, True, False): color_dim = "field" case (False, False, True): color_dim = "layer" case (False, False, False): pass return color_dim, symbol_dim def _get_marker_for_class( symbol_class: Any, symbol_classes: List[Any], ) -> str: """Map symbol_class to a marker, cycling if more classes than markers.""" idx = symbol_classes.index(symbol_class) if symbol_class in symbol_classes else 0 return MARKERS[idx % len(MARKERS)] def _compute_taylor_points( fields: List[cp.Field], obs_field: cp.Field, gridop: Optional[str], layers: List[str], field_labels: Optional[List[str]], n_fields: int, n_layers: int, ) -> List[dict]: """Compute (std_norm, corr, site, layer, field_idx, label) for each valid point.""" obs_df = cp.make_lines(obs_field, flatten_columns=True, layers=layers) obs_columns = list(obs_df.columns) points = [] for field_idx, field in enumerate(fields): field_to_use = field.reduce_grid(gridop) if gridop else field df_field = cp.make_lines(field_to_use, flatten_columns=True, layers=layers) for col in obs_columns: if col not in df_field.columns: continue std_norm, corr = _compute_stats(df_field[col], obs_df[col]) if std_norm is None: label = _parse_column_label(col, n_layers) warnings.warn(f"Warning: Skipping {label}: insufficient data or zero obs std dev") continue site, layer = _parse_column_to_site_layer(col, n_layers) field_label = field_labels[field_idx] if field_labels and field_idx < len(field_labels) else f"Field {field_idx}" label = _format_point_label(site, layer, field_label, n_fields, n_layers) points.append({ "std_norm": std_norm, "corr": corr, "site": site, "layer": layer, "field_idx": field_idx, "field_label": field_label, "label": label, }) return points def _parse_column_to_site_layer(col: Any, n_layers: int) -> Tuple[str, str]: """Extract (site, layer) from column name. Column is 'layer - site' when n_layers>1 else just site.""" if n_layers > 1 and isinstance(col, str) and " - " in col: layer, site = col.split(" - ", 1) return site, layer return str(col), "" def _parse_column_label(col: Any, n_layers: int) -> str: """Get a readable label for a column (for warning messages).""" site, layer = _parse_column_to_site_layer(col, n_layers) if layer: return f"{site} - {layer}" return site def _format_point_label( site: str, layer: str, field_label: str, n_fields: int, n_layers: int, ) -> str: """Format combined label for legend. Include all varying dimensions.""" parts = [] if n_fields > 1: parts.append(field_label) if n_layers > 1: parts.append(layer) if site: parts.append(site) return " - ".join(parts) if parts else "Point" def _assign_colors_and_markers( points_data: List[dict], color_dim: Optional[str], symbol_dim: Optional[str], palette: Optional[str], custom_palette: Optional[str], default_marker: str, ) -> Tuple[List[Any], List[str]]: """Assign color and marker to each point based on color_dim and symbol_dim.""" if color_dim: color_classes = list(dict.fromkeys( p["site"] if color_dim == "site" else (p["field_label"] if color_dim == "field" else p["layer"]) for p in points_data )) n_colors = len(color_classes) colors_list, _ = get_color_palette(n_colors, palette=palette, custom_palette=custom_palette) color_map = {c: colors_list[i] for i, c in enumerate(color_classes)} else: colors_list, _ = get_color_palette(1, palette=palette, custom_palette=custom_palette) color_map = {} if symbol_dim: symbol_classes = list(dict.fromkeys( p["field_label"] if symbol_dim == "field" else p["layer"] for p in points_data )) else: symbol_classes = [] colors = [] markers = [] for p in points_data: if color_dim == "site": color_key = p["site"] elif color_dim == "field": color_key = p["field_label"] elif color_dim == "layer": color_key = p["layer"] else: color_key = None colors.append(color_map.get(color_key, colors_list[0]) if color_map else colors_list[0]) if symbol_dim == "field": markers.append(_get_marker_for_class(p["field_label"], symbol_classes)) elif symbol_dim == "layer": markers.append(_get_marker_for_class(p["layer"], symbol_classes)) else: markers.append(default_marker) return colors, markers def _compute_stats(model_ts: pd.Series, obs_ts: pd.Series) -> tuple[float, float]: """ Helper function to compute statistics from aligned time series """ model_ts, obs_ts = model_ts.align(obs_ts, join='inner') valid = model_ts.notna() & obs_ts.notna() model_ts, obs_ts = model_ts[valid], obs_ts[valid] if len(model_ts) < 2: return None, None std_model, std_obs = float(model_ts.std()), float(obs_ts.std()) if std_obs == 0: return None, None return std_model / std_obs, float(np.corrcoef(model_ts.values, obs_ts.values)[0, 1]) def _plot_taylor_diagram( stddev_norm: List[float], corrcoef: List[float], label_names: List[str], colors: List[Any], markers: List[str], x_fig_val: float, y_fig_val: float, marker_size: float, fontsize: int, title: Optional[str], dark_mode: bool, subfig=None, **kwargs, ) -> plt.Figure: """ Helper function to create a Taylor diagram figure. Returns the figure. """ # Handle subfigure case: TaylorDiagram can accept SubFigure directly if subfig is not None: # Pass subfigure directly to TaylorDiagram taylor = gv.TaylorDiagram(fig=subfig, label='obs') ax = plt.gca() fig = subfig.figure else: # Standard case: create figure and axes fig, ax = setup_figure_and_axes(subfig=subfig, x_fig=x_fig_val, y_fig=y_fig_val) # TaylorDiagram creates its own axes, so we need to remove the one we created ax.remove() taylor = gv.TaylorDiagram(fig=fig, label='obs') # Get the axes created by TaylorDiagram ax = plt.gca() legend_handles = [] # Plot points with per-point colors and markers for i, (std, corr, label) in enumerate(zip(stddev_norm, corrcoef, label_names)): color = colors[i] marker = markers[i] plot_kwargs = { "color": color, "model_outlier_on": True, "annotate_on": False, "marker": marker, "facecolors": color, "s": marker_size } plot_kwargs.update(kwargs) taylor.add_model_set([std], [corr], **plot_kwargs) legend_handles.append(mlines.Line2D([0], [0], marker=marker, color='w', markerfacecolor=color, markeredgecolor=color, markersize=8, label=label, linestyle='None')) # Add reference point and contours taylor.add_model_set([1.0], [1.0], color='black', facecolors='black', s=marker_size, annotate_on=False) taylor.add_contours(levels=np.arange(0, 1.1, 0.25), colors='lightgrey', linewidths=0.5) taylor.add_corr_grid(np.array([0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.95,0.99])) # Add legend: when multiple symbols, use one column per symbol if legend_handles: unique_markers = list(dict.fromkeys(markers)) if len(unique_markers) > 1: # Group handles by marker and reorder so each column = one symbol by_marker = {m: [] for m in unique_markers} for i, m in enumerate(markers): by_marker[m].append(legend_handles[i]) reordered_handles = [h for m in unique_markers for h in by_marker[m]] ncol = len(unique_markers) else: reordered_handles = legend_handles ncol = 1 ax.legend(handles=reordered_handles, loc='center left', bbox_to_anchor=(0.95, 0.75), frameon=False, fontsize=fontsize, ncol=ncol) # Add title if title: ax.set_title(title, fontsize=18, pad=100) # Apply dark mode if requested if dark_mode: fig, ax = make_dark_mode(fig, ax) leg = ax.get_legend() if leg is not None: for text in leg.get_texts(): text.set_color('white') return fig