"""Visualization utilities for ComBat batch correction."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal
import numpy as np
import numpy.typing as npt
import pandas as pd
from ._utils import _subset
from .core import ArrayLike, FloatArray
from .inspection import feature_batch_diagnostics
if TYPE_CHECKING:
from .sklearn_api import ComBat
def _scatter_panel(
ax: Any,
X: FloatArray,
batch_labels: pd.Series,
unique_batches: pd.Series,
colors: npt.NDArray[Any],
n_components: int,
point_size: int,
alpha: float,
) -> None:
"""Scatter one panel (before or after) onto the given axes."""
for i, batch in enumerate(unique_batches):
mask = batch_labels == batch
coords = [X[mask, j] for j in range(n_components)]
ax.scatter(
*coords,
c=[colors[i]],
s=point_size,
alpha=alpha,
label=f"Batch {batch}",
edgecolors="black",
linewidth=0.5,
)
def _create_static_plot(
X_orig: FloatArray,
X_trans: FloatArray,
batch_labels: pd.Series,
method: str,
n_components: int,
figsize: tuple[int, int],
alpha: float,
point_size: int,
cmap: str,
title: str | None,
show_legend: bool,
) -> Any:
"""Create static plots using matplotlib."""
import matplotlib
import matplotlib.pyplot as plt
fig = plt.figure(figsize=figsize)
unique_batches = batch_labels.drop_duplicates()
n_batches = len(unique_batches)
if n_batches <= 10:
colors = matplotlib.colormaps.get_cmap(cmap)(np.linspace(0, 1, n_batches))
else:
colors = matplotlib.colormaps.get_cmap("tab20")(np.linspace(0, 1, n_batches))
if n_components == 2:
ax1 = plt.subplot(1, 2, 1)
ax2 = plt.subplot(1, 2, 2)
else:
ax1 = fig.add_subplot(121, projection="3d")
ax2 = fig.add_subplot(122, projection="3d")
scatter_kwargs = {
"batch_labels": batch_labels,
"unique_batches": unique_batches,
"colors": colors,
"n_components": n_components,
"point_size": point_size,
"alpha": alpha,
}
_scatter_panel(ax1, X_orig, **scatter_kwargs)
_scatter_panel(ax2, X_trans, **scatter_kwargs)
method_upper = method.upper()
for ax, label in [(ax1, "Before"), (ax2, "After")]:
ax.set_title(f"{label} ComBat correction\n({method_upper})")
ax.set_xlabel(f"{method_upper}1")
ax.set_ylabel(f"{method_upper}2")
if n_components == 3:
ax.set_zlabel(f"{method_upper}3") # type: ignore[attr-defined]
if show_legend and n_batches <= 20:
ax2.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
if title is None:
title = f"ComBat correction effect visualized with {method_upper}"
fig.suptitle(title, fontsize=14, fontweight="bold")
plt.tight_layout()
return fig
def _add_plotly_traces(
fig: Any,
X: FloatArray,
batch_labels: pd.Series,
unique_batches: pd.Series,
batch_to_color: dict[Any, Any],
n_components: int,
col: int,
showlegend: bool,
) -> None:
"""Add plotly scatter traces for one panel (before or after)."""
import plotly.graph_objects as go
for batch in unique_batches:
mask = batch_labels == batch
if n_components == 2:
fig.add_trace(
go.Scatter(
x=X[mask, 0],
y=X[mask, 1],
mode="markers",
name=f"Batch {batch}",
marker={
"size": 8,
"color": batch_to_color[batch],
"line": {"width": 1, "color": "black"},
},
showlegend=showlegend,
),
row=1,
col=col,
)
else:
fig.add_trace(
go.Scatter3d(
x=X[mask, 0],
y=X[mask, 1],
z=X[mask, 2],
mode="markers",
name=f"Batch {batch}",
marker={
"size": 5,
"color": batch_to_color[batch],
"line": {"width": 0.5, "color": "black"},
},
showlegend=showlegend,
),
row=1,
col=col,
)
def _create_interactive_plot(
X_orig: FloatArray,
X_trans: FloatArray,
batch_labels: pd.Series,
method: str,
n_components: int,
cmap: str,
title: str | None,
show_legend: bool,
) -> Any:
"""Create interactive plots using plotly."""
import matplotlib
import matplotlib.colors as mcolors
from plotly.subplots import make_subplots
method_upper = method.upper()
if n_components == 2:
fig = make_subplots(
rows=1,
cols=2,
subplot_titles=(
f"Before ComBat correction ({method_upper})",
f"After ComBat correction ({method_upper})",
),
)
else:
fig = make_subplots(
rows=1,
cols=2,
specs=[[{"type": "scatter3d"}, {"type": "scatter3d"}]],
subplot_titles=(
f"Before ComBat correction ({method_upper})",
f"After ComBat correction ({method_upper})",
),
)
unique_batches = batch_labels.drop_duplicates()
n_batches = len(unique_batches)
cmap_func = matplotlib.colormaps.get_cmap(cmap)
color_list = [mcolors.to_hex(cmap_func(i / max(n_batches - 1, 1))) for i in range(n_batches)]
batch_to_color = dict(zip(unique_batches, color_list, strict=True))
trace_kwargs = {
"batch_labels": batch_labels,
"unique_batches": unique_batches,
"batch_to_color": batch_to_color,
"n_components": n_components,
}
_add_plotly_traces(fig, X_orig, **trace_kwargs, col=1, showlegend=False) # type: ignore[arg-type]
_add_plotly_traces(fig, X_trans, **trace_kwargs, col=2, showlegend=show_legend) # type: ignore[arg-type]
if title is None:
title = f"ComBat correction effect visualized with {method_upper}"
fig.update_layout(
title=title,
title_font_size=16,
height=600,
showlegend=show_legend,
hovermode="closest",
)
axis_labels = [f"{method_upper}{i + 1}" for i in range(n_components)]
if n_components == 2:
fig.update_xaxes(title_text=axis_labels[0])
fig.update_yaxes(title_text=axis_labels[1])
else:
fig.update_scenes(
xaxis_title=axis_labels[0],
yaxis_title=axis_labels[1],
zaxis_title=axis_labels[2],
)
return fig
[docs]
def plot_feature_diagnostics(
combat: ComBat,
top_n: int = 20,
kind: Literal["location", "scale", "combined"] = "combined",
mode: Literal["magnitude", "distribution"] = "magnitude",
weighted: bool = True,
figsize: tuple[int, int] = (8, 10),
) -> Any:
"""Plot top features affected by batch effects.
Parameters
----------
combat : ComBat
A fitted ``ComBat`` instance.
top_n : int, default=20
Number of top features to display.
kind : {'location', 'scale', 'combined'}, default='combined'
- 'location': bar plot of location (mean shift) contribution only
- 'scale': bar plot of scale (variance) contribution only
- 'combined': grouped bar plot showing location and scale
side-by-side for each feature (sorted by Euclidean magnitude).
In magnitude mode: bars reflect Euclidean decomposition
(combined**2 = location**2 + scale**2).
In distribution mode: bars reflect independent normalized
contributions (each sums to 1 separately).
mode : {'magnitude', 'distribution'}, default='magnitude'
- 'magnitude': y-axis shows absolute batch effect magnitude
- 'distribution': y-axis shows relative contribution (proportion), includes
annotation showing cumulative contribution of top_n features
(e.g., "Top 20 features explain 75% of total batch effect")
weighted : bool, default=True
If True, batch effects are weighted by batch sample size.
Passed to :func:`feature_batch_diagnostics`.
figsize : tuple, default=(8,10)
Figure size (width, height) in inches.
Returns
-------
matplotlib.figure.Figure
The figure object containing the plot.
Raises
------
ValueError
If the model is not fitted, or if kind/mode is invalid.
"""
import matplotlib.pyplot as plt
if not hasattr(combat, "_model") or not hasattr(combat._model, "_gamma_star"):
raise ValueError(
"This ComBat instance is not fitted yet. Call 'fit' before 'plot_feature_diagnostics'."
)
if kind not in ["location", "scale", "combined"]:
raise ValueError(f"kind must be 'location', 'scale', or 'combined', got '{kind}'")
if mode not in ["magnitude", "distribution"]:
raise ValueError(f"mode must be 'magnitude' or 'distribution', got '{mode}'")
importance_df = feature_batch_diagnostics(combat, mode=mode, weighted=weighted)
top_features = importance_df.head(top_n)
# Reverse so highest values are at the top of the horizontal bar plot
top_features = top_features.iloc[::-1]
fig, ax = plt.subplots(figsize=figsize)
if kind == "combined":
# Grouped horizontal bar plot showing location and scale side-by-side
y = np.arange(len(top_features))
height = 0.35
ax.barh(
y + height / 2,
top_features["location"],
height,
label="Location",
color="steelblue",
edgecolor="black",
linewidth=0.5,
)
ax.barh(
y - height / 2,
top_features["scale"],
height,
label="Scale",
color="coral",
edgecolor="black",
linewidth=0.5,
)
ax.set_yticks(y)
ax.set_yticklabels(top_features.index)
ax.legend()
else:
# Single horizontal bar plot for location or scale
color = "steelblue" if kind == "location" else "coral"
ax.barh(
range(len(top_features)),
top_features[kind],
color=color,
edgecolor="black",
linewidth=0.5,
)
ax.set_yticks(range(len(top_features)))
ax.set_yticklabels(top_features.index)
# Set labels and title
ax.set_ylabel("Feature")
if mode == "magnitude":
ax.set_xlabel("Batch Effect Magnitude (RMS)")
title_str = f"Top {top_n} Features by Batch Effect"
else:
ax.set_xlabel("Relative Contribution")
title_str = f"Top {top_n} Features by Batch Effect (Distribution)"
if kind == "combined":
title_str += " (Location & Scale)"
else:
title_str += f" ({kind.capitalize()})"
ax.set_title(title_str)
plt.tight_layout()
# For distribution mode, print cumulative contribution
if mode == "distribution":
if kind == "combined":
cumulative_pct = top_features["combined"].sum() * 100
effect_label = "batch effect"
elif kind == "location":
cumulative_pct = top_features["location"].sum() * 100
effect_label = "location effect"
else: # scale
cumulative_pct = top_features["scale"].sum() * 100
effect_label = "scale effect"
print(f"Top {top_n} features explain {cumulative_pct:.1f}% of total {effect_label}")
return fig
[docs]
def plot_batch_effect_heatmap(
combat: ComBat,
top_n: int = 50,
weighted: bool = True,
figsize: tuple[int, int] = (12, 8),
) -> Any:
"""Plot a heatmap of batch effect parameters across features and batches.
Displays the estimated batch-specific location shifts (gamma) and,
unless ``mean_only=True``, log-scale shifts (log delta) for the
``top_n`` most affected features.
Parameters
----------
combat : ComBat
A fitted ``ComBat`` instance.
top_n : int, default=50
Number of top features (by combined batch effect) to display.
weighted : bool, default=True
If True, feature ranking uses sample-size-weighted batch effects.
Passed to :func:`feature_batch_diagnostics`.
figsize : tuple of int, default=(12, 8)
Figure size in inches.
Returns
-------
matplotlib.figure.Figure
Figure containing the heatmap(s).
Raises
------
ValueError
If the model is not fitted.
ImportError
If seaborn is not installed.
"""
import matplotlib.pyplot as plt
import seaborn as sns
if not hasattr(combat, "_model") or not hasattr(combat._model, "_gamma_star"):
raise ValueError(
"This ComBat instance is not fitted yet. Call 'fit' before 'plot_batch_effect_heatmap'."
)
feature_names = combat._model._grand_mean.index
batch_levels = combat._model._batch_levels
gamma_star = combat._model._gamma_star
delta_star = combat._model._delta_star
importance = feature_batch_diagnostics(combat, weighted=weighted)
top_features = importance.head(top_n).index
feat_idx = np.array([feature_names.get_loc(f) for f in top_features])
gamma_df = pd.DataFrame(
gamma_star[:, feat_idx],
index=[str(b) for b in batch_levels],
columns=top_features,
)
n_plots = 1 if combat.mean_only else 2
fig, axes = plt.subplots(1, n_plots, figsize=figsize)
if n_plots == 1:
axes = [axes]
sns.heatmap(
gamma_df,
cmap="RdBu_r",
center=0,
ax=axes[0],
xticklabels=True,
yticklabels=True,
)
axes[0].set_title("Location shifts (gamma)")
axes[0].set_xlabel("Feature")
axes[0].set_ylabel("Batch")
if not combat.mean_only:
log_delta_df = pd.DataFrame(
np.log(delta_star[:, feat_idx]),
index=[str(b) for b in batch_levels],
columns=top_features,
)
sns.heatmap(
log_delta_df,
cmap="RdBu_r",
center=0,
ax=axes[1],
xticklabels=True,
yticklabels=True,
)
axes[1].set_title("Scale shifts (log delta)")
axes[1].set_xlabel("Feature")
axes[1].set_ylabel("Batch")
fig.suptitle(f"Batch Effect Heatmap (top {top_n} features)", fontsize=14, fontweight="bold")
plt.tight_layout()
return fig