"""Network-level utilities: message passing, CRISPR smoothing, taxonomy kernels, abundance helpers."""
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from capellini.utils.transforms import (
closure,
clr,
double_clr_transform,
row_normalize,
)
from capellini.utils.taxonomy import (
sanitize_taxon_name,
sanitize_index,
clean_index_ids,
clean_df_ids,
parse_bool_series,
load_bacteria_taxonomy,
clean_bacteria_taxonomy,
apply_custom_renames,
rename_clostridium_sensu_stricto,
)
from capellini.utils.io import read_table, write_df
# ── Validation helpers ─────────────────────────────────────────────────────────
def _validate_inputs(V: np.ndarray, B: np.ndarray, W: np.ndarray, lam: float, n_steps: int) -> None:
"""Validate shapes and parameter ranges for message-passing functions.
Args:
V: Samples x viruses matrix.
B: Samples x bacteria matrix.
W: Viruses x bacteria CRISPR matrix.
lam: Mixing weight.
n_steps: Number of propagation steps.
Raises:
ValueError: On shape mismatch or invalid parameter values.
"""
n, q = V.shape
n_b, p = B.shape
if n != n_b:
raise ValueError(f"V and B must have the same number of samples, got {n} and {n_b}.")
if W.shape != (q, p):
raise ValueError(f"W_vh_smooth must have shape ({q}, {p}), got {W.shape}.")
if lam < 0:
raise ValueError(f"lam must be nonnegative, got {lam}.")
if n_steps < 1:
raise ValueError(f"n_steps must be at least 1, got {n_steps}.")
def _build_common_inputs(
V_df: pd.DataFrame,
B_df: pd.DataFrame,
W_vh_smooth_df: pd.DataFrame,
pseudocount: float,
eps: float,
) -> tuple:
"""Shared preprocessing for X-star builders: align samples, apply double-CLR, subset W.
Args:
V_df: Samples x viruses abundance.
B_df: Samples x bacteria abundance.
W_vh_smooth_df: Viruses x bacteria smoothed CRISPR matrix.
pseudocount: CLR pseudocount.
eps: Numerical stability constant.
Returns:
Tuple (V0, B0, V_clr_df, B_clr_df, X_clr_df, W_sub).
Raises:
ValueError: If no overlapping samples, viruses, or bacteria are found.
"""
common_samples = V_df.index.intersection(B_df.index)
if len(common_samples) == 0:
raise ValueError("No overlapping sample IDs between V_df and B_df.")
V0 = V_df.loc[common_samples].copy()
B0 = B_df.loc[common_samples].copy()
V_clr_df, B_clr_df, _ = double_clr_transform(V0, B0, pseudocount=pseudocount, eps=eps)
viruses = V_clr_df.columns.intersection(W_vh_smooth_df.index)
bacteria = B_clr_df.columns.intersection(W_vh_smooth_df.columns)
if len(viruses) == 0:
raise ValueError("No overlapping viruses between V_clr_df.columns and W_vh_smooth_df.index.")
if len(bacteria) == 0:
raise ValueError("No overlapping bacteria between B_clr_df.columns and W_vh_smooth_df.columns.")
V_clr_df = V_clr_df.loc[:, viruses].copy()
B_clr_df = B_clr_df.loc[:, bacteria].copy()
X_clr_df = pd.concat([V_clr_df, B_clr_df], axis=1)
W_sub = W_vh_smooth_df.loc[viruses, bacteria].copy()
return V0, B0, V_clr_df, B_clr_df, X_clr_df, W_sub
# ── Message passing ────────────────────────────────────────────────────────────
[docs]
def build_xstar_from_smoothed_crispr(
V_df: pd.DataFrame,
B_df: pd.DataFrame,
W_vh_smooth_df: pd.DataFrame,
pseudocount: float = 1e-6,
lam: float = 0.1,
n_steps: int = 1,
preserve_scale: bool = False,
eps: float = 1e-12,
) -> dict[str, pd.DataFrame]:
"""Full non-residual X-star pipeline (convex message passing).
Args:
V_df: Samples x viruses raw abundance.
B_df: Samples x bacteria raw abundance.
W_vh_smooth_df: Viruses x bacteria smoothed CRISPR matrix.
pseudocount: CLR pseudocount.
lam: Mixing weight.
n_steps: Number of propagation steps.
preserve_scale: If True, restore original column standard deviations.
eps: Numerical stability constant.
Returns:
Dict with keys: V_raw_aligned, B_raw_aligned, V_clr, B_clr, X_clr, V_star, B_star, X_star, W_smooth_aligned.
"""
V0, B0, V_clr_df, B_clr_df, X_clr_df, W_sub = _build_common_inputs(
V_df, B_df, W_vh_smooth_df, pseudocount=pseudocount, eps=eps
)
V_star_df, B_star_df, X_star_df = transform_message_passing_smoothed_crispr_df(
V_clr_df, B_clr_df, W_sub, lam=lam, n_steps=n_steps, preserve_scale=preserve_scale
)
return {
"V_raw_aligned": V0.loc[:, V_clr_df.columns],
"B_raw_aligned": B0.loc[:, B_clr_df.columns],
"V_clr": V_clr_df,
"B_clr": B_clr_df,
"X_clr": X_clr_df,
"V_star": V_star_df,
"B_star": B_star_df,
"X_star": X_star_df,
"W_smooth_aligned": W_sub,
}
[docs]
def residual_message_passing(
V: np.ndarray,
B: np.ndarray,
W_vh_smooth: np.ndarray,
lam: float = 0.1,
n_steps: int = 1,
preserve_scale: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
"""Residual additive cross-domain message passing using a smoothed CRISPR matrix.
Args:
V: Samples x viruses CLR matrix.
B: Samples x bacteria CLR matrix.
W_vh_smooth: Viruses x bacteria smoothed CRISPR matrix.
lam: Additive weight for cross-domain messages.
n_steps: Number of propagation steps.
preserve_scale: If True, restore original column standard deviations.
Returns:
Tuple (V_star, B_star).
"""
V = np.asarray(V, dtype=float)
B = np.asarray(B, dtype=float)
W = np.asarray(W_vh_smooth, dtype=float)
_validate_inputs(V, B, W, lam=lam, n_steps=n_steps)
P_vh = row_normalize(W)
P_hv = row_normalize(W.T)
Vt = V.copy()
Bt = B.copy()
V_sd0 = V.std(axis=0, ddof=0)
B_sd0 = B.std(axis=0, ddof=0)
for _ in range(n_steps):
Rv = Vt - Vt.mean(axis=0, keepdims=True)
Rb = Bt - Bt.mean(axis=0, keepdims=True)
V_from_hosts = Rb @ P_hv
B_from_virs = Rv @ P_vh
Vt = Vt + lam * V_from_hosts
Bt = Bt + lam * B_from_virs
if preserve_scale:
V_sd = Vt.std(axis=0, ddof=0)
B_sd = Bt.std(axis=0, ddof=0)
V_scale = np.divide(V_sd0, V_sd, out=np.ones_like(V_sd0), where=V_sd > 1e-12)
B_scale = np.divide(B_sd0, B_sd, out=np.ones_like(B_sd0), where=B_sd > 1e-12)
Vt = Vt * V_scale
Bt = Bt * B_scale
return Vt, Bt
[docs]
def residual_message_passing_df(
V_clr_df: pd.DataFrame,
B_clr_df: pd.DataFrame,
W_vh_smooth_df: pd.DataFrame,
lam: float = 0.1,
n_steps: int = 1,
preserve_scale: bool = False,
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""DataFrame wrapper for residual additive message passing.
Args:
V_clr_df: Samples x viruses CLR-transformed abundance.
B_clr_df: Samples x bacteria CLR-transformed abundance.
W_vh_smooth_df: Viruses x bacteria smoothed CRISPR matrix.
lam: Additive weight.
n_steps: Number of propagation steps.
preserve_scale: If True, restore original column standard deviations.
Returns:
Tuple (V_star_df, B_star_df, X_star_df).
"""
viruses = V_clr_df.columns.intersection(W_vh_smooth_df.index)
bacteria = B_clr_df.columns.intersection(W_vh_smooth_df.columns)
if len(viruses) == 0:
raise ValueError("No overlapping virus names between V_clr_df.columns and W_vh_smooth_df.index.")
if len(bacteria) == 0:
raise ValueError("No overlapping bacteria names between B_clr_df.columns and W_vh_smooth_df.columns.")
V_sub = V_clr_df.loc[:, viruses].copy()
B_sub = B_clr_df.loc[:, bacteria].copy()
W_sub = W_vh_smooth_df.loc[viruses, bacteria].copy()
if not V_sub.index.equals(B_sub.index):
common_samples = V_sub.index.intersection(B_sub.index)
if len(common_samples) == 0:
raise ValueError("No overlapping sample IDs between V_clr_df and B_clr_df.")
V_sub = V_sub.loc[common_samples]
B_sub = B_sub.loc[common_samples]
V_star, B_star = residual_message_passing(
V=V_sub.to_numpy(),
B=B_sub.to_numpy(),
W_vh_smooth=W_sub.to_numpy(),
lam=lam,
n_steps=n_steps,
preserve_scale=preserve_scale,
)
V_star_df = pd.DataFrame(V_star, index=V_sub.index, columns=V_sub.columns)
B_star_df = pd.DataFrame(B_star, index=B_sub.index, columns=B_sub.columns)
X_star_df = pd.concat([V_star_df, B_star_df], axis=1)
return V_star_df, B_star_df, X_star_df
[docs]
def build_xstar_from_smoothed_crispr_residual(
V_df: pd.DataFrame,
B_df: pd.DataFrame,
W_vh_smooth_df: pd.DataFrame,
pseudocount: float = 1e-6,
lam: float = 0.1,
n_steps: int = 1,
preserve_scale: bool = False,
eps: float = 1e-12,
) -> dict[str, pd.DataFrame]:
"""Full residual additive X-star pipeline.
Args:
V_df: Samples x viruses raw abundance.
B_df: Samples x bacteria raw abundance.
W_vh_smooth_df: Viruses x bacteria smoothed CRISPR matrix.
pseudocount: CLR pseudocount.
lam: Additive weight.
n_steps: Number of propagation steps.
preserve_scale: If True, restore original column standard deviations.
eps: Numerical stability constant.
Returns:
Dict with keys: V_raw_aligned, B_raw_aligned, V_clr, B_clr, X_clr, V_star, B_star, X_star, W_smooth_aligned.
"""
V0, B0, V_clr_df, B_clr_df, X_clr_df, W_sub = _build_common_inputs(
V_df, B_df, W_vh_smooth_df, pseudocount=pseudocount, eps=eps
)
V_star_df, B_star_df, X_star_df = residual_message_passing_df(
V_clr_df, B_clr_df, W_sub, lam=lam, n_steps=n_steps, preserve_scale=preserve_scale
)
return {
"V_raw_aligned": V0.loc[:, V_clr_df.columns],
"B_raw_aligned": B0.loc[:, B_clr_df.columns],
"V_clr": V_clr_df,
"B_clr": B_clr_df,
"X_clr": X_clr_df,
"V_star": V_star_df,
"B_star": B_star_df,
"X_star": X_star_df,
"W_smooth_aligned": W_sub,
}
# ── CRISPR smoothing ───────────────────────────────────────────────────────────
[docs]
def smooth_crispr_bac_vir(
crispr_df: pd.DataFrame,
K_bac: pd.DataFrame,
K_vir: pd.DataFrame,
alpha: float = 1.0,
preserve_original: bool = True,
) -> pd.DataFrame:
"""Smooth a CRISPR matrix using taxonomy kernels.
W_smooth = (1 - alpha) * W + alpha * K_bac @ W @ K_vir
Args:
crispr_df: Bacteria x viruses binary CRISPR matrix.
K_bac: Bacteria taxonomy kernel.
K_vir: Virus taxonomy kernel.
alpha: Smoothing weight (1.0 = full kernel propagation).
preserve_original: If True, blend with original; if False, use propagated only.
Returns:
Smoothed CRISPR DataFrame.
"""
bac_common = crispr_df.index.intersection(K_bac.index)
vir_common = crispr_df.columns.intersection(K_vir.index)
if len(bac_common) == 0:
raise ValueError("No overlapping bacteria between crispr_df.index and K_bac.index")
if len(vir_common) == 0:
raise ValueError("No overlapping viruses between crispr_df.columns and K_vir.index")
W = crispr_df.loc[bac_common, vir_common].copy()
Kb = K_bac.loc[bac_common, bac_common]
Kv = K_vir.loc[vir_common, vir_common]
W_prop = Kb.to_numpy() @ W.to_numpy(dtype=float) @ Kv.to_numpy()
if preserve_original:
W_smooth = (1 - alpha) * W.to_numpy(dtype=float) + alpha * W_prop
else:
W_smooth = W_prop
return pd.DataFrame(W_smooth, index=W.index, columns=W.columns)
[docs]
def build_taxonomy_kernel_from_shared_ranks(
ids,
tax_df: pd.DataFrame,
ranks,
weights=None,
fill_value: str = "",
normalize_rows: bool = True,
strict: bool = False,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Build a taxonomy kernel matrix using weighted shared rank agreement.
K[i,j] = sum_k weights[k] * I(rank_k[i] == rank_k[j]) if rank_k[i] != fill_value.
K[i,i] = 1.
Args:
ids: Ordered collection of feature IDs to include.
tax_df: Taxonomy DataFrame indexed by the same IDs.
ranks: Ordered list of taxonomy rank column names.
weights: Per-rank weights (default: 1..n). Normalized to sum to 1.
fill_value: Value treated as missing (no match score).
normalize_rows: Row-normalize the kernel.
strict: Raise ValueError if any IDs are missing from tax_df.
Returns:
Tuple (K DataFrame, aligned taxonomy DataFrame).
"""
ids = pd.Index(ids)
ranks = list(ranks)
missing_ranks = [r for r in ranks if r not in tax_df.columns]
if missing_ranks:
raise ValueError(f"Missing taxonomy columns: {missing_ranks}")
common = ids.intersection(tax_df.index)
missing = ids.difference(tax_df.index)
if len(common) == 0:
raise ValueError("No overlap between ids and taxonomy index")
if len(missing) > 0:
msg = f"Missing taxonomy for {len(missing)} taxa"
if strict:
raise ValueError(msg)
else:
print(msg + " — keeping only overlapping taxa.")
if weights is None:
weights = np.arange(1, len(ranks) + 1, dtype=float)
else:
weights = np.asarray(weights, dtype=float)
if len(weights) != len(ranks):
raise ValueError("weights must have same length as ranks")
weights = weights / weights.sum()
tax = tax_df.loc[common, ranks].copy()
tax = tax.fillna(fill_value).astype(str)
tax = tax.replace(
{"nan": fill_value, "None": fill_value, "NA": fill_value,
"N/A": fill_value, "unknown": fill_value, "unclassified": fill_value}
)
for r in ranks:
tax[r] = tax[r].str.replace(r"^[a-z]__+", "", regex=True)
X = tax.to_numpy()
n = X.shape[0]
K = np.zeros((n, n), dtype=float)
for k in range(len(ranks)):
col = X[:, k]
same = (col[:, None] == col[None, :]) & (col[:, None] != fill_value)
K += weights[k] * same.astype(float)
np.fill_diagonal(K, 1.0)
if normalize_rows:
rs = K.sum(axis=1, keepdims=True)
rs[rs == 0] = 1.0
K = K / rs
K = pd.DataFrame(K, index=common, columns=common)
return K, tax
[docs]
def assign_crispr(cri_big: pd.DataFrame, cri_s: pd.DataFrame) -> pd.DataFrame:
"""Fill matching rows/columns of a target matrix from a source CRISPR matrix.
Args:
cri_big: Target (full-size) matrix.
cri_s: Source CRISPR matrix (subset).
Returns:
Updated copy of cri_big.
"""
full_cri = cri_big.copy()
overlap_rows = full_cri.index.intersection(cri_s.index)
overlap_cols = full_cri.columns.intersection(cri_s.columns)
if len(overlap_rows) > 0 and len(overlap_cols) > 0:
full_cri.loc[overlap_rows, overlap_cols] = cri_s.loc[overlap_rows, overlap_cols]
return full_cri
[docs]
def build_binary_crispr_matrix(
raw_crispr_path: str,
bacteria_features,
virus_features,
transpose_after_load: bool = True,
) -> pd.DataFrame:
"""Load a raw CRISPR network and build a binary bacteria x viruses matrix.
Args:
raw_crispr_path: Path to the CSV of the raw CRISPR network.
bacteria_features: Ordered list of bacteria feature IDs.
virus_features: Ordered list of virus feature IDs.
transpose_after_load: If True, transpose the loaded matrix (contigs were rows).
Returns:
Binary bacteria x viruses DataFrame.
"""
df_crispr = pd.read_csv(raw_crispr_path, index_col=0)
if transpose_after_load:
df_crispr = df_crispr.T
crispr = df_crispr.T
crispr.index = [str(i) for i in crispr.index]
target = pd.DataFrame(
np.zeros((len(bacteria_features), len(virus_features))),
index=pd.Index(bacteria_features),
columns=pd.Index(virus_features),
)
full = assign_crispr(target, crispr)
full[full != 0] = 1
return full
[docs]
def get_hierarchies(df_b1: pd.DataFrame, df_v1: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Prepare bacteria and virus taxonomy hierarchies for CRISPR smoothing.
Args:
df_b1: Bacteria taxonomy DataFrame with progenomes_taxid_genus column.
df_v1: Virus taxonomy DataFrame with lev0 column.
Returns:
Tuple (bacteria_tax, virus_tax) with cleaned indexes.
"""
df_b = df_b1.copy()
df_v = df_v1.copy()
df_v = df_v.set_index("lev0")
silva_str = [str(i) for i in df_b["progenomes_taxid_genus"]]
silva_str1 = [i.split(".")[0] for i in silva_str]
df_b["progenomes_taxid_genus"] = silva_str1
df_b = df_b.set_index("progenomes_taxid_genus")
keep_index = [i for i in df_b.index if "nan" not in i]
df_b = df_b.loc[keep_index]
df_b = df_b.drop_duplicates()
df_b = df_b.loc[~df_b.index.duplicated(keep="first")]
df_v = df_v.loc[~df_v.index.duplicated(keep="first")]
df_b["progenomes_taxid_genus"] = df_b.index
df_v["lev0"] = df_v.index
return df_b, df_v
[docs]
def build_smoothed_crispr_for_study(
raw_crispr_path: str,
bacteria_features,
virus_features,
tax_bac_path: str,
tax_vir_path: str,
bacterial_ranks,
viral_ranks,
bacterial_weights,
viral_weights,
alpha: float = 0.95,
transpose_after_load: bool = True,
) -> dict[str, pd.DataFrame]:
"""Build a smoothed CRISPR matrix for a single study.
Args:
raw_crispr_path: Path to the raw CRISPR network CSV.
bacteria_features: Bacteria feature IDs (from processed abundance).
virus_features: Virus feature IDs (from processed abundance).
tax_bac_path: Path to bacteria taxonomy CSV.
tax_vir_path: Path to virus taxonomy CSV.
bacterial_ranks: Bacterial taxonomy rank columns for kernel.
viral_ranks: Viral taxonomy rank columns for kernel.
bacterial_weights: Per-rank weights for bacteria kernel.
viral_weights: Per-rank weights for virus kernel.
alpha: CRISPR smoothing weight.
transpose_after_load: Passed to build_binary_crispr_matrix.
Returns:
Dict with crispr_binary, K_bac, K_vir, crispr_smooth, bac_tax_aligned, vir_tax_aligned.
"""
tax_bac = pd.read_csv(tax_bac_path, index_col=0)
tax_vir = pd.read_csv(tax_vir_path, index_col=0)
tax_bac_aligned, tax_vir_aligned = get_hierarchies(tax_bac, tax_vir)
crispr_binary = build_binary_crispr_matrix(
raw_crispr_path=raw_crispr_path,
bacteria_features=bacteria_features,
virus_features=virus_features,
transpose_after_load=transpose_after_load,
)
K_bac, bac_tax_aligned = build_taxonomy_kernel_from_shared_ranks(
ids=crispr_binary.index,
tax_df=tax_bac_aligned,
ranks=bacterial_ranks,
weights=bacterial_weights,
normalize_rows=True,
strict=False,
)
K_vir, vir_tax_aligned = build_taxonomy_kernel_from_shared_ranks(
ids=crispr_binary.columns,
tax_df=tax_vir_aligned,
ranks=viral_ranks,
weights=viral_weights,
normalize_rows=True,
strict=False,
)
crispr_smooth = smooth_crispr_bac_vir(
crispr_df=crispr_binary,
K_bac=K_bac,
K_vir=K_vir,
alpha=alpha,
preserve_original=True,
)
return {
"crispr_binary": crispr_binary,
"K_bac": K_bac,
"K_vir": K_vir,
"crispr_smooth": crispr_smooth,
"bac_tax_aligned": bac_tax_aligned,
"vir_tax_aligned": vir_tax_aligned,
}
[docs]
def crispr_matrix_aggregate_viruses(
df_crispr: pd.DataFrame,
vir_tax: pd.DataFrame,
*,
bac_col: int = 0,
vir_col: int = 1,
vir_rank: str = "lev0",
vir_id_col=None,
dropna_vir: bool = True,
dtype=int,
) -> pd.DataFrame:
"""Parse a SpacePHARER output TSV and aggregate viruses by taxonomy rank.
Args:
df_crispr: Raw SpacePHARER predictions DataFrame (no header).
vir_tax: Virus taxonomy DataFrame.
bac_col: Column index for bacterial spacer IDs.
vir_col: Column index for viral contig IDs.
vir_rank: Viral taxonomy rank column to aggregate to.
vir_id_col: Optional viral ID column in vir_tax; uses index if None.
dropna_vir: Drop viruses not found in taxonomy.
dtype: Output matrix dtype.
Returns:
Bacteria x viral_groups crosstab matrix.
"""
def parse_bac_taxid(s):
if pd.isna(s):
return None
s = str(s)
try:
i1 = s.split(">")[1]
return int(i1.split(".")[0])
except Exception:
return None
df1 = df_crispr[[bac_col, vir_col]].copy()
df1["bac_taxid"] = df1[bac_col].map(parse_bac_taxid)
df1["contig"] = df1[vir_col].astype(str).str.strip()
df1 = df1.dropna(subset=["bac_taxid", "contig"])
df1["bac_taxid"] = df1["bac_taxid"].astype(int)
M = pd.crosstab(df1["bac_taxid"], df1["contig"]).astype(dtype)
if vir_rank not in vir_tax.columns:
raise KeyError(f"vir_tax missing column {vir_rank!r}. Available: {list(vir_tax.columns)[:20]}")
if vir_id_col is None:
vt = vir_tax[[vir_rank]].copy()
vt.index = vt.index.astype(str).str.strip()
vt = vt.loc[~vt.index.duplicated(keep="first")]
contig_to_group = vt[vir_rank]
labels = pd.Series(M.columns.astype(str), index=M.columns).map(contig_to_group)
else:
if vir_id_col not in vir_tax.columns:
raise KeyError(f"vir_tax missing id column {vir_id_col!r}")
vt = vir_tax[[vir_id_col, vir_rank]].copy()
vt = vt.dropna(subset=[vir_id_col])
vt[vir_id_col] = vt[vir_id_col].astype(str).str.strip()
vt = vt.drop_duplicates(subset=[vir_id_col], keep="first")
contig_to_group = vt.set_index(vir_id_col)[vir_rank]
labels = pd.Series(M.columns.astype(str), index=M.columns).map(contig_to_group)
valid = labels.notna() & (labels.astype(str).str.strip() != "")
if dropna_vir:
M = M.loc[:, valid.values]
labels = labels.loc[valid]
else:
labels = labels.fillna("Unassigned")
out = M.T.groupby(labels.astype(str).str.strip()).sum().T
out.index = out.index.astype(str)
return out
# ── Abundance helpers ──────────────────────────────────────────────────────────
[docs]
def aggregate_otu_columns_by_rank_skip_nan(
otu: pd.DataFrame,
tax: pd.DataFrame,
rank: str,
) -> pd.DataFrame:
"""Aggregate OTU/ASV columns to a taxonomy rank, skipping NaN labels.
Args:
otu: Samples x ASVs abundance DataFrame.
tax: ASVs x ranks taxonomy DataFrame.
rank: Taxonomy column to aggregate to.
Returns:
Samples x rank-groups abundance DataFrame.
"""
common_asvs = otu.columns.intersection(tax.index)
if len(common_asvs) == 0:
raise ValueError("No common ASV/OTU IDs between abundance columns and taxonomy index.")
otu2 = otu.loc[:, common_asvs]
labels = tax.loc[common_asvs, rank]
keep = labels.notna() & (labels.astype(str).str.strip() != "")
otu2 = otu2.loc[:, keep.values]
labels = labels.loc[keep].astype(str).str.strip()
return otu2.T.groupby(labels).sum().T
[docs]
def prepare_bacteria_genus_abundance(
otu: pd.DataFrame,
tax: pd.DataFrame,
rank: str = "target_taxids",
) -> pd.DataFrame:
"""Aggregate bacteria OTUs to the selected taxonomy rank and sanitize feature names.
Args:
otu: Samples x ASVs raw abundance.
tax: ASV taxonomy DataFrame with a column matching rank.
rank: Taxonomy column to aggregate to.
Returns:
Samples x genus-level abundance DataFrame with sanitized column names.
"""
genus_abund = aggregate_otu_columns_by_rank_skip_nan(otu, tax, rank)
genus_abund = rename_clostridium_sensu_stricto(genus_abund)
genus_abund.columns = sanitize_index(genus_abund.columns)
genus_abund = apply_custom_renames(genus_abund)
return genus_abund
[docs]
def remove_disease_columns_from_virus_abundance(V: pd.DataFrame) -> pd.DataFrame:
"""Remove phenotype/metadata columns accidentally stored in viral abundance tables.
Args:
V: Viral abundance DataFrame.
Returns:
Cleaned copy without non-feature columns.
"""
V = V.copy()
bad_cols = {
"disease", "Disease", "disease_original", "disease_binary",
"phenotype", "Phenotype", "label", "Label",
"class", "Class", "group", "Group",
"sample_id", "SampleID", "subject_id", "SubjectNo",
"Reads", "reads",
}
cols_to_drop = [c for c in V.columns if str(c) in bad_cols]
if cols_to_drop:
print("Dropping non-feature columns from viral abundance:", cols_to_drop)
V = V.drop(columns=cols_to_drop)
return V
[docs]
def prevalence_filter_df(df: pd.DataFrame, prevalence: float = 0.10, verbose: bool = True) -> pd.DataFrame:
"""Keep only features present in at least prevalence * n_samples samples.
Args:
df: Samples x features DataFrame.
prevalence: Minimum fractional prevalence threshold.
verbose: Print kept/total feature count.
Returns:
Filtered DataFrame.
"""
n_samples = df.shape[0]
min_samples = max(int(prevalence * n_samples), 1)
keep = (df > 0).sum(axis=0) >= min_samples
if verbose:
print(
f"prevalence={prevalence} -> min_samples={min_samples}/{n_samples}; "
f"kept {int(keep.sum())}/{df.shape[1]} features"
)
return df.loc[:, keep]
[docs]
def summarize_df(name: str, df: pd.DataFrame) -> None:
"""Print a short summary of a DataFrame shape and index uniqueness.
Args:
name: Label for the printout.
df: DataFrame to summarize.
"""
print(f"{name}: shape={df.shape}, index_unique={df.index.is_unique}, columns_unique={df.columns.is_unique}")
[docs]
def study_outdir(study: str, subdir: str, output_root: Path) -> Path:
"""Construct the output directory path for a study.
Args:
study: Study identifier string.
subdir: Sub-directory name within the study folder.
output_root: Root output directory.
Returns:
Path to the study sub-directory.
"""
base = output_root / study
return base / subdir if subdir else base
[docs]
def out_path(study: str, subdir: str, filename: str, output_root: Path) -> Path:
"""Construct the full path to an output file within a study directory.
Args:
study: Study identifier string.
subdir: Sub-directory name.
filename: File name.
output_root: Root output directory.
Returns:
Full Path to the output file.
"""
return study_outdir(study, subdir, output_root) / filename
[docs]
def orient_W_viruses_by_bacteria(
W_df: pd.DataFrame,
V_df: pd.DataFrame,
B_df: pd.DataFrame,
verbose: bool = True,
) -> pd.DataFrame:
"""Detect and correct the orientation of a CRISPR matrix to be viruses x bacteria.
Args:
W_df: CRISPR matrix (may be bacteria x viruses or viruses x bacteria).
V_df: Viral abundance (samples x viruses).
B_df: Bacterial abundance (samples x bacteria).
verbose: Print overlap counts.
Returns:
W_df in viruses x bacteria orientation.
"""
n_vir_row = len(set(W_df.index.astype(str)) & set(map(str, V_df.columns)))
n_bac_col = len(set(W_df.columns.astype(str)) & set(map(str, B_df.columns)))
n_bac_row = len(set(W_df.index.astype(str)) & set(map(str, B_df.columns)))
n_vir_col = len(set(W_df.columns.astype(str)) & set(map(str, V_df.columns)))
if verbose:
print("W rows overlapping viruses:", n_vir_row, "/", V_df.shape[1])
print("W cols overlapping bacteria:", n_bac_col, "/", B_df.shape[1])
print("W rows overlapping bacteria:", n_bac_row, "/", B_df.shape[1])
print("W cols overlapping viruses:", n_vir_col, "/", V_df.shape[1])
if n_bac_row > n_vir_row and n_vir_col > n_bac_col:
print("Detected bacteria x viruses orientation; transposing W.")
W_df = W_df.T
return W_df