Source code for canopy.visualization.visualization_helpers

import os
import warnings
from typing import Optional, Any, List, Tuple

import matplotlib.pyplot as plt
import seaborn as sns

[docs] def handle_figure_output( fig: plt.Figure, output_file: Optional[str] = None, transparent: bool = False, subfig=None, ) -> Optional[plt.Figure]: """ Figure handler: save or show. """ # If subfig was provided, the plot is already in the parent figure, so return None if subfig is not None: return None if output_file: # Only use bbox_inches='tight' if nothing is out of bounds if _has_out_of_bounds_artists(fig): save_figure_png(output_file, bbox_inches=None, transparent=transparent) else: save_figure_png(output_file, bbox_inches='tight', transparent=transparent) plt.close() else: plt.show() return fig if output_file is None else None
def _has_out_of_bounds_artists(fig: plt.Figure) -> bool: # Handle Seaborn FacetGrid/PairGrid if hasattr(fig, "axes"): axes = fig.axes # FacetGrid.axes is a numpy array, flatten it if hasattr(axes, "flat"): axes = axes.flat # Handle matplotlib Figure elif hasattr(fig, "get_axes"): axes = fig.get_axes() else: return False for ax in axes: xlim = ax.get_xlim() ylim = ax.get_ylim() for line in ax.get_lines(): xdata = line.get_xdata() ydata = line.get_ydata() if ((xdata < xlim[0]).any() or (xdata > xlim[1]).any() or (ydata < ylim[0]).any() or (ydata > ylim[1]).any()): return True return False
[docs] def save_figure_png( output_file: str, bbox_inches: Optional[str] = None, transparent: bool = False, ) -> None: """ Save the current matplotlib figure as a PNG file. """ # Ensure the extension is .png base, _ = os.path.splitext(output_file) output_file = f"{base}.png" # Create directory if it doesn't exist directory = os.path.dirname(output_file) if directory and not os.path.exists(directory): os.makedirs(directory, exist_ok=True) # Save the figure plt.savefig(output_file, format="png", dpi=300, bbox_inches=bbox_inches, transparent=transparent)
[docs] def get_color_palette( n_classes: int, palette: Optional[str] = None, custom_palette: Optional[str] = None, ) -> Tuple[List[Any], Optional[dict]]: """ Generate a color palette for plotting based on either a ColorBrewer palette or a custom palette file. """ if custom_palette: palette_dict = {} with open(custom_palette, 'r') as file: lines = file.readlines() if len(lines) != n_classes: raise ValueError(f"Custom palette file has {len(lines)} lines, but {n_classes} classes are required.") for line in lines: parts = line.strip().split() if len(parts) == 2: label, color = parts palette_dict[label] = color else: raise ValueError("Custom palette provided should have two elements maximum per line.") # Extract colors from the dictionary palette = [palette_dict[label] for label in palette_dict] else: if palette: palette = sns.color_palette(palette, n_colors=n_classes) else: # Get the base tab20 palette (20 colors) base_palette = sns.color_palette("tab20", n_colors=20) # Loop through the palette if more than 20 classes are needed if n_classes > 20: warnings.warn(f"Requested {n_classes} classes, but tab20 palette only has 20 colors. Colors will be repeated cyclically. Consider using a custom palette with custom_palette for better distinction.", UserWarning) palette = [base_palette[i % 20] for i in range(n_classes)] palette_dict = None return palette, palette_dict
[docs] def get_field_metadata( field_or_fields, label: Optional[str] = None, unit: Optional[str] = None, layers: Optional[Any] = None, ) -> Tuple[str, str, Any]: """Retrieve label, unit, and layers from field metadata when not provided.""" field = field_or_fields[0] if isinstance(field_or_fields, (list, tuple)) else field_or_fields label = label or str(field.metadata.get("name", "")) unit = unit or str(field.metadata.get("units", "")) layers = layers or field.layers return label, unit, layers
[docs] def format_value_label(yaxis_label: Optional[str], unit: Optional[str]) -> str: """Format y-axis label with unit, or return label alone if no unit.""" if unit and unit != "[no units]": return f"{yaxis_label} (in {unit})" return yaxis_label or ""
[docs] def set_axis_style( ax: Any, title: Optional[str] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, tick_labelsize: int = 12, ) -> None: """Set axis labels, title, tick size, and despine.""" if title: ax.set_title(title, fontsize=18, pad=20) if x_label is not None: ax.set_xlabel(x_label, fontsize=14) if y_label is not None: ax.set_ylabel(y_label, fontsize=16) ax.tick_params(axis='both', labelsize=tick_labelsize) sns.despine(ax=ax)
[docs] def make_dark_mode( fig: plt.Figure, ax: Any, legend_style: Optional[str] = None, cbar: Optional[Any] = None, gridlines: Optional[Any] = None, ) -> Tuple[plt.Figure, Any]: """ Apply dark mode styling to the given figure and axis. """ dark_gray = '#1F1F1F' fig.patch.set_facecolor(dark_gray) ax.set_facecolor(dark_gray) ax.xaxis.label.set_color('white') ax.yaxis.label.set_color('white') ax.title.set_color('white') ax.tick_params(axis='x', colors='white') ax.tick_params(axis='y', colors='white') for spine in ax.spines.values(): spine.set_edgecolor('white') if gridlines: gridlines.xlabel_style = {'color': 'white'} gridlines.ylabel_style = {'color': 'white'} if cbar: cbar.ax.xaxis.label.set_color('white') cbar.ax.tick_params(axis='x', colors='white') cbar.outline.set_edgecolor('white') legend = ax.get_legend() if legend: if legend_style is None or legend_style == 'default': for text in legend.get_texts(): text.set_color('white') return fig, ax