"""Standalone inspection functions for fitted ComBat models."""
from __future__ import annotations
from typing import TYPE_CHECKING, Literal
import numpy as np
import pandas as pd
if TYPE_CHECKING:
from .sklearn_api import ComBat
[docs]
def feature_batch_diagnostics(
combat: ComBat,
mode: Literal["magnitude", "distribution"] = "magnitude",
weighted: bool = True,
) -> pd.DataFrame:
"""Compute per-feature batch effect magnitude.
Returns a DataFrame with columns ``location``, ``scale``, and
``combined``. Location is the (weighted) RMS of gamma across batches
(standardized mean shifts). Scale is the (weighted) RMS of log-delta
across batches (log-fold variance change). Combined is the Euclidean
norm sqrt(location**2 + scale**2). Using RMS provides L2-consistent
aggregation; using log(delta) ensures symmetry.
Parameters
----------
combat : ComBat
A fitted ``ComBat`` instance.
mode : {'magnitude', 'distribution'}, default='magnitude'
- 'magnitude': Returns L2-consistent absolute batch effect magnitudes.
Suitable for ranking, thresholding, and cross-dataset comparison.
- 'distribution': Returns column-wise normalized proportions (each column
sums to 1, values in range [0, 1]), representing the relative contribution
of each feature to the total location, scale, or combined batch effect.
Note: normalization is applied independently to each column, so the
Euclidean relationship (combined**2 = location**2 + scale**2) no longer holds.
weighted : bool, default=True
If True, compute a weighted RMS where each batch is weighted by its
sample size. This gives more influence to larger batches, producing
a more statistically representative summary. If False, all batches
contribute equally regardless of size.
Returns
-------
pd.DataFrame
DataFrame with index=feature names, columns=['location', 'scale', 'combined'],
sorted by 'combined' descending.
Raises
------
ValueError
If the model is not fitted or if mode is invalid.
"""
if not hasattr(combat, "_model") or not hasattr(combat._model, "_gamma_star"):
raise ValueError(
"This ComBat instance is not fitted yet. Call 'fit' before 'feature_batch_diagnostics'."
)
if mode not in ["magnitude", "distribution"]:
raise ValueError(f"mode must be 'magnitude' or 'distribution', got '{mode}'")
feature_names = combat._model._grand_mean.index
gamma_star = combat._model._gamma_star
delta_star = combat._model._delta_star
if weighted:
# Batch sample sizes as weights, normalized to sum to 1
n_per_batch = combat._model._n_per_batch
weights = np.array(
[n_per_batch[str(lvl)] for lvl in combat._model._batch_levels], dtype=np.float64
)
weights = weights / weights.sum()
# Location effect: weighted RMS of gamma across batches
location = np.sqrt((weights[:, np.newaxis] * gamma_star**2).sum(axis=0))
# Scale effect: weighted RMS of log(delta) across batches
if not combat.mean_only:
scale = np.sqrt((weights[:, np.newaxis] * np.log(delta_star) ** 2).sum(axis=0))
else:
scale = np.zeros_like(location)
else:
# Location effect: unweighted RMS of gamma across batches
location = np.sqrt((gamma_star**2).mean(axis=0))
# Scale effect: unweighted RMS of log(delta) across batches
if not combat.mean_only:
scale = np.sqrt((np.log(delta_star) ** 2).mean(axis=0))
else:
scale = np.zeros_like(location)
# Euclidean to treat location and scale as orthogonal dimensions
combined = np.sqrt(location**2 + scale**2)
if mode == "distribution":
# Normalize each column independently to sum to 1
location_sum = location.sum()
scale_sum = scale.sum()
combined_sum = combined.sum()
location = location / location_sum if location_sum > 0 else location
scale = scale / scale_sum if scale_sum > 0 else scale
combined = combined / combined_sum if combined_sum > 0 else combined
return pd.DataFrame(
{
"location": location,
"scale": scale,
"combined": combined,
},
index=feature_names,
).sort_values("combined", ascending=False)
[docs]
def summary(combat: ComBat) -> str:
"""Return a human-readable diagnostic report after fitting.
Parameters
----------
combat : ComBat
A fitted ``ComBat`` instance.
Returns
-------
str
Multi-line summary string.
Raises
------
ValueError
If the model is not fitted.
"""
if not hasattr(combat, "_model") or not hasattr(combat._model, "_gamma_star"):
raise ValueError("This ComBat instance is not fitted yet. Call 'fit' before 'summary'.")
lines: list[str] = []
lines.append("ComBat Summary")
lines.append("=" * 40)
lines.append(f"Method: {combat.method}")
lines.append(f"Parametric: {combat.parametric}")
lines.append(f"Mean only: {combat.mean_only}")
lines.append(f"Reference batch: {combat.reference_batch or 'None'}")
batch_levels = combat._model._batch_levels
n_per_batch = combat._model._n_per_batch
lines.append(f"Number of batches: {len(batch_levels)}")
lines.append("Samples per batch:")
for lvl in batch_levels:
lines.append(f" {lvl}: {n_per_batch[str(lvl)]}")
n_features = len(combat._model._grand_mean)
lines.append(f"Number of features: {n_features}")
lines.append("")
lines.append("Top 5 features by batch effect (combined):")
importance = feature_batch_diagnostics(combat)
top5 = importance.head(5)
for feat, row in top5.iterrows():
lines.append(f" {feat}: {row['combined']:.4f}")
# Diagnostics table
lines.append("")
lines.append("Diagnostics")
lines.append("=" * 40)
col_w = 34
lines.append(f"{'Metric':<{col_w}}Value")
lines.append(f"{'------':<{col_w}}-----")
if hasattr(combat, "_batch_var_before_"):
lines.append(f"{'Batch var. explained (before)':<{col_w}}{combat._batch_var_before_:.1%}")
if hasattr(combat, "_batch_var_after_"):
lines.append(f"{'Batch var. explained (after)':<{col_w}}{combat._batch_var_after_:.1%}")
design_cond = getattr(combat._model, "_design_cond", None)
if design_cond is not None:
lines.append(f"{'Design matrix condition number':<{col_w}}{design_cond:.1f}")
conv_info = getattr(combat._model, "_convergence_info", [])
if conv_info:
eb_type = "parametric" if combat.parametric else "non-parametric"
lines.append(f"EB convergence ({eb_type}):")
for info in conv_info:
batch_name = info["batch"]
if info["converged"]:
status = f"converged ({info['iterations']} iter)"
else:
max_change = max(info["final_gamma_change"], info["final_delta_change"])
status = f"NOT CONVERGED ({info['iterations']} iter, \u0394={max_change:.2e})"
lines.append(f" {batch_name!s:<{col_w - 2}}{status}")
return "\n".join(lines)