from typing import Any, Optional, List, Union
from matplotlib.axes import Axes
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import canopy as cp
from canopy.visualization.multiple_figs import setup_figure_and_axes, create_wrapper_from_locals
from canopy.visualization.visualization_helpers import (
get_color_palette, make_dark_mode, handle_figure_output, get_field_metadata,
)
[docs]
def make_static_plot(
field_a: cp.Field,
field_b: cp.Field,
kind: Optional[str] = 'scatter',
output_file: Optional[str] = None,
layers: Optional[List[str] | str] = None,
field_a_label: Optional[str] = None,
field_b_label: Optional[str] = None,
unit_a: Optional[str] = None,
unit_b: Optional[str] = None,
scatter_size: Optional[float] = 6,
scatter_alpha: Optional[float] = 0.5,
title: Optional[str] = None,
palette: Optional[str] = None,
custom_palette: Optional[str] = None,
move_legend: Optional[bool] = False,
dark_mode: Optional[bool] = False,
transparent: Optional[bool] = False,
x_label_rotation: float = 0,
x_fig: Optional[float] = 10,
y_fig: Optional[float] = 10,
subfig=None,
return_fig: Optional[bool] = False,
**kwargs,
) -> Optional[plt.Figure]:
"""
This function generates a scatter plot with regression lines and r-scores, a histogram or a kde plot,
comparing two input fields (which can be reduced spatially, temporally or both).
Parameters
----------
field_a, field_b : cp.Field
Input data Field to display.
kind : str, optional
Kind of plot to draw. Default is 'scatter', which uses `seaborn.regplot` (supports multiple layers).
Option 'hist' uses `seaborn.histplot` (supports multiple layers).
Option 'kde' uses `seaborn.kdeplot` (supports multiple layers).
output_file : str, optional
File path for saving the plot.
layers : List[str] or str, optional
Layers to plot from the input data.
field_a_label, field_b_label : str, optional
Labels for the data series, if not provided canopy will try to retrieve the name of the variable in the metadata.
unit_a, unit_b : str, optional
Units for the data series, if not provided canopy will try to retrieve the unit of the variable in the metadata.
scatter_size : float, optional
Marker size for scatter points. Default is 6.
scatter_alpha : float, optional
Transparency (alpha) for scatter points. Default is 0.5.
title : str, optional
Title of the plot.
palette : str, optional
Seaborn color palette to use for the line colors (https://seaborn.pydata.org/tutorial/color_palettes.html,
recommended palette are in https://colorbrewer2.org).
custom_palette : str, optional
Path of custom color palette .txt file to use. Names should match label names.
move_legend : bool, optional
Location of the legend ('in' or 'out'). Default is False.
dark_mode : bool, optional
Whether to apply dark mode styling to the plot.
transparent : bool, optional
If True, makes the background of the figure transparent.
x_label_rotation : float, optional
Rotation angle in degrees for the x-axis tick labels. Default is 0.
x_fig : float, optional
Width of the figure in inches. Default is 10.
y_fig : float, optional
Height of the figure in inches. Default is 10.
subfig : matplotlib.figure.SubFigure, optional
If provided, the plot will be created in this subfigure instead of creating a new figure.
This is used by multiple_figs() to combine multiple plots.
User can also provide a plt.figure.subfigure object
(https://matplotlib.org/stable/gallery/subplots_axes_and_figures/subfigures.html)
return_fig : bool, optional
If True, return a callable wrapper function instead of creating the plot immediately.
This wrapper can be used with multiple_figs(). Default is False.
**kwargs
Additional keyword arguments are passed directly to `seaborn.regplot` (if `kind='scatter'`)
or `seaborn.histplot` (if `kind='hist'`). This allows customization of plot features.
"""
# If return_fig is True, create a wrapper function and return it
if return_fig:
return create_wrapper_from_locals(make_static_plot, locals())
# Create boolean if grid type is sites
sites = field_a.grid.grid_type == "sites"
# Force variables to be a list
if isinstance(layers, str):
layers = [layers]
if not isinstance(sites, bool) and not isinstance(sites, list):
sites = [sites]
# Retrieve metadata
field_a_label, unit_a, layers = get_field_metadata(field_a, field_a_label, unit_a, layers)
field_b_label, unit_b, _ = get_field_metadata(field_b, field_b_label, unit_b, None)
_pre_checks(field_a, field_b, field_a_label, field_b_label, layers, sites)
# If sites (from parameter or grid type), retrieve sites labels (or lon, lat) as layers
if sites:
layers = list[str](field_a.sites.keys())
df_a = cp.make_lines(field_a)
df_b = cp.make_lines(field_b)
# Simplify MultiIndex columns for sites: use label if available, otherwise 'lon, lat' format
if sites:
df_a = _make_columns_sites(df_a, list[str](field_a.sites.keys()))
df_b = _make_columns_sites(df_b, list[str](field_b.sites.keys()))
# Prepare axis labels (used by both paths)
xlabel = f"{field_a_label} (in {unit_a})" if unit_a != "[no units]" else field_a_label
ylabel = f"{field_b_label} (in {unit_b})" if unit_b != "[no units]" else field_b_label
# Set up the figure and axes
fig, ax = setup_figure_and_axes(subfig=subfig, x_fig=x_fig, y_fig=y_fig)
# Build items to plot: (label, x, y) per layer
items = []
for layer in layers:
x1d, y1d = _to_aligned_1d(df_a[layer], df_b[layer])
items.append((layer, x1d, y1d))
# Colour palette
colors, _ = get_color_palette(len(items), palette=palette, custom_palette=custom_palette)
# Filter kwargs to remove subfig
filtered_kwargs = {k: v for k, v in kwargs.items() if k != 'subfig'}
min_val = min(min(x.min(), y.min()) for _, x, y in items)
max_val = max(max(x.max(), y.max()) for _, x, y in items)
for i, (label, x, y) in enumerate(tqdm(items, desc="Processing layers")):
_plot_layer(kind, ax, x, y, label, colors[i], scatter_size, scatter_alpha, filtered_kwargs)
# Annotate counts for histograms
if kind == "hist":
_annotate_hist_counts(ax)
# Create legend for hist/kde
if kind in ("hist", "kde"):
from matplotlib.patches import Patch
handles = [
Patch(facecolor=colors[i], edgecolor="none", label=f"{layer}")
for i, layer in enumerate(layers)
]
ax.legend(handles=handles, loc="best", frameon=False, fontsize=14)
# Legend options
if kind == "scatter":
leg = ax.legend(loc='best', frameon=False, fontsize=14)
else:
leg = ax.get_legend()
# Set alpha on handles before moving legend
if leg is not None:
for handle in leg.legend_handles:
if hasattr(handle, 'set_alpha'):
handle.set_alpha(1.0)
if move_legend:
sns.move_legend(ax, "center left", bbox_to_anchor=(1, 0.95))
# Get the legend again after moving (move_legend creates a new legend)
leg = ax.get_legend()
# Handle axis limits, labels and title
ax.set_xlim([min_val, max_val])
if kind == 'scatter':
ax.set_ylim([min_val, max_val])
ax.set_xlabel(xlabel, fontsize=16)
if x_label_rotation:
ax.tick_params(axis='x', labelrotation=x_label_rotation)
ax.set_ylabel(ylabel, fontsize=16)
if title:
ax.set_title(title, fontsize=18)
# Remove the top and right spines
sns.despine(ax=ax)
# Apply dark mode if requested
if dark_mode:
fig, ax = make_dark_mode(fig, ax, legend_style=None)
# Adjust figure layout to accommodate moved legend (only for standalone figures)
if move_legend and subfig is None:
fig.tight_layout()
return handle_figure_output(fig, output_file=output_file, transparent=transparent, subfig=subfig)
def _pre_checks(
field_a: cp.Field,
field_b: cp.Field,
field_a_label: str,
field_b_label: str,
layers: List[str],
sites: bool,
) -> None:
"""Validate inputs before processing."""
# Sites grids can only be compared for a single variable/layer.
if sites and len(layers) > 1:
raise ValueError("layers and sites argument cannot be used simultaneously. Only one layer for multiple sites.")
# Check if both fields have the same grid
if hasattr(field_a.grid, 'min') and hasattr(field_b.grid, 'min'):
both_reduced = all(f.grid.is_reduced('lat') and f.grid.is_reduced('lon') for f in (field_a, field_b))
if not both_reduced and (field_a.grid.min != field_b.grid.min or field_a.grid.max != field_b.grid.max):
ga, gb = field_a.grid, field_b.grid
raise ValueError(
f"Fields have different grids: {field_a_label} {ga.min}-{ga.max} vs {field_b_label} {gb.min}-{gb.max}. "
)
def _make_columns_sites(
df: pd.DataFrame,
site_keys: List[str],
) -> pd.DataFrame:
"""
Simplify MultiIndex columns for sites data by using site keys.
"""
if not isinstance(df.columns, pd.MultiIndex):
return df
# If 'label' exists in column names, extract labels to match site keys
if 'label' in df.columns.names:
label_idx = df.columns.names.index('label')
new_columns = [col[label_idx] for col in df.columns]
else:
new_columns = site_keys if len(df.columns) == len(site_keys) else [str(col) for col in df.columns]
df = df.copy()
df.columns = new_columns
return df
def _to_aligned_1d(
x_obj: Union[pd.Series, pd.DataFrame],
y_obj: Union[pd.Series, pd.DataFrame],
) -> tuple[pd.Series, pd.Series]:
"""
Return pairwise-valid, aligned 1D Series for x and y (handles Series/DataFrame)
"""
# Stack the data to handle multiple layers
xs = x_obj.stack(future_stack=True) if isinstance(x_obj, pd.DataFrame) else x_obj
ys = y_obj.stack(future_stack=True) if isinstance(y_obj, pd.DataFrame) else y_obj
xs, ys = xs.align(ys, join='inner')
# Unstack the data to handle multiple layers
if isinstance(xs, pd.DataFrame):
xs = xs.stack(future_stack=True)
if isinstance(ys, pd.DataFrame):
ys = ys.stack(future_stack=True)
# Get valid data
valid = xs.notna() & ys.notna()
# Convert to numpy arrays
if isinstance(valid, pd.DataFrame):
valid = valid.to_numpy().ravel()
xs = pd.Series(xs.to_numpy().ravel())
ys = pd.Series(ys.to_numpy().ravel())
return xs[valid], ys[valid]
def _plot_layer(
kind: str,
ax: Axes,
x: pd.Series,
y: pd.Series,
label: str,
color: Any,
scatter_size: float,
scatter_alpha: float,
filtered_kwargs: dict,
) -> None:
"""Plot one series (layer or season) according to kind."""
plot_kwargs = {"x": x, "y": y, "ax": ax, "color": color, **filtered_kwargs}
if kind == 'scatter':
r_val = np.corrcoef(x.values, y.values)[0, 1]
plot_kwargs["label"] = f"{label} (R={r_val:.2f})"
plot_kwargs["scatter_kws"] = {"s": scatter_size, "alpha": scatter_alpha}
sns.regplot(**plot_kwargs)
elif kind == 'hist':
plot_kwargs["label"] = label
sns.histplot(**plot_kwargs)
elif kind == 'kde':
plot_kwargs["label"] = label
sns.kdeplot(**plot_kwargs)
else:
raise ValueError(f"Unsupported kind: {kind}, must be 'scatter', 'hist' or 'kde'.")
def _annotate_hist_counts(ax: Axes) -> None:
"""
Annotate each rectangle in a histogram plot with the count of points in that bin.
"""
# Process all collections (seaborn histplot uses collections for 2D histograms)
for collection in ax.collections:
# Get array/data from collection
try:
array = collection.get_array()
except:
continue
# Get paths from collection
try:
paths = collection.get_paths()
except:
continue
if array is None or len(paths) == 0:
continue
# Get the array values (counts)
counts = array.data if hasattr(array, 'data') else array
if counts is None:
continue
# Convert to numpy array and flatten for easier indexing
counts = np.asarray(counts).flatten()
# Find max count for text color threshold
valid_counts = counts[~np.isnan(counts) & (counts > 0)]
if len(valid_counts) == 0:
continue
max_count = np.max(valid_counts)
# Annotate each path (rectangle/polygon)
for idx, path in enumerate(paths):
if idx >= len(counts):
break
count = counts[idx]
if count <= 0 or np.isnan(count):
continue
# Get the center of the shape from the path
vertices = path.vertices
if len(vertices) == 0:
continue
x_center = np.mean(vertices[:, 0])
y_center = np.mean(vertices[:, 1])
# Use white text for bins with counts above 50% of max
text_color = 'white' if count > max_count * 0.5 else 'black'
# Add text annotation
ax.text(x_center, y_center, f'{int(count)}',
ha='center', va='center', fontsize=8,
color=text_color, weight='bold')