API Reference#

Complete API documentation for combatlearn.

ComBat#

The main scikit-learn compatible transformer for batch effect correction.

class ComBat(batch, *, discrete_covariates=None, continuous_covariates=None, method='johnson', parametric=True, mean_only=False, reference_batch=None, eps=1e-08, covbat_cov_thresh=0.9)[source]#

Bases: BaseEstimator, TransformerMixin

Pipeline-friendly wrapper around ComBatModel.

Stores batch (and optional covariates) passed at construction and appropriately uses them for separate fit and transform.

Parameters:
  • batch (array-like of shape (n_samples,)) – Batch labels for each sample.

  • discrete_covariates (array-like, optional) – Categorical covariates to protect (Fortin/Chen only).

  • continuous_covariates (array-like, optional) – Continuous covariates to protect (Fortin/Chen only).

  • method ({'johnson', 'fortin', 'chen'}, default='johnson') – ComBat variant to use.

  • parametric (bool, default=True) – Use parametric empirical Bayes.

  • mean_only (bool, default=False) – Adjust only the mean (ignore variance).

  • reference_batch (str, optional) – Batch level to leave unchanged.

  • eps (float, default=1e-8) – Numerical jitter for stability.

  • covbat_cov_thresh (float or int, default=0.9) – CovBat variance threshold for PCs.

__init__(batch, *, discrete_covariates=None, continuous_covariates=None, method='johnson', parametric=True, mean_only=False, reference_batch=None, eps=1e-08, covbat_cov_thresh=0.9)[source]#
fit(X, y=None)[source]#

Fit the ComBat model.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Input data to fit.

  • y (None) – Ignored. Present for API compatibility.

Returns:

self – Fitted estimator.

Return type:

ComBat

transform(X)[source]#

Transform the data using fitted ComBat parameters.

Parameters:

X (array-like of shape (n_samples, n_features)) – Input data to transform.

Returns:

X_transformed – Batch-corrected data.

Return type:

pd.DataFrame

get_feature_names_out(input_features=None)[source]#

Get output feature names for transform.

Parameters:

input_features (array-like of str or None, default=None) – Ignored. Present for API compatibility.

Returns:

feature_names_out – Feature names.

Return type:

ndarray of str objects

Raises:

sklearn.exceptions.NotFittedError – If the estimator is not fitted.

Inspection#

Functions for inspecting fitted ComBat models.

Standalone inspection functions for fitted ComBat models.

feature_batch_diagnostics(combat, mode='magnitude', weighted=True)[source]#

Compute per-feature batch effect magnitude.

Returns a DataFrame with columns location, scale, and combined. Location is the (weighted) RMS of gamma across batches (standardized mean shifts). Scale is the (weighted) RMS of log-delta across batches (log-fold variance change). Combined is the Euclidean norm sqrt(location**2 + scale**2). Using RMS provides L2-consistent aggregation; using log(delta) ensures symmetry.

Parameters:
  • combat (ComBat) – A fitted ComBat instance.

  • mode ({'magnitude', 'distribution'}, default='magnitude') –

    • ‘magnitude’: Returns L2-consistent absolute batch effect magnitudes. Suitable for ranking, thresholding, and cross-dataset comparison.

    • ’distribution’: Returns column-wise normalized proportions (each column sums to 1, values in range [0, 1]), representing the relative contribution of each feature to the total location, scale, or combined batch effect. Note: normalization is applied independently to each column, so the Euclidean relationship (combined**2 = location**2 + scale**2) no longer holds.

  • weighted (bool, default=True) – If True, compute a weighted RMS where each batch is weighted by its sample size. This gives more influence to larger batches, producing a more statistically representative summary. If False, all batches contribute equally regardless of size.

Returns:

DataFrame with index=feature names, columns=[‘location’, ‘scale’, ‘combined’], sorted by ‘combined’ descending.

Return type:

pd.DataFrame

Raises:

ValueError – If the model is not fitted or if mode is invalid.

summary(combat)[source]#

Return a human-readable diagnostic report after fitting.

Parameters:

combat (ComBat) – A fitted ComBat instance.

Returns:

Multi-line summary string.

Return type:

str

Raises:

ValueError – If the model is not fitted.

Metrics#

Functions for computing batch effect metrics.

Batch effect metrics and diagnostics.

compute_batch_metrics(combat, X, batch=None, *, pca_components=None, k_neighbors=None, kbet_k0=None, lisi_perplexity=30, n_jobs=1, nn_algorithm='auto')[source]#

Compute batch effect metrics before and after ComBat correction.

Parameters:
  • combat (ComBat) – A fitted ComBat instance.

  • X (array-like of shape (n_samples, n_features)) – Input data to evaluate.

  • batch (array-like of shape (n_samples,), optional) – Batch labels. If None, uses the batch stored at construction.

  • pca_components (int, optional) – Number of PCA components for dimensionality reduction before computing metrics. If None (default), metrics are computed in the original feature space. Must be less than min(n_samples, n_features).

  • k_neighbors (list of int, default=[5, 10, 50]) – Values of k for k-NN preservation metric.

  • kbet_k0 (int, optional) – Neighborhood size for kBET. Default is 10% of samples.

  • lisi_perplexity (int, default=30) – Perplexity for LISI computation.

  • n_jobs (int, default=1) – Number of parallel jobs for neighbor computations.

  • nn_algorithm ({'auto', 'ball_tree', 'kd_tree', 'brute'}, default='auto') – Algorithm used for nearest neighbor computation. Passed to sklearn.neighbors.NearestNeighbors.

Returns:

Dictionary with three main keys:

  • batch_effect: Silhouette, Davies-Bouldin, kBET, LISI, variance ratio (each with ‘before’ and ‘after’ values)

  • preservation: k-NN preservation fractions, distance correlation

  • alignment: Centroid distance, Levene statistic (each with ‘before’ and ‘after’ values)

Return type:

dict

Raises:

ValueError – If the model is not fitted or if pca_components is invalid.

Visualization#

Functions for visualizing batch effects and ComBat corrections.

Visualization utilities for ComBat batch correction.

plot_transformation(combat, X, *, reduction_method='pca', n_components=2, plot_type='static', figsize=(12, 5), alpha=0.7, point_size=50, cmap='Set1', title=None, show_legend=True, return_embeddings=False, **reduction_kwargs)[source]#

Visualize the ComBat transformation effect using dimensionality reduction.

It shows a before/after comparison of data transformed by ComBat using PCA, t-SNE, or UMAP to reduce dimensions for visualization.

Parameters:
  • combat (ComBat) – A fitted ComBat instance.

  • X (array-like of shape (n_samples, n_features)) – Input data to transform and visualize.

  • reduction_method ({‘pca’, ‘tsne’, ‘umap’}, default=`’pca’`) – Dimensionality reduction method.

  • n_components ({2, 3}, default=2) – Number of components for dimensionality reduction.

  • plot_type ({‘static’, ‘interactive’}, default=`’static’`) – Visualization type: - ‘static’: matplotlib plots (can be saved as images) - ‘interactive’: plotly plots (explorable, requires plotly)

  • figsize (tuple of int, default=(12, 5)) – Figure size in inches (width, height). Only used for static plots.

  • alpha (float, default=0.7) – Marker transparency. Only used for static plots.

  • point_size (int, default=50) – Marker size. Only used for static plots.

  • cmap (str, default='Set1') – Matplotlib colormap name for batch colors.

  • title (str or None, default=None) – Custom figure title. If None, a default title is generated.

  • show_legend (bool, default=True) – Whether to display the batch legend.

  • return_embeddings (bool, default=False) – If True, return embeddings along with the plot.

  • **reduction_kwargs (dict) – Additional keyword arguments passed to the reduction method (e.g., perplexity for t-SNE, n_neighbors for UMAP).

Returns:

  • fig (matplotlib.figure.Figure or plotly.graph_objects.Figure) – The figure object containing the plots.

  • embeddings (dict, optional) – If return_embeddings=True, dictionary with: - ‘original’: embedding of original data - ‘transformed’: embedding of ComBat-transformed data

Return type:

Any | tuple[Any, dict[str, FloatArray]]

plot_feature_diagnostics(combat, top_n=20, kind='combined', mode='magnitude', weighted=True, figsize=(8, 10))[source]#

Plot top features affected by batch effects.

Parameters:
  • combat (ComBat) – A fitted ComBat instance.

  • top_n (int, default=20) – Number of top features to display.

  • kind ({'location', 'scale', 'combined'}, default='combined') –

    • ‘location’: bar plot of location (mean shift) contribution only

    • ’scale’: bar plot of scale (variance) contribution only

    • ’combined’: grouped bar plot showing location and scale side-by-side for each feature (sorted by Euclidean magnitude). In magnitude mode: bars reflect Euclidean decomposition (combined**2 = location**2 + scale**2). In distribution mode: bars reflect independent normalized contributions (each sums to 1 separately).

  • mode ({'magnitude', 'distribution'}, default='magnitude') –

    • ‘magnitude’: y-axis shows absolute batch effect magnitude

    • ’distribution’: y-axis shows relative contribution (proportion), includes annotation showing cumulative contribution of top_n features (e.g., “Top 20 features explain 75% of total batch effect”)

  • weighted (bool, default=True) – If True, batch effects are weighted by batch sample size. Passed to feature_batch_diagnostics().

  • figsize (tuple, default=(8,10)) – Figure size (width, height) in inches.

Returns:

The figure object containing the plot.

Return type:

matplotlib.figure.Figure

Raises:

ValueError – If the model is not fitted, or if kind/mode is invalid.

plot_batch_effect_heatmap(combat, top_n=50, weighted=True, figsize=(12, 8))[source]#

Plot a heatmap of batch effect parameters across features and batches.

Displays the estimated batch-specific location shifts (gamma) and, unless mean_only=True, log-scale shifts (log delta) for the top_n most affected features.

Parameters:
  • combat (ComBat) – A fitted ComBat instance.

  • top_n (int, default=50) – Number of top features (by combined batch effect) to display.

  • weighted (bool, default=True) – If True, feature ranking uses sample-size-weighted batch effects. Passed to feature_batch_diagnostics().

  • figsize (tuple of int, default=(12, 8)) – Figure size in inches.

Returns:

Figure containing the heatmap(s).

Return type:

matplotlib.figure.Figure

Raises: