Source code for canopy.visualization.distribution_plot.distribution_plot

import warnings
from typing import Optional, List

import matplotlib.pyplot as plt
from matplotlib.axes import Axes
import numpy as np
import pandas as pd
import seaborn as sns

import canopy as cp
from canopy.visualization.multiple_figs import setup_figure_and_axes, create_wrapper_from_locals
from canopy.visualization.visualization_helpers import (
    get_color_palette, make_dark_mode, handle_figure_output,
    format_value_label, set_axis_style, get_field_metadata,
)

# Plot-type-specific seaborn options
_PLOT_OPTIONS = {
    "box": {"fill": False, "showfliers": False, "gap": 0.1},
    "boxen": {"fill": False, "showfliers": False, "gap": 0.1},
    "violin": {"inner": None, "bw_method": 1},
}

[docs] def make_distribution_plot( fields: cp.Field | List[cp.Field], output_file: Optional[str] = None, plot_type: str = "box", layers: Optional[List[str] | str] = None, gridop: Optional[str] = 'av', yaxis_label: Optional[str] = None, field_labels: Optional[List[str]] = None, unit: Optional[List[str]] = None, title: Optional[str] = None, palette: Optional[str] = None, custom_palette: Optional[List[str]] = None, horizontal: bool = False, vertical_xlabels: bool = False, x_label_rotation: float = 0, move_legend: bool = False, dark_mode: bool = False, transparent: bool = False, x_fig: float = 10, y_fig: float = 10, subfig=None, return_fig: bool = False, **kwargs, ) -> Optional[plt.Figure]: """ Create a comparative plot from a list of input data Fields from, for example, different runs. The functions can generate boxplot, strip or swarm plot, violin plot, boxen plot, point plot, bar plot or count plot based on the `plot_type` parameter. Parameters ---------- fields : cp.Field or List[cp.Field] Input data Field to display. output_file : str, optional File path for saving the plot. plot_type: str, optional Type of plot. Either "strip", "swarm", "box", "violin", "boxen", "point", or "bar" layers : List[str] or str, optional List of layer names to display. gridop : str, optional If provided, the grid reduction operation. Either None, 'sum' or 'av'. Default is 'av'. yaxis_label : str, optional Y-axis label, if not provided canopy will try to retrieve the name of the variable in the metadata. field_labels : List[str], optional Names of each series to display in the legend. unit : List[str], optional Unit of the y-axis variable, if not provided canopy will try to retrieve the unit of the variable in the metadata. title : str, optional Title of the plot. palette : str, optional Seaborn color palette to use for the line colors (https://seaborn.pydata.org/tutorial/color_palettes.html, recommended palette are in https://colorbrewer2.org). custom_palette : List[str], optional Path of custom color palette .txt file to use. Names should match label names. horizontal : bool, optional If True, renders the plot with horizontal orientation (flips the axes). vertical_xlabels : bool, optional If True, rotates the x-axis tick labels vertically (i.e., 90 degrees). x_label_rotation : float, optional Rotation angle in degrees for the x-axis tick labels. Overrides vertical_xlabels when non-zero. Default is 0. move_legend : bool, optional Move the legend outside of plot. Default is False. dark_mode : bool, optional If True, applies dark mode styling to the figure. Default is False. transparent : bool, optional If True, sets the figure background to be transparent when saved. Default is False. x_fig : float, optional Width of the figure in inches. Default is 10. y_fig : float, optional Height of the figure in inches. Default is 10. subfig : matplotlib.figure.SubFigure, optional If provided, the plot will be created in this subfigure instead of creating a new figure. This is used by multiple_figs() to combine multiple plots. User can also provide a plt.figure.subfigure object (https://matplotlib.org/stable/gallery/subplots_axes_and_figures/subfigures.html) return_fig : bool, optional If True, return a callable wrapper function instead of creating the plot immediately. This wrapper can be used with multiple_figs(). Default is False. **kwargs Additional keyword arguments are passed directly to `seaborn.catplot`. This allows customization of plot features such as `aspect`, `errorbar`, height`, etc. """ # If return_fig is True, create a wrapper function and return it if return_fig: return create_wrapper_from_locals(make_distribution_plot, locals()) # Force fields to be a list if isinstance(fields, cp.Field): fields = [fields] if plot_type == "count": raise ValueError("count plot is not supported for distribution plot.") # Force variables to be a list if isinstance(layers, str): layers = [layers] # Retrieve metadata yaxis_label, unit, layers = get_field_metadata(fields, yaxis_label, unit, layers) # Pre-checks if len(fields) > 1 and field_labels is None: raise ValueError("field_labels must be defined when there are more than one field.") field_labels_list = field_labels if field_labels is not None else [" "] # Convert fields to long format df_long = _fields_to_long_format(fields, layers, field_labels_list, gridop) # Get color palette n_classes = len(field_labels_list) colors, color_dict = get_color_palette(n_classes=n_classes, palette=palette, custom_palette=custom_palette) palette_dict = {label: color for label, color in zip(field_labels_list, colors)} x, y = ("value", "category") if horizontal else ("category", "value") # Create figure and axis fig, ax = setup_figure_and_axes(subfig=subfig, x_fig=x_fig, y_fig=y_fig) plot_kwargs = { "data": df_long, "x": x, "y": y, "hue": "series", "palette": palette_dict if palette_dict else palette, "ax": ax } # Update plot kwargs plot_kwargs.update({k: v for k, v in kwargs.items() if k != "subfig"}) plot_kwargs.update(_PLOT_OPTIONS.get(plot_type, {})) if plot_type == "violin" and len(fields) == 2: # Split violin plot for two fields plot_kwargs["split"] = True # Make the distribution plot ax = _plot_distribution(plot_type, plot_kwargs) # Set the axis labels and style axis_label = format_value_label(yaxis_label, unit) set_axis_style(ax, title=title, tick_labelsize=14, x_label=axis_label if horizontal else "", y_label="" if horizontal else axis_label) # Rotate labels if requested to prevent overlap rotation = x_label_rotation if x_label_rotation else (90 if vertical_xlabels else 0) if rotation: ticklabels = ax.get_yticklabels() if horizontal else ax.get_xticklabels() plt.setp(ticklabels, rotation=rotation, ha="center" if horizontal else "right") # Custom legend (colored labels, no box) handles, labels = ax.get_legend_handles_labels() if labels and palette_dict: leg_handles = [plt.Line2D([], [], color=palette_dict.get(l, "black"), marker="", linestyle="") for l in labels] ax.legend(handles=leg_handles, labels=labels, handlelength=0, handletextpad=0, labelcolor=[palette_dict.get(l, "black") for l in labels], loc="best", frameon=False, fontsize=14) if move_legend: sns.move_legend(ax, "center left", bbox_to_anchor=(1, 0.85), fontsize=16) # Apply dark mode if dark_mode: fig, ax = make_dark_mode(fig, ax) # Reapply label colors after move_legend or dark_mode (both can reset them) leg = ax.get_legend() if labels and palette_dict and leg and (move_legend or dark_mode): for text, lbl in zip(leg.get_texts(), labels): text.set_color(palette_dict.get(lbl, "black")) return handle_figure_output(fig, output_file=output_file, transparent=transparent, subfig=subfig)
def _fields_to_long_format( fields: List[cp.Field], layers: List[str], field_labels: List[str], gridop: Optional[str], ) -> pd.DataFrame: """Convert fields to long-format DataFrame (value, series, category).""" rows = [] for i, field in enumerate(fields): label = field_labels[i] if i < len(field_labels) else f"Field {i+1}" field_use = field if gridop and not field.grid.is_reduced("lat") and not field.grid.is_reduced("lon"): field_use = field.reduce_grid(gridop) df = cp.make_lines(field_use, flatten_columns=True, layers=layers) for layer in layers: cols = [layer] if layer in df.columns else (list(df.columns) if len(layers) == 1 else [c for c in df.columns if str(c).startswith(f"{layer} - ")]) if len(layers) == 1 and len(cols) > 1: for col in cols: data = np.asarray(df[col].values).flatten() rows.append(pd.DataFrame({"value": data, "series": label, "category": col})) else: data = np.asarray(df[cols].values).flatten() rows.append(pd.DataFrame({"value": data, "series": label, "category": layer})) return pd.concat(rows, ignore_index=True) def _plot_distribution(plot_type: str, plot_kwargs: dict[str]) -> Axes: """Dispatch to the appropriate seaborn plot function based on plot_type.""" with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=PendingDeprecationWarning, message=".*vert.*") match plot_type: case "box": return sns.boxplot(**plot_kwargs) case "violin": return sns.violinplot(**plot_kwargs) case "boxen": return sns.boxenplot(**plot_kwargs) case "strip": return sns.stripplot(**plot_kwargs) case "swarm": return sns.swarmplot(**plot_kwargs) case "point": return sns.pointplot(**plot_kwargs) case "bar": return sns.barplot(**plot_kwargs) case _: raise ValueError(f"Unsupported plot_type: {plot_type}")