Source code for capellini.utils.transforms

"""Numerical transformations: CLR, row normalization, shrinkage correlation."""

from __future__ import annotations

import numpy as np
import pandas as pd


[docs] def clr(X, pseudocount: float = 1e-6, eps: float = 1e-12): """Row-wise centered log-ratio transform. Args: X: Nonnegative array-like of shape (n_samples, n_features). pseudocount: Constant added before the log to handle zeros. eps: Numerical stability constant inside the log. Returns: CLR-transformed numpy array. """ X = np.asarray(X, dtype=float) X = np.clip(X, 0.0, None) + pseudocount log_x = np.log(X + eps) return log_x - log_x.mean(axis=1, keepdims=True)
[docs] def row_normalize(W, eps: float = 1e-12): """Row-normalize a matrix so each row sums to 1. Args: W: Matrix-like. eps: Small constant to avoid division by zero. Returns: Row-normalized numpy array. """ W = np.asarray(W, dtype=float) rs = W.sum(axis=1, keepdims=True) return W / (rs + eps)
[docs] def schaefer_strimmer_corr(X: pd.DataFrame) -> tuple[pd.DataFrame, dict]: """Schäfer-Strimmer shrinkage correlation estimator. Args: X: Samples x features DataFrame (at least 3 samples required). Returns: Tuple (shrinkage_correlation_DataFrame, diagnostics_dict) with keys ``n_samples``, ``n_features``, ``lambda_var``, ``lambda_corr``. """ values = X.to_numpy(dtype=float) n, p = values.shape if n < 3: raise ValueError(f"Need at least 3 samples for shrinkage correlation, got {n}") # variance shrinkage parameter (computed for diagnostics) w = (values - values.mean(axis=0)) ** 2 w_bar = np.mean(w, axis=0) var_unb = (n / (n - 1)) * w_bar var_s = (n / (n - 1) ** 3) * np.sum((w - w_bar) ** 2, axis=0) med_var = np.median(var_unb) denom_var = np.sum((var_unb - med_var) ** 2) lambda_var = 1.0 if denom_var <= 1e-15 else min(1.0, float(np.sum(var_s) / denom_var)) # off-diagonal shrinkage parameter sd = np.std(values, axis=0, ddof=1) X_st = np.divide(values, sd, out=np.zeros_like(values), where=sd > 1e-12) X_c_st = X_st - X_st.mean(axis=0) w_st = X_c_st.T @ X_c_st w_st_sq = (X_c_st ** 2).T @ (X_c_st ** 2) w_bar_st = w_st / n var_s_st = (n / (n - 1) ** 3) * (w_st_sq - 2 * w_bar_st * w_st + n * w_bar_st ** 2) corr_unb_st = (n / (n - 1)) * w_bar_st denom_corr = np.sum(corr_unb_st ** 2) - np.sum(np.diag(corr_unb_st) ** 2) numer_corr = np.sum(var_s_st) - np.sum(np.diag(var_s_st)) lambda_corr = 1.0 if denom_corr <= 1e-15 else min(1.0, float(numer_corr / denom_corr)) corr_X = np.nan_to_num(np.corrcoef(values.T), nan=0.0, posinf=0.0, neginf=0.0) corr_shrink = (1.0 - lambda_corr) * corr_X np.fill_diagonal(corr_shrink, 1.0) out = pd.DataFrame(corr_shrink, index=X.columns, columns=X.columns) info = { "n_samples": int(n), "n_features": int(p), "lambda_var": float(lambda_var), "lambda_corr": float(lambda_corr), } return out, info