import warnings
from typing import Optional, List, Tuple, Any
import geocat.viz as gv
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import canopy as cp
from canopy.visualization.multiple_figs import setup_figure_and_axes, create_wrapper_from_locals
from canopy.visualization.visualization_helpers import get_color_palette, make_dark_mode, handle_figure_output
# Marker symbols for differentiating layers/fields (cycle if more than available)
MARKERS = ['o', 's', '^', 'v', 'D', 'p', 'h', '*', '<', '>']
[docs]
def make_taylor_diagram(
fields: cp.Field | List[cp.Field],
obs: cp.Field,
output_file: Optional[str] = None,
gridop: Optional[str] = None,
title: Optional[str] = None,
layers: Optional[List[str] | str] = None,
field_labels: Optional[List[str]] = None,
palette: Optional[str] = None,
custom_palette: Optional[str] = None,
dark_mode: bool = False,
transparent: bool = False,
marker_size: float = 100,
marker: str = 'o',
fontsize: int = 16,
x_fig: float = 12,
y_fig: float = 10,
subfig=None,
return_fig: bool = False,
**kwargs,
) -> Optional[plt.Figure]:
"""
Create a Taylor diagram (https://pcmdi.llnl.gov/staff/taylor/CV/Taylor_diagram_primer.pdf)
comparing model data against observations.
Note: This function requires all input fields and observations to have grid_type='sites'.
Parameters
----------
fields : cp.Field or List[cp.Field]
Model data Field(s) to compare against observations. If a list is provided,
multiple fields will be plotted with one point per field.
obs : cp.Field
Observation/reference data Field (typically the "truth").
output_file : str, optional
File path for saving the plot.
gridop : str, optional
If provided, the grid reduction operation. Either None, 'sum' or 'av'. Default is None.
title : str, optional
Title of the plot.
layers : List[str] or str, optional
Layer names to include. Defaults to all layers in the first field.
field_labels : List[str], optional
Labels for each field when multiple fields are provided. Required when field is a list.
palette : str, optional
Seaborn color palette to use (https://seaborn.pydata.org/tutorial/color_palettes.html).
custom_palette : str, optional
Path of custom color palette .txt file to use. Names should match site labels.
dark_mode : bool, optional
Whether to apply dark mode styling to the plot.
transparent : bool, optional
If True, makes the background of the figure transparent.
marker_size : float, optional
Size of the markers. Default is 100.
marker : str, optional
Marker shape. Default is 'o' (circle).
fontsize : int, optional
Font size for labels. Default is 16.
x_fig : float, optional
Width of the figure in inches. Default is 12.
y_fig : float, optional
Height of the figure in inches. Default is 12.
subfig : matplotlib.figure.SubFigure, optional
If provided, the plot will be created in this subfigure instead of creating a new figure.
This is used by multiple_figs() to combine multiple plots.
User can also provide a plt.figure.subfigure object
(https://matplotlib.org/stable/gallery/subplots_axes_and_figures/subfigures.html)
return_fig : bool, optional
If True, return a callable wrapper function instead of creating the plot immediately.
This wrapper can be used with multiple_figs(). Default is False.
**kwargs
Additional keyword arguments are passed directly to `taylor.add_model_set`. This allows
customization of plot features such as `linestyle`, `linewidth`, etc.
"""
# If return_fig is True, create a wrapper function and return it
if return_fig:
return create_wrapper_from_locals(make_taylor_diagram, locals())
# Force fields and layers to be lists
if isinstance(fields, cp.Field):
fields = [fields]
if isinstance(layers, str):
layers = [layers]
layers = layers or list(fields[0].layers)
n_fields = len(fields)
n_layers = len(layers)
_pre_checks(fields, obs, field_labels, layers, n_fields)
# Apply grid reduction to obs if needed
obs_field = obs.reduce_grid(gridop) if gridop else obs
# Build flattened data and compute stats with (color_class, symbol_class, label) per point
points_data = _compute_taylor_points(
fields, obs_field, gridop, layers, field_labels, n_fields, n_layers
)
if len(points_data) == 0:
raise ValueError("No valid data found with sufficient data for Taylor diagram")
# Determine color_dim and symbol_dim from case matrix
n_sites = len(set(p["site"] for p in points_data))
if n_sites > 1 and n_fields > 1 and n_layers > 1:
raise ValueError(
"Cannot differentiate 3 categories (sites, fields, layers). Use at most 2: "
"e.g. reduce to one site, one field, or one layer."
)
color_dim, symbol_dim = _taylorvar_case_matrix(n_sites, n_fields, n_layers)
# Assign colors and markers per point
colors, markers = _assign_colors_and_markers(
points_data, color_dim, symbol_dim, palette, custom_palette, marker
)
stddev_norm = [p["std_norm"] for p in points_data]
corrcoef = [p["corr"] for p in points_data]
label_names = [p["label"] for p in points_data]
# Create the Taylor diagram figure
fig = _plot_taylor_diagram(
stddev_norm, corrcoef, label_names, colors, markers, x_fig, y_fig,
marker_size, fontsize,
title, dark_mode, subfig=subfig, **kwargs
)
# Handle output and return
return handle_figure_output(fig, output_file=output_file, transparent=transparent, subfig=subfig)
def _pre_checks(
fields: List[cp.Field],
obs: cp.Field,
field_labels: Optional[List[str]],
layers: List[str],
n_fields: int,
) -> None:
"""Validate fields, obs, and options before processing."""
for i, field in enumerate(fields):
if field.grid.grid_type != "sites":
raise ValueError(f"All fields must have grid_type='sites'. Found: field[{i}] has grid_type='{field.grid.grid_type}'")
if obs.grid.grid_type != "sites":
raise ValueError(f"Obs must have grid_type='sites'. Found: obs has grid_type='{obs.grid.grid_type}'")
for layer in layers:
if layer not in obs.layers:
raise ValueError(f"Obs must contain layer '{layer}'. Obs has layers: {list(obs.layers)}")
for i, field in enumerate(fields):
if layer not in field.layers:
raise ValueError(f"Field[{i}] must contain layer '{layer}'. Found: {list(field.layers)}")
if n_fields > 1:
if field_labels is None:
raise ValueError("field_labels must be provided when multiple fields are provided")
if len(field_labels) != n_fields:
raise ValueError(f"field_labels length ({len(field_labels)}) must match number of fields ({n_fields})")
def _taylorvar_case_matrix(
n_sites: int,
n_fields: int,
n_layers: int,
) -> Tuple[Optional[str], Optional[str]]:
"""Return (color_dim, symbol_dim) from case matrix. At most 2 dimensions can vary."""
case = (n_sites > 1, n_fields > 1, n_layers > 1)
color_dim, symbol_dim = None, None
match case:
case (True, True, True):
pass # Error raised elsewhere
case (True, True, False):
color_dim, symbol_dim = "site", "field"
case (True, False, True):
color_dim, symbol_dim = "site", "layer"
case (True, False, False):
color_dim = "site"
case (False, True, True):
color_dim, symbol_dim = "field", "layer"
case (False, True, False):
color_dim = "field"
case (False, False, True):
color_dim = "layer"
case (False, False, False):
pass
return color_dim, symbol_dim
def _get_marker_for_class(
symbol_class: Any,
symbol_classes: List[Any],
) -> str:
"""Map symbol_class to a marker, cycling if more classes than markers."""
idx = symbol_classes.index(symbol_class) if symbol_class in symbol_classes else 0
return MARKERS[idx % len(MARKERS)]
def _compute_taylor_points(
fields: List[cp.Field],
obs_field: cp.Field,
gridop: Optional[str],
layers: List[str],
field_labels: Optional[List[str]],
n_fields: int,
n_layers: int,
) -> List[dict]:
"""Compute (std_norm, corr, site, layer, field_idx, label) for each valid point."""
obs_df = cp.make_lines(obs_field, flatten_columns=True, layers=layers)
obs_columns = list(obs_df.columns)
points = []
for field_idx, field in enumerate(fields):
field_to_use = field.reduce_grid(gridop) if gridop else field
df_field = cp.make_lines(field_to_use, flatten_columns=True, layers=layers)
for col in obs_columns:
if col not in df_field.columns:
continue
std_norm, corr = _compute_stats(df_field[col], obs_df[col])
if std_norm is None:
label = _parse_column_label(col, n_layers)
warnings.warn(f"Warning: Skipping {label}: insufficient data or zero obs std dev")
continue
site, layer = _parse_column_to_site_layer(col, n_layers)
field_label = field_labels[field_idx] if field_labels and field_idx < len(field_labels) else f"Field {field_idx}"
label = _format_point_label(site, layer, field_label, n_fields, n_layers)
points.append({
"std_norm": std_norm,
"corr": corr,
"site": site,
"layer": layer,
"field_idx": field_idx,
"field_label": field_label,
"label": label,
})
return points
def _parse_column_to_site_layer(col: Any, n_layers: int) -> Tuple[str, str]:
"""Extract (site, layer) from column name. Column is 'layer - site' when n_layers>1 else just site."""
if n_layers > 1 and isinstance(col, str) and " - " in col:
layer, site = col.split(" - ", 1)
return site, layer
return str(col), ""
def _parse_column_label(col: Any, n_layers: int) -> str:
"""Get a readable label for a column (for warning messages)."""
site, layer = _parse_column_to_site_layer(col, n_layers)
if layer:
return f"{site} - {layer}"
return site
def _format_point_label(
site: str,
layer: str,
field_label: str,
n_fields: int,
n_layers: int,
) -> str:
"""Format combined label for legend. Include all varying dimensions."""
parts = []
if n_fields > 1:
parts.append(field_label)
if n_layers > 1:
parts.append(layer)
if site:
parts.append(site)
return " - ".join(parts) if parts else "Point"
def _assign_colors_and_markers(
points_data: List[dict],
color_dim: Optional[str],
symbol_dim: Optional[str],
palette: Optional[str],
custom_palette: Optional[str],
default_marker: str,
) -> Tuple[List[Any], List[str]]:
"""Assign color and marker to each point based on color_dim and symbol_dim."""
if color_dim:
color_classes = list(dict.fromkeys(
p["site"] if color_dim == "site" else (p["field_label"] if color_dim == "field" else p["layer"])
for p in points_data
))
n_colors = len(color_classes)
colors_list, _ = get_color_palette(n_colors, palette=palette, custom_palette=custom_palette)
color_map = {c: colors_list[i] for i, c in enumerate(color_classes)}
else:
colors_list, _ = get_color_palette(1, palette=palette, custom_palette=custom_palette)
color_map = {}
if symbol_dim:
symbol_classes = list(dict.fromkeys(
p["field_label"] if symbol_dim == "field" else p["layer"]
for p in points_data
))
else:
symbol_classes = []
colors = []
markers = []
for p in points_data:
if color_dim == "site":
color_key = p["site"]
elif color_dim == "field":
color_key = p["field_label"]
elif color_dim == "layer":
color_key = p["layer"]
else:
color_key = None
colors.append(color_map.get(color_key, colors_list[0]) if color_map else colors_list[0])
if symbol_dim == "field":
markers.append(_get_marker_for_class(p["field_label"], symbol_classes))
elif symbol_dim == "layer":
markers.append(_get_marker_for_class(p["layer"], symbol_classes))
else:
markers.append(default_marker)
return colors, markers
def _compute_stats(model_ts: pd.Series, obs_ts: pd.Series) -> tuple[float, float]:
"""
Helper function to compute statistics from aligned time series
"""
model_ts, obs_ts = model_ts.align(obs_ts, join='inner')
valid = model_ts.notna() & obs_ts.notna()
model_ts, obs_ts = model_ts[valid], obs_ts[valid]
if len(model_ts) < 2:
return None, None
std_model, std_obs = float(model_ts.std()), float(obs_ts.std())
if std_obs == 0:
return None, None
return std_model / std_obs, float(np.corrcoef(model_ts.values, obs_ts.values)[0, 1])
def _plot_taylor_diagram(
stddev_norm: List[float],
corrcoef: List[float],
label_names: List[str],
colors: List[Any],
markers: List[str],
x_fig_val: float,
y_fig_val: float,
marker_size: float,
fontsize: int,
title: Optional[str],
dark_mode: bool,
subfig=None,
**kwargs,
) -> plt.Figure:
"""
Helper function to create a Taylor diagram figure.
Returns the figure.
"""
# Handle subfigure case: TaylorDiagram can accept SubFigure directly
if subfig is not None:
# Pass subfigure directly to TaylorDiagram
taylor = gv.TaylorDiagram(fig=subfig, label='obs')
ax = plt.gca()
fig = subfig.figure
else:
# Standard case: create figure and axes
fig, ax = setup_figure_and_axes(subfig=subfig, x_fig=x_fig_val, y_fig=y_fig_val)
# TaylorDiagram creates its own axes, so we need to remove the one we created
ax.remove()
taylor = gv.TaylorDiagram(fig=fig, label='obs')
# Get the axes created by TaylorDiagram
ax = plt.gca()
legend_handles = []
# Plot points with per-point colors and markers
for i, (std, corr, label) in enumerate(zip(stddev_norm, corrcoef, label_names)):
color = colors[i]
marker = markers[i]
plot_kwargs = {
"color": color,
"model_outlier_on": True,
"annotate_on": False,
"marker": marker,
"facecolors": color,
"s": marker_size
}
plot_kwargs.update(kwargs)
taylor.add_model_set([std], [corr], **plot_kwargs)
legend_handles.append(mlines.Line2D([0], [0], marker=marker, color='w',
markerfacecolor=color, markeredgecolor=color, markersize=8,
label=label, linestyle='None'))
# Add reference point and contours
taylor.add_model_set([1.0], [1.0], color='black', facecolors='black', s=marker_size, annotate_on=False)
taylor.add_contours(levels=np.arange(0, 1.1, 0.25), colors='lightgrey', linewidths=0.5)
taylor.add_corr_grid(np.array([0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.95,0.99]))
# Add legend: when multiple symbols, use one column per symbol
if legend_handles:
unique_markers = list(dict.fromkeys(markers))
if len(unique_markers) > 1:
# Group handles by marker and reorder so each column = one symbol
by_marker = {m: [] for m in unique_markers}
for i, m in enumerate(markers):
by_marker[m].append(legend_handles[i])
reordered_handles = [h for m in unique_markers for h in by_marker[m]]
ncol = len(unique_markers)
else:
reordered_handles = legend_handles
ncol = 1
ax.legend(handles=reordered_handles, loc='center left', bbox_to_anchor=(0.95, 0.75),
frameon=False, fontsize=fontsize, ncol=ncol)
# Add title
if title:
ax.set_title(title, fontsize=18, pad=100)
# Apply dark mode if requested
if dark_mode:
fig, ax = make_dark_mode(fig, ax)
leg = ax.get_legend()
if leg is not None:
for text in leg.get_texts():
text.set_color('white')
return fig