"""Batch effect metrics and diagnostics."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import numpy as np
import numpy.typing as npt
import pandas as pd
from scipy.spatial.distance import pdist
from scipy.stats import chi2, levene, spearmanr
from sklearn.decomposition import PCA
from sklearn.metrics import davies_bouldin_score, silhouette_score
from sklearn.neighbors import NearestNeighbors
from ._utils import _subset
from .core import ArrayLike
if TYPE_CHECKING:
from .sklearn_api import ComBat
def _compute_pca_embedding(
X_before: npt.NDArray[Any],
X_after: npt.NDArray[Any],
n_components: int,
) -> tuple[npt.NDArray[Any], npt.NDArray[Any], PCA]:
"""
Compute PCA embeddings for both datasets.
Fits PCA on X_before and applies to both datasets.
Parameters
----------
X_before : npt.NDArray[Any]
Original data before correction.
X_after : npt.NDArray[Any]
Corrected data.
n_components : int
Number of PCA components.
Returns
-------
X_before_pca : npt.NDArray[Any]
PCA-transformed original data.
X_after_pca : npt.NDArray[Any]
PCA-transformed corrected data.
pca : PCA
Fitted PCA model.
"""
n_components = min(n_components, X_before.shape[1], X_before.shape[0] - 1)
pca = PCA(n_components=n_components, random_state=42)
X_before_pca = pca.fit_transform(X_before)
X_after_pca = pca.transform(X_after)
return X_before_pca, X_after_pca, pca
def _silhouette_batch(X: npt.NDArray[Any], batch_labels: npt.NDArray[Any]) -> float:
"""
Compute silhouette coefficient using batch as cluster labels.
Lower values after correction indicate better batch mixing.
Range: [-1, 1], where -1 = batch mixing, 1 = batch separation.
Parameters
----------
X : npt.NDArray[Any]
Data matrix.
batch_labels : npt.NDArray[Any]
Batch labels for each sample.
Returns
-------
float
Silhouette coefficient.
"""
unique_batches = np.unique(batch_labels)
if len(unique_batches) < 2:
return 0.0
try:
return float(silhouette_score(X, batch_labels, metric="euclidean"))
except Exception:
return 0.0
def _davies_bouldin_batch(X: npt.NDArray[Any], batch_labels: npt.NDArray[Any]) -> float:
"""
Compute Davies-Bouldin index using batch labels.
Lower values indicate better batch mixing.
Range: [0, inf), 0 = perfect batch overlap.
Parameters
----------
X : npt.NDArray[Any]
Data matrix.
batch_labels : npt.NDArray[Any]
Batch labels for each sample.
Returns
-------
float
Davies-Bouldin index.
"""
unique_batches = np.unique(batch_labels)
if len(unique_batches) < 2:
return 0.0
try:
return float(davies_bouldin_score(X, batch_labels))
except Exception:
return 0.0
def _kbet_score(
X: npt.NDArray[Any],
batch_labels: npt.NDArray[Any],
k0: int,
alpha: float = 0.05,
nn_algorithm: str = "auto",
) -> tuple[float, float]:
"""
Compute kBET (k-nearest neighbor Batch Effect Test) acceptance rate.
Tests if local batch proportions match global batch proportions.
Higher acceptance rate = better batch mixing.
Reference: Buttner et al. (2019) Nature Methods
Parameters
----------
X : npt.NDArray[Any]
Data matrix.
batch_labels : npt.NDArray[Any]
Batch labels for each sample.
k0 : int
Neighborhood size.
alpha : float
Significance level for chi-squared test.
nn_algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, default='auto'
Algorithm used for nearest neighbor computation. Passed to
``sklearn.neighbors.NearestNeighbors``.
Returns
-------
acceptance_rate : float
Fraction of samples where H0 (uniform mixing) is accepted.
mean_stat : float
Mean chi-squared statistic across samples.
"""
n_samples = X.shape[0]
unique_batches, batch_counts = np.unique(batch_labels, return_counts=True)
n_batches = len(unique_batches)
if n_batches < 2:
return 1.0, 0.0
global_freq = batch_counts / n_samples
k0 = min(k0, n_samples - 1)
nn = NearestNeighbors(n_neighbors=k0 + 1, algorithm=nn_algorithm)
nn.fit(X)
_, indices = nn.kneighbors(X)
chi2_stats = []
p_values = []
batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
for i in range(n_samples):
neighbors = indices[i, 1 : k0 + 1]
neighbor_batches = batch_labels[neighbors]
observed = np.zeros(n_batches)
for nb in neighbor_batches:
observed[batch_to_idx[nb]] += 1
expected = global_freq * k0
mask = expected > 0
if mask.sum() < 2:
continue
stat = np.sum((observed[mask] - expected[mask]) ** 2 / expected[mask])
df = max(1, mask.sum() - 1)
p_val = 1 - chi2.cdf(stat, df)
chi2_stats.append(stat)
p_values.append(p_val)
if len(p_values) == 0:
return 1.0, 0.0
acceptance_rate = np.mean(np.array(p_values) > alpha)
mean_stat = np.mean(chi2_stats)
return float(acceptance_rate), float(mean_stat)
def _find_sigma(distances: npt.NDArray[Any], target_perplexity: float, tol: float = 1e-5) -> float:
"""
Binary search for sigma to achieve target perplexity.
Used in LISI computation.
Parameters
----------
distances : npt.NDArray[Any]
Distances to neighbors.
target_perplexity : float
Target perplexity value.
tol : float
Tolerance for convergence.
Returns
-------
float
Sigma value.
"""
target_H = np.log2(target_perplexity + 1e-10)
sigma_min, sigma_max = 1e-10, 1e10
sigma = 1.0
for _ in range(50):
P = np.exp(-(distances**2) / (2 * sigma**2 + 1e-10))
P_sum = P.sum()
if P_sum < 1e-10:
sigma_min = sigma
sigma = (sigma_min + sigma_max) / 2
continue
P = P / P_sum
P = np.clip(P, 1e-10, 1.0)
H = -np.sum(P * np.log2(P))
if abs(H - target_H) < tol:
break
elif target_H > H:
sigma_min = sigma
else:
sigma_max = sigma
sigma = (sigma_min + sigma_max) / 2
return sigma
def _lisi_score(
X: npt.NDArray[Any],
batch_labels: npt.NDArray[Any],
perplexity: int = 30,
nn_algorithm: str = "auto",
) -> float:
"""
Compute mean Local Inverse Simpson's Index (LISI).
Range: [1, n_batches], where n_batches = perfect mixing.
Higher = better batch mixing.
Reference: Korsunsky et al. (2019) Nature Methods (Harmony paper)
Parameters
----------
X : npt.NDArray[Any]
Data matrix.
batch_labels : npt.NDArray[Any]
Batch labels for each sample.
perplexity : int
Perplexity for Gaussian kernel.
nn_algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, default='auto'
Algorithm used for nearest neighbor computation. Passed to
``sklearn.neighbors.NearestNeighbors``.
Returns
-------
float
Mean LISI score.
"""
n_samples = X.shape[0]
unique_batches = np.unique(batch_labels)
n_batches = len(unique_batches)
batch_to_idx = {b: i for i, b in enumerate(unique_batches)}
if n_batches < 2:
return 1.0
k = min(3 * perplexity, n_samples - 1)
nn = NearestNeighbors(n_neighbors=k + 1, algorithm=nn_algorithm)
nn.fit(X)
distances, indices = nn.kneighbors(X)
distances = distances[:, 1:]
indices = indices[:, 1:]
lisi_values = []
for i in range(n_samples):
sigma = _find_sigma(distances[i], perplexity)
P = np.exp(-(distances[i] ** 2) / (2 * sigma**2 + 1e-10))
P_sum = P.sum()
if P_sum < 1e-10:
lisi_values.append(1.0)
continue
P = P / P_sum
neighbor_batches = batch_labels[indices[i]]
batch_probs = np.zeros(n_batches)
for j, nb in enumerate(neighbor_batches):
batch_probs[batch_to_idx[nb]] += P[j]
simpson = np.sum(batch_probs**2)
lisi = n_batches if simpson < 1e-10 else 1.0 / simpson
lisi_values.append(lisi)
return float(np.mean(lisi_values))
def _variance_ratio(X: npt.NDArray[Any], batch_labels: npt.NDArray[Any]) -> float:
"""
Compute between-batch to within-batch variance ratio.
Similar to F-statistic in one-way ANOVA.
Lower ratio after correction = better batch effect removal.
Parameters
----------
X : npt.NDArray[Any]
Data matrix.
batch_labels : npt.NDArray[Any]
Batch labels for each sample.
Returns
-------
float
Variance ratio (between/within).
"""
unique_batches = np.unique(batch_labels)
n_batches = len(unique_batches)
n_samples = X.shape[0]
if n_batches < 2:
return 0.0
grand_mean = np.mean(X, axis=0)
between_var = 0.0
within_var = 0.0
for batch in unique_batches:
mask = batch_labels == batch
n_b = np.sum(mask)
X_batch = X[mask]
batch_mean = np.mean(X_batch, axis=0)
between_var += n_b * np.sum((batch_mean - grand_mean) ** 2)
within_var += np.sum((X_batch - batch_mean) ** 2)
between_var /= n_batches - 1
within_var /= n_samples - n_batches
if within_var < 1e-10:
return 0.0
return between_var / within_var
def _knn_preservation(
X_before: npt.NDArray[Any],
X_after: npt.NDArray[Any],
k_values: list[int],
n_jobs: int = 1,
nn_algorithm: str = "auto",
) -> dict[int, float]:
"""
Compute fraction of k-nearest neighbors preserved after correction.
Range: [0, 1], where 1 = perfect preservation.
Higher = better biological structure preservation.
Parameters
----------
X_before : npt.NDArray[Any]
Original data.
X_after : npt.NDArray[Any]
Corrected data.
k_values : list of int
Values of k for k-NN.
n_jobs : int
Number of parallel jobs.
nn_algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, default='auto'
Algorithm used for nearest neighbor computation. Passed to
``sklearn.neighbors.NearestNeighbors``.
Returns
-------
dict
Mapping from k to preservation fraction.
"""
results = {}
max_k = max(k_values)
max_k = min(max_k, X_before.shape[0] - 1)
nn_before = NearestNeighbors(n_neighbors=max_k + 1, algorithm=nn_algorithm, n_jobs=n_jobs)
nn_before.fit(X_before)
_, indices_before = nn_before.kneighbors(X_before)
nn_after = NearestNeighbors(n_neighbors=max_k + 1, algorithm=nn_algorithm, n_jobs=n_jobs)
nn_after.fit(X_after)
_, indices_after = nn_after.kneighbors(X_after)
for k in k_values:
if k > max_k:
results[k] = 0.0
continue
overlaps = []
for i in range(X_before.shape[0]):
neighbors_before = set(indices_before[i, 1 : k + 1])
neighbors_after = set(indices_after[i, 1 : k + 1])
overlap = len(neighbors_before & neighbors_after) / k
overlaps.append(overlap)
results[k] = float(np.mean(overlaps))
return results
def _pairwise_distance_correlation(
X_before: npt.NDArray[Any],
X_after: npt.NDArray[Any],
subsample: int = 1000,
random_state: int = 42,
) -> float:
"""
Compute Spearman correlation of pairwise distances.
Range: [-1, 1], where 1 = perfect rank preservation.
Higher = better relative relationship preservation.
Parameters
----------
X_before : npt.NDArray[Any]
Original data.
X_after : npt.NDArray[Any]
Corrected data.
subsample : int
Maximum samples to use (for efficiency).
random_state : int
Random seed for subsampling.
Returns
-------
float
Spearman correlation coefficient.
"""
n_samples = X_before.shape[0]
if n_samples > subsample:
rng = np.random.default_rng(random_state)
idx = rng.choice(n_samples, subsample, replace=False)
X_before = X_before[idx]
X_after = X_after[idx]
dist_before = pdist(X_before, metric="euclidean")
dist_after = pdist(X_after, metric="euclidean")
if len(dist_before) == 0:
return 1.0
corr, _ = spearmanr(dist_before, dist_after)
if np.isnan(corr):
return 1.0
return float(corr)
def _mean_centroid_distance(X: npt.NDArray[Any], batch_labels: npt.NDArray[Any]) -> float:
"""
Compute mean pairwise Euclidean distance between batch centroids.
Lower after correction = better batch alignment.
Parameters
----------
X : npt.NDArray[Any]
Data matrix.
batch_labels : npt.NDArray[Any]
Batch labels for each sample.
Returns
-------
float
Mean pairwise distance between centroids.
"""
unique_batches = np.unique(batch_labels)
n_batches = len(unique_batches)
if n_batches < 2:
return 0.0
centroids = []
for batch in unique_batches:
mask = batch_labels == batch
centroid = np.mean(X[mask], axis=0)
centroids.append(centroid)
centroids_arr = np.array(centroids)
distances = pdist(centroids_arr, metric="euclidean")
return float(np.mean(distances))
def _levene_median_statistic(X: npt.NDArray[Any], batch_labels: npt.NDArray[Any]) -> float:
"""
Compute median Levene test statistic across features.
Lower statistic = more homogeneous variances across batches.
Parameters
----------
X : npt.NDArray[Any]
Data matrix.
batch_labels : npt.NDArray[Any]
Batch labels for each sample.
Returns
-------
float
Median Levene test statistic.
"""
unique_batches = np.unique(batch_labels)
if len(unique_batches) < 2:
return 0.0
levene_stats = []
for j in range(X.shape[1]):
groups = [X[batch_labels == b, j] for b in unique_batches]
groups = [g for g in groups if len(g) > 0]
if len(groups) < 2:
continue
try:
stat, _ = levene(*groups, center="median")
if not np.isnan(stat):
levene_stats.append(stat)
except Exception:
continue
if len(levene_stats) == 0:
return 0.0
return float(np.median(levene_stats))
def _compute_batch_effect_metrics(
X_before: npt.NDArray[Any],
X_after: npt.NDArray[Any],
batch_labels: npt.NDArray[Any],
*,
kbet_k0: int,
lisi_perplexity: int,
nn_algorithm: str,
) -> dict[str, Any]:
"""Compute batch effect metrics (silhouette, DB, kBET, LISI, variance ratio)."""
n_batches = len(np.unique(batch_labels))
silhouette_before = _silhouette_batch(X_before, batch_labels)
silhouette_after = _silhouette_batch(X_after, batch_labels)
db_before = _davies_bouldin_batch(X_before, batch_labels)
db_after = _davies_bouldin_batch(X_after, batch_labels)
kbet_before, _ = _kbet_score(X_before, batch_labels, kbet_k0, nn_algorithm=nn_algorithm)
kbet_after, _ = _kbet_score(X_after, batch_labels, kbet_k0, nn_algorithm=nn_algorithm)
lisi_before = _lisi_score(X_before, batch_labels, lisi_perplexity, nn_algorithm=nn_algorithm)
lisi_after = _lisi_score(X_after, batch_labels, lisi_perplexity, nn_algorithm=nn_algorithm)
var_ratio_before = _variance_ratio(X_before, batch_labels)
var_ratio_after = _variance_ratio(X_after, batch_labels)
return {
"silhouette": {"before": silhouette_before, "after": silhouette_after},
"davies_bouldin": {"before": db_before, "after": db_after},
"kbet": {"before": kbet_before, "after": kbet_after},
"lisi": {"before": lisi_before, "after": lisi_after, "max_value": n_batches},
"variance_ratio": {"before": var_ratio_before, "after": var_ratio_after},
}
def _compute_preservation_metrics(
X_before: npt.NDArray[Any],
X_after: npt.NDArray[Any],
k_neighbors: list[int],
n_jobs: int,
*,
nn_algorithm: str,
) -> dict[str, Any]:
"""Compute structure preservation metrics (kNN preservation, distance correlation)."""
knn_results = _knn_preservation(
X_before, X_after, k_neighbors, n_jobs, nn_algorithm=nn_algorithm
)
dist_corr = _pairwise_distance_correlation(X_before, X_after)
return {
"knn": knn_results,
"distance_correlation": dist_corr,
}
def _compute_alignment_metrics(
X_before_orig: npt.NDArray[Any],
X_after_orig: npt.NDArray[Any],
X_before_pca: npt.NDArray[Any],
X_after_pca: npt.NDArray[Any],
batch_labels: npt.NDArray[Any],
) -> dict[str, Any]:
"""Compute batch alignment metrics (centroid distance, Levene statistic)."""
centroid_before = _mean_centroid_distance(X_before_pca, batch_labels)
centroid_after = _mean_centroid_distance(X_after_pca, batch_labels)
# Levene test operates on the original feature space, not PCA-reduced
levene_before = _levene_median_statistic(X_before_orig, batch_labels)
levene_after = _levene_median_statistic(X_after_orig, batch_labels)
return {
"centroid_distance": {"before": centroid_before, "after": centroid_after},
"levene_statistic": {"before": levene_before, "after": levene_after},
}
[docs]
def compute_batch_metrics(
combat: ComBat,
X: ArrayLike,
batch: ArrayLike | None = None,
*,
pca_components: int | None = None,
k_neighbors: list[int] | None = None,
kbet_k0: int | None = None,
lisi_perplexity: int = 30,
n_jobs: int = 1,
nn_algorithm: str = "auto",
) -> dict[str, Any]:
"""
Compute batch effect metrics before and after ComBat correction.
Parameters
----------
combat : ComBat
A fitted ``ComBat`` instance.
X : array-like of shape (n_samples, n_features)
Input data to evaluate.
batch : array-like of shape (n_samples,), optional
Batch labels. If None, uses the batch stored at construction.
pca_components : int, optional
Number of PCA components for dimensionality reduction before
computing metrics. If None (default), metrics are computed in
the original feature space. Must be less than min(n_samples, n_features).
k_neighbors : list of int, default=[5, 10, 50]
Values of k for k-NN preservation metric.
kbet_k0 : int, optional
Neighborhood size for kBET. Default is 10% of samples.
lisi_perplexity : int, default=30
Perplexity for LISI computation.
n_jobs : int, default=1
Number of parallel jobs for neighbor computations.
nn_algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, default='auto'
Algorithm used for nearest neighbor computation. Passed to
``sklearn.neighbors.NearestNeighbors``.
Returns
-------
dict
Dictionary with three main keys:
- ``batch_effect``: Silhouette, Davies-Bouldin, kBET, LISI, variance ratio
(each with 'before' and 'after' values)
- ``preservation``: k-NN preservation fractions, distance correlation
- ``alignment``: Centroid distance, Levene statistic (each with
'before' and 'after' values)
Raises
------
ValueError
If the model is not fitted or if pca_components is invalid.
"""
_valid_nn = {"auto", "ball_tree", "kd_tree", "brute"}
if nn_algorithm not in _valid_nn:
raise ValueError(f"nn_algorithm must be one of {_valid_nn}, got '{nn_algorithm}'")
if not hasattr(combat, "_model") or not hasattr(combat._model, "_gamma_star"):
raise ValueError(
"This ComBat instance is not fitted yet. Call 'fit' before 'compute_batch_metrics'."
)
if not isinstance(X, pd.DataFrame):
X = pd.DataFrame(X)
idx = X.index
if batch is None:
batch_vec = _subset(combat.batch, idx)
else:
if isinstance(batch, pd.Series | pd.DataFrame):
batch_vec = batch.loc[idx] if hasattr(batch, "loc") else batch
elif isinstance(batch, np.ndarray):
batch_vec = pd.Series(batch, index=idx)
else:
batch_vec = pd.Series(batch, index=idx)
batch_labels = np.array(batch_vec)
X_before = X.values
X_after = combat.transform(X).values
n_samples, n_features = X_before.shape
if kbet_k0 is None:
kbet_k0 = max(10, int(0.10 * n_samples))
if k_neighbors is None:
k_neighbors = [5, 10, 50]
# Validate and apply PCA if requested
if pca_components is not None:
max_components = min(n_samples, n_features)
if pca_components >= max_components:
raise ValueError(
f"pca_components={pca_components} must be less than "
f"min(n_samples, n_features)={max_components}."
)
X_before_pca, X_after_pca, _ = _compute_pca_embedding(X_before, X_after, pca_components)
else:
X_before_pca = X_before
X_after_pca = X_after
batch_effect = _compute_batch_effect_metrics(
X_before_pca,
X_after_pca,
batch_labels,
kbet_k0=kbet_k0,
lisi_perplexity=lisi_perplexity,
nn_algorithm=nn_algorithm,
)
preservation = _compute_preservation_metrics(
X_before_pca,
X_after_pca,
k_neighbors,
n_jobs,
nn_algorithm=nn_algorithm,
)
alignment = _compute_alignment_metrics(
X_before,
X_after,
X_before_pca,
X_after_pca,
batch_labels,
)
return {
"batch_effect": batch_effect,
"preservation": preservation,
"alignment": alignment,
}