import warnings
from typing import Optional, List
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
import numpy as np
import pandas as pd
import seaborn as sns
import canopy as cp
from canopy.visualization.multiple_figs import setup_figure_and_axes, create_wrapper_from_locals
from canopy.visualization.visualization_helpers import (
get_color_palette, make_dark_mode, handle_figure_output,
format_value_label, set_axis_style, get_field_metadata,
)
# Plot-type-specific seaborn options
_PLOT_OPTIONS = {
"box": {"fill": False, "showfliers": False, "gap": 0.1},
"boxen": {"fill": False, "showfliers": False, "gap": 0.1},
"violin": {"inner": None, "bw_method": 1},
}
[docs]
def make_distribution_plot(
fields: cp.Field | List[cp.Field],
output_file: Optional[str] = None,
plot_type: str = "box",
layers: Optional[List[str] | str] = None,
gridop: Optional[str] = 'av',
yaxis_label: Optional[str] = None,
field_labels: Optional[List[str]] = None,
unit: Optional[List[str]] = None,
title: Optional[str] = None,
palette: Optional[str] = None,
custom_palette: Optional[List[str]] = None,
horizontal: bool = False,
vertical_xlabels: bool = False,
x_label_rotation: float = 0,
move_legend: bool = False,
dark_mode: bool = False,
transparent: bool = False,
x_fig: float = 10,
y_fig: float = 10,
subfig=None,
return_fig: bool = False,
**kwargs,
) -> Optional[plt.Figure]:
"""
Create a comparative plot from a list of input data Fields from, for example, different runs.
The functions can generate boxplot, strip or swarm plot, violin plot, boxen plot, point plot,
bar plot or count plot based on the `plot_type` parameter.
Parameters
----------
fields : cp.Field or List[cp.Field]
Input data Field to display.
output_file : str, optional
File path for saving the plot.
plot_type: str, optional
Type of plot. Either "strip", "swarm", "box", "violin", "boxen", "point", or "bar"
layers : List[str] or str, optional
List of layer names to display.
gridop : str, optional
If provided, the grid reduction operation. Either None, 'sum' or 'av'. Default is 'av'.
yaxis_label : str, optional
Y-axis label, if not provided canopy will try to retrieve the name of the variable in the metadata.
field_labels : List[str], optional
Names of each series to display in the legend.
unit : List[str], optional
Unit of the y-axis variable, if not provided canopy will try to retrieve the unit of the variable in the metadata.
title : str, optional
Title of the plot.
palette : str, optional
Seaborn color palette to use for the line colors (https://seaborn.pydata.org/tutorial/color_palettes.html,
recommended palette are in https://colorbrewer2.org).
custom_palette : List[str], optional
Path of custom color palette .txt file to use. Names should match label names.
horizontal : bool, optional
If True, renders the plot with horizontal orientation (flips the axes).
vertical_xlabels : bool, optional
If True, rotates the x-axis tick labels vertically (i.e., 90 degrees).
x_label_rotation : float, optional
Rotation angle in degrees for the x-axis tick labels. Overrides vertical_xlabels when non-zero. Default is 0.
move_legend : bool, optional
Move the legend outside of plot. Default is False.
dark_mode : bool, optional
If True, applies dark mode styling to the figure. Default is False.
transparent : bool, optional
If True, sets the figure background to be transparent when saved. Default is False.
x_fig : float, optional
Width of the figure in inches. Default is 10.
y_fig : float, optional
Height of the figure in inches. Default is 10.
subfig : matplotlib.figure.SubFigure, optional
If provided, the plot will be created in this subfigure instead of creating a new figure.
This is used by multiple_figs() to combine multiple plots.
User can also provide a plt.figure.subfigure object
(https://matplotlib.org/stable/gallery/subplots_axes_and_figures/subfigures.html)
return_fig : bool, optional
If True, return a callable wrapper function instead of creating the plot immediately.
This wrapper can be used with multiple_figs(). Default is False.
**kwargs
Additional keyword arguments are passed directly to `seaborn.catplot`.
This allows customization of plot features such as `aspect`, `errorbar`, height`, etc.
"""
# If return_fig is True, create a wrapper function and return it
if return_fig:
return create_wrapper_from_locals(make_distribution_plot, locals())
# Force fields to be a list
if isinstance(fields, cp.Field):
fields = [fields]
if plot_type == "count":
raise ValueError("count plot is not supported for distribution plot.")
# Force variables to be a list
if isinstance(layers, str):
layers = [layers]
# Retrieve metadata
yaxis_label, unit, layers = get_field_metadata(fields, yaxis_label, unit, layers)
# Pre-checks
if len(fields) > 1 and field_labels is None:
raise ValueError("field_labels must be defined when there are more than one field.")
field_labels_list = field_labels if field_labels is not None else [" "]
# Convert fields to long format
df_long = _fields_to_long_format(fields, layers, field_labels_list, gridop)
# Get color palette
n_classes = len(field_labels_list)
colors, color_dict = get_color_palette(n_classes=n_classes, palette=palette, custom_palette=custom_palette)
palette_dict = {label: color for label, color in zip(field_labels_list, colors)}
x, y = ("value", "category") if horizontal else ("category", "value")
# Create figure and axis
fig, ax = setup_figure_and_axes(subfig=subfig, x_fig=x_fig, y_fig=y_fig)
plot_kwargs = {
"data": df_long,
"x": x,
"y": y,
"hue": "series",
"palette": palette_dict if palette_dict else palette,
"ax": ax
}
# Update plot kwargs
plot_kwargs.update({k: v for k, v in kwargs.items() if k != "subfig"})
plot_kwargs.update(_PLOT_OPTIONS.get(plot_type, {}))
if plot_type == "violin" and len(fields) == 2: # Split violin plot for two fields
plot_kwargs["split"] = True
# Make the distribution plot
ax = _plot_distribution(plot_type, plot_kwargs)
# Set the axis labels and style
axis_label = format_value_label(yaxis_label, unit)
set_axis_style(ax, title=title, tick_labelsize=14,
x_label=axis_label if horizontal else "",
y_label="" if horizontal else axis_label)
# Rotate labels if requested to prevent overlap
rotation = x_label_rotation if x_label_rotation else (90 if vertical_xlabels else 0)
if rotation:
ticklabels = ax.get_yticklabels() if horizontal else ax.get_xticklabels()
plt.setp(ticklabels, rotation=rotation, ha="center" if horizontal else "right")
# Custom legend (colored labels, no box)
handles, labels = ax.get_legend_handles_labels()
if labels and palette_dict:
leg_handles = [plt.Line2D([], [], color=palette_dict.get(l, "black"), marker="", linestyle="") for l in labels]
ax.legend(handles=leg_handles, labels=labels, handlelength=0, handletextpad=0,
labelcolor=[palette_dict.get(l, "black") for l in labels],
loc="best", frameon=False, fontsize=14)
if move_legend:
sns.move_legend(ax, "center left", bbox_to_anchor=(1, 0.85), fontsize=16)
# Apply dark mode
if dark_mode:
fig, ax = make_dark_mode(fig, ax)
# Reapply label colors after move_legend or dark_mode (both can reset them)
leg = ax.get_legend()
if labels and palette_dict and leg and (move_legend or dark_mode):
for text, lbl in zip(leg.get_texts(), labels):
text.set_color(palette_dict.get(lbl, "black"))
return handle_figure_output(fig, output_file=output_file, transparent=transparent, subfig=subfig)
def _fields_to_long_format(
fields: List[cp.Field],
layers: List[str],
field_labels: List[str],
gridop: Optional[str],
) -> pd.DataFrame:
"""Convert fields to long-format DataFrame (value, series, category)."""
rows = []
for i, field in enumerate(fields):
label = field_labels[i] if i < len(field_labels) else f"Field {i+1}"
field_use = field
if gridop and not field.grid.is_reduced("lat") and not field.grid.is_reduced("lon"):
field_use = field.reduce_grid(gridop)
df = cp.make_lines(field_use, flatten_columns=True, layers=layers)
for layer in layers:
cols = [layer] if layer in df.columns else (list(df.columns) if len(layers) == 1
else [c for c in df.columns if str(c).startswith(f"{layer} - ")])
if len(layers) == 1 and len(cols) > 1:
for col in cols:
data = np.asarray(df[col].values).flatten()
rows.append(pd.DataFrame({"value": data, "series": label, "category": col}))
else:
data = np.asarray(df[cols].values).flatten()
rows.append(pd.DataFrame({"value": data, "series": label, "category": layer}))
return pd.concat(rows, ignore_index=True)
def _plot_distribution(plot_type: str, plot_kwargs: dict[str]) -> Axes:
"""Dispatch to the appropriate seaborn plot function based on plot_type."""
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=PendingDeprecationWarning, message=".*vert.*")
match plot_type:
case "box":
return sns.boxplot(**plot_kwargs)
case "violin":
return sns.violinplot(**plot_kwargs)
case "boxen":
return sns.boxenplot(**plot_kwargs)
case "strip":
return sns.stripplot(**plot_kwargs)
case "swarm":
return sns.swarmplot(**plot_kwargs)
case "point":
return sns.pointplot(**plot_kwargs)
case "bar":
return sns.barplot(**plot_kwargs)
case _:
raise ValueError(f"Unsupported plot_type: {plot_type}")