API Reference#
Complete API documentation for combatlearn.
ComBat#
The main scikit-learn compatible transformer for batch effect correction.
- class ComBat(batch, *, discrete_covariates=None, continuous_covariates=None, method='johnson', parametric=True, mean_only=False, reference_batch=None, eps=1e-08, covbat_cov_thresh=0.9)[source]#
Bases:
BaseEstimator,TransformerMixinPipeline-friendly wrapper around ComBatModel.
Stores batch (and optional covariates) passed at construction and appropriately uses them for separate fit and transform.
- Parameters:
batch (array-like of shape (n_samples,)) – Batch labels for each sample.
discrete_covariates (array-like, optional) – Categorical covariates to protect (Fortin/Chen only).
continuous_covariates (array-like, optional) – Continuous covariates to protect (Fortin/Chen only).
method ({'johnson', 'fortin', 'chen'}, default='johnson') – ComBat variant to use.
parametric (bool, default=True) – Use parametric empirical Bayes.
mean_only (bool, default=False) – Adjust only the mean (ignore variance).
reference_batch (str, optional) – Batch level to leave unchanged.
eps (float, default=1e-8) – Numerical jitter for stability.
covbat_cov_thresh (float or int, default=0.9) – CovBat variance threshold for PCs.
- __init__(batch, *, discrete_covariates=None, continuous_covariates=None, method='johnson', parametric=True, mean_only=False, reference_batch=None, eps=1e-08, covbat_cov_thresh=0.9)[source]#
- fit(X, y=None)[source]#
Fit the ComBat model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Input data to fit.
y (None) – Ignored. Present for API compatibility.
- Returns:
self – Fitted estimator.
- Return type:
- transform(X)[source]#
Transform the data using fitted ComBat parameters.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Input data to transform.
- Returns:
X_transformed – Batch-corrected data.
- Return type:
pd.DataFrame
- get_feature_names_out(input_features=None)[source]#
Get output feature names for transform.
- Parameters:
input_features (array-like of str or None, default=None) – Ignored. Present for API compatibility.
- Returns:
feature_names_out – Feature names.
- Return type:
ndarray of str objects
- Raises:
sklearn.exceptions.NotFittedError – If the estimator is not fitted.
Inspection#
Functions for inspecting fitted ComBat models.
Standalone inspection functions for fitted ComBat models.
- feature_batch_diagnostics(combat, mode='magnitude', weighted=True)[source]#
Compute per-feature batch effect magnitude.
Returns a DataFrame with columns
location,scale, andcombined. Location is the (weighted) RMS of gamma across batches (standardized mean shifts). Scale is the (weighted) RMS of log-delta across batches (log-fold variance change). Combined is the Euclidean norm sqrt(location**2 + scale**2). Using RMS provides L2-consistent aggregation; using log(delta) ensures symmetry.- Parameters:
combat (ComBat) – A fitted
ComBatinstance.mode ({'magnitude', 'distribution'}, default='magnitude') –
‘magnitude’: Returns L2-consistent absolute batch effect magnitudes. Suitable for ranking, thresholding, and cross-dataset comparison.
’distribution’: Returns column-wise normalized proportions (each column sums to 1, values in range [0, 1]), representing the relative contribution of each feature to the total location, scale, or combined batch effect. Note: normalization is applied independently to each column, so the Euclidean relationship (combined**2 = location**2 + scale**2) no longer holds.
weighted (bool, default=True) – If True, compute a weighted RMS where each batch is weighted by its sample size. This gives more influence to larger batches, producing a more statistically representative summary. If False, all batches contribute equally regardless of size.
- Returns:
DataFrame with index=feature names, columns=[‘location’, ‘scale’, ‘combined’], sorted by ‘combined’ descending.
- Return type:
pd.DataFrame
- Raises:
ValueError – If the model is not fitted or if mode is invalid.
- summary(combat)[source]#
Return a human-readable diagnostic report after fitting.
- Parameters:
combat (ComBat) – A fitted
ComBatinstance.- Returns:
Multi-line summary string.
- Return type:
- Raises:
ValueError – If the model is not fitted.
Metrics#
Functions for computing batch effect metrics.
Batch effect metrics and diagnostics.
- compute_batch_metrics(combat, X, batch=None, *, pca_components=None, k_neighbors=None, kbet_k0=None, lisi_perplexity=30, n_jobs=1, nn_algorithm='auto')[source]#
Compute batch effect metrics before and after ComBat correction.
- Parameters:
combat (ComBat) – A fitted
ComBatinstance.X (array-like of shape (n_samples, n_features)) – Input data to evaluate.
batch (array-like of shape (n_samples,), optional) – Batch labels. If None, uses the batch stored at construction.
pca_components (int, optional) – Number of PCA components for dimensionality reduction before computing metrics. If None (default), metrics are computed in the original feature space. Must be less than min(n_samples, n_features).
k_neighbors (list of int, default=[5, 10, 50]) – Values of k for k-NN preservation metric.
kbet_k0 (int, optional) – Neighborhood size for kBET. Default is 10% of samples.
lisi_perplexity (int, default=30) – Perplexity for LISI computation.
n_jobs (int, default=1) – Number of parallel jobs for neighbor computations.
nn_algorithm ({'auto', 'ball_tree', 'kd_tree', 'brute'}, default='auto') – Algorithm used for nearest neighbor computation. Passed to
sklearn.neighbors.NearestNeighbors.
- Returns:
Dictionary with three main keys:
batch_effect: Silhouette, Davies-Bouldin, kBET, LISI, variance ratio (each with ‘before’ and ‘after’ values)preservation: k-NN preservation fractions, distance correlationalignment: Centroid distance, Levene statistic (each with ‘before’ and ‘after’ values)
- Return type:
- Raises:
ValueError – If the model is not fitted or if pca_components is invalid.
Visualization#
Functions for visualizing batch effects and ComBat corrections.
Visualization utilities for ComBat batch correction.
- plot_transformation(combat, X, *, reduction_method='pca', n_components=2, plot_type='static', figsize=(12, 5), alpha=0.7, point_size=50, cmap='Set1', title=None, show_legend=True, return_embeddings=False, **reduction_kwargs)[source]#
Visualize the ComBat transformation effect using dimensionality reduction.
It shows a before/after comparison of data transformed by ComBat using PCA, t-SNE, or UMAP to reduce dimensions for visualization.
- Parameters:
combat (ComBat) – A fitted
ComBatinstance.X (array-like of shape (n_samples, n_features)) – Input data to transform and visualize.
reduction_method ({‘pca’, ‘tsne’, ‘umap’}, default=`’pca’`) – Dimensionality reduction method.
n_components ({2, 3}, default=2) – Number of components for dimensionality reduction.
plot_type ({‘static’, ‘interactive’}, default=`’static’`) – Visualization type: - ‘static’: matplotlib plots (can be saved as images) - ‘interactive’: plotly plots (explorable, requires plotly)
figsize (tuple of int, default=(12, 5)) – Figure size in inches (width, height). Only used for static plots.
alpha (float, default=0.7) – Marker transparency. Only used for static plots.
point_size (int, default=50) – Marker size. Only used for static plots.
cmap (str, default='Set1') – Matplotlib colormap name for batch colors.
title (str or None, default=None) – Custom figure title. If None, a default title is generated.
show_legend (bool, default=True) – Whether to display the batch legend.
return_embeddings (bool, default=False) – If True, return embeddings along with the plot.
**reduction_kwargs (dict) – Additional keyword arguments passed to the reduction method (e.g.,
perplexityfor t-SNE,n_neighborsfor UMAP).
- Returns:
fig (matplotlib.figure.Figure or plotly.graph_objects.Figure) – The figure object containing the plots.
embeddings (dict, optional) – If return_embeddings=True, dictionary with: - ‘original’: embedding of original data - ‘transformed’: embedding of ComBat-transformed data
- Return type:
- plot_feature_diagnostics(combat, top_n=20, kind='combined', mode='magnitude', weighted=True, figsize=(8, 10))[source]#
Plot top features affected by batch effects.
- Parameters:
combat (ComBat) – A fitted
ComBatinstance.top_n (int, default=20) – Number of top features to display.
kind ({'location', 'scale', 'combined'}, default='combined') –
‘location’: bar plot of location (mean shift) contribution only
’scale’: bar plot of scale (variance) contribution only
’combined’: grouped bar plot showing location and scale side-by-side for each feature (sorted by Euclidean magnitude). In magnitude mode: bars reflect Euclidean decomposition (combined**2 = location**2 + scale**2). In distribution mode: bars reflect independent normalized contributions (each sums to 1 separately).
mode ({'magnitude', 'distribution'}, default='magnitude') –
‘magnitude’: y-axis shows absolute batch effect magnitude
’distribution’: y-axis shows relative contribution (proportion), includes annotation showing cumulative contribution of top_n features (e.g., “Top 20 features explain 75% of total batch effect”)
weighted (bool, default=True) – If True, batch effects are weighted by batch sample size. Passed to
feature_batch_diagnostics().figsize (tuple, default=(8,10)) – Figure size (width, height) in inches.
- Returns:
The figure object containing the plot.
- Return type:
matplotlib.figure.Figure
- Raises:
ValueError – If the model is not fitted, or if kind/mode is invalid.
- plot_batch_effect_heatmap(combat, top_n=50, weighted=True, figsize=(12, 8))[source]#
Plot a heatmap of batch effect parameters across features and batches.
Displays the estimated batch-specific location shifts (gamma) and, unless
mean_only=True, log-scale shifts (log delta) for thetop_nmost affected features.- Parameters:
combat (ComBat) – A fitted
ComBatinstance.top_n (int, default=50) – Number of top features (by combined batch effect) to display.
weighted (bool, default=True) – If True, feature ranking uses sample-size-weighted batch effects. Passed to
feature_batch_diagnostics().figsize (tuple of int, default=(12, 8)) – Figure size in inches.
- Returns:
Figure containing the heatmap(s).
- Return type:
matplotlib.figure.Figure
- Raises:
ValueError – If the model is not fitted.
ImportError – If seaborn is not installed.