import os
import warnings
from typing import Optional, Any, List, Tuple
import matplotlib.pyplot as plt
import seaborn as sns
def _has_out_of_bounds_artists(fig: plt.Figure) -> bool:
# Handle Seaborn FacetGrid/PairGrid
if hasattr(fig, "axes"):
axes = fig.axes
# FacetGrid.axes is a numpy array, flatten it
if hasattr(axes, "flat"):
axes = axes.flat
# Handle matplotlib Figure
elif hasattr(fig, "get_axes"):
axes = fig.get_axes()
else:
return False
for ax in axes:
xlim = ax.get_xlim()
ylim = ax.get_ylim()
for line in ax.get_lines():
xdata = line.get_xdata()
ydata = line.get_ydata()
if ((xdata < xlim[0]).any() or (xdata > xlim[1]).any() or
(ydata < ylim[0]).any() or (ydata > ylim[1]).any()):
return True
return False
[docs]
def get_color_palette(
n_classes: int,
palette: Optional[str] = None,
custom_palette: Optional[str] = None,
) -> Tuple[List[Any], Optional[dict]]:
"""
Generate a color palette for plotting based on either a ColorBrewer palette or a custom palette file.
"""
if custom_palette:
palette_dict = {}
with open(custom_palette, 'r') as file:
lines = file.readlines()
if len(lines) != n_classes:
raise ValueError(f"Custom palette file has {len(lines)} lines, but {n_classes} classes are required.")
for line in lines:
parts = line.strip().split()
if len(parts) == 2:
label, color = parts
palette_dict[label] = color
else:
raise ValueError("Custom palette provided should have two elements maximum per line.")
# Extract colors from the dictionary
palette = [palette_dict[label] for label in palette_dict]
else:
if palette:
palette = sns.color_palette(palette, n_colors=n_classes)
else:
# Get the base tab20 palette (20 colors)
base_palette = sns.color_palette("tab20", n_colors=20)
# Loop through the palette if more than 20 classes are needed
if n_classes > 20:
warnings.warn(f"Requested {n_classes} classes, but tab20 palette only has 20 colors. Colors will be repeated cyclically. Consider using a custom palette with custom_palette for better distinction.", UserWarning)
palette = [base_palette[i % 20] for i in range(n_classes)]
palette_dict = None
return palette, palette_dict
[docs]
def set_axis_style(
ax: Any,
title: Optional[str] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
tick_labelsize: int = 12,
) -> None:
"""Set axis labels, title, tick size, and despine."""
if title:
ax.set_title(title, fontsize=18, pad=20)
if x_label is not None:
ax.set_xlabel(x_label, fontsize=14)
if y_label is not None:
ax.set_ylabel(y_label, fontsize=16)
ax.tick_params(axis='both', labelsize=tick_labelsize)
sns.despine(ax=ax)
[docs]
def make_dark_mode(
fig: plt.Figure,
ax: Any,
legend_style: Optional[str] = None,
cbar: Optional[Any] = None,
gridlines: Optional[Any] = None,
) -> Tuple[plt.Figure, Any]:
"""
Apply dark mode styling to the given figure and axis.
"""
dark_gray = '#1F1F1F'
fig.patch.set_facecolor(dark_gray)
ax.set_facecolor(dark_gray)
ax.xaxis.label.set_color('white')
ax.yaxis.label.set_color('white')
ax.title.set_color('white')
ax.tick_params(axis='x', colors='white')
ax.tick_params(axis='y', colors='white')
for spine in ax.spines.values():
spine.set_edgecolor('white')
if gridlines:
gridlines.xlabel_style = {'color': 'white'}
gridlines.ylabel_style = {'color': 'white'}
if cbar:
cbar.ax.xaxis.label.set_color('white')
cbar.ax.tick_params(axis='x', colors='white')
cbar.outline.set_edgecolor('white')
legend = ax.get_legend()
if legend:
if legend_style is None or legend_style == 'default':
for text in legend.get_texts():
text.set_color('white')
return fig, ax