"""Network stage: common abundance, shrinkage, raw/smoothed CRISPR, residual X*.
Slim core. The math follows the paper (CLR + W̃ taxonomy smoothing + residual
message passing). Schäfer-Strimmer shrinkage correlations are computed both
on the joint Z = [B_clr V_clr] and on the propagated Z* = [B_star V_star].
Outputs go directly under ``cfg.output_root`` — no per-study sub-folder.
"""
from __future__ import annotations
import logging
from pathlib import Path
import pandas as pd
from capellini.config import CapelliniConfig
from capellini.utils.io import read_table, write_df
from capellini.utils.taxonomy import (
clean_df_ids,
load_bacteria_taxonomy,
)
from capellini.utils.transforms import clr, schaefer_strimmer_corr
from capellini.utils.network_utils import (
align_abundance_from_metadata,
build_smoothed_crispr_for_study,
build_xstar_from_smoothed_crispr,
crispr_matrix_aggregate_viruses,
orient_W_viruses_by_bacteria,
prepare_bacteria_genus_abundance,
prevalence_filter_df,
remove_disease_columns_from_virus_abundance,
)
logger = logging.getLogger(__name__)
def _subdir(cfg: CapelliniConfig, name: str) -> Path:
out = Path(cfg.output_root) / name
out.mkdir(parents=True, exist_ok=True)
return out
def _shrinkage_block(joint_clr: pd.DataFrame, B_cols, V_cols) -> tuple[pd.DataFrame, pd.DataFrame, dict]:
"""Run SS on a CLR-transformed [B V] frame; return (corr_full, corr_BV, info)."""
corr_df, info = schaefer_strimmer_corr(joint_clr)
bv = corr_df.loc[B_cols, V_cols]
return corr_df, bv, info
[docs]
def build_common_abundance_one(cfg: CapelliniConfig) -> dict[str, Path]:
logger.info("common abundance: loading raw inputs")
V_raw = read_table(cfg.virus_abundance_raw)
V_raw = remove_disease_columns_from_virus_abundance(V_raw)
V_raw = clean_df_ids(V_raw)
B_otu = read_table(cfg.bacteria_otu)
B_tax = load_bacteria_taxonomy(cfg.bacteria_taxonomy)
B_genus = prepare_bacteria_genus_abundance(B_otu, B_tax, rank=cfg.bacteria_taxonomy_rank)
meta = read_table(cfg.metadata_path, index_col=None)
V, B, meta_aligned = align_abundance_from_metadata(
virus_abundance=V_raw,
bacteria_abundance=B_genus,
metadata=meta,
keep_col=cfg.keep_column,
)
V = prevalence_filter_df(V, prevalence=cfg.prevalence)
B = prevalence_filter_df(B, prevalence=cfg.prevalence)
out_dir = _subdir(cfg, "common")
paths = {
"V": out_dir / "V_processed.csv",
"B": out_dir / "B_processed.csv",
"meta": out_dir / "metadata_aligned.csv",
}
write_df(V, paths["V"])
write_df(B, paths["B"])
write_df(meta_aligned, paths["meta"])
logger.info("common abundance written to %s", out_dir)
return paths
[docs]
def build_shrinkage_one(cfg: CapelliniConfig) -> dict[str, Path]:
"""SS shrinkage correlation on Z = [B_clr V_clr]."""
out_dir = _subdir(cfg, "shrinkage")
common = Path(cfg.output_root) / "common"
V = read_table(common / "V_processed.csv")
B = read_table(common / "B_processed.csv").loc[V.index]
V_clr = pd.DataFrame(clr(V, pseudocount=cfg.pseudocount), index=V.index, columns=V.columns)
B_clr = pd.DataFrame(clr(B, pseudocount=cfg.pseudocount), index=B.index, columns=B.columns)
V_pref = V_clr.add_prefix("V__")
B_pref = B_clr.add_prefix("B__")
joint = pd.concat([B_pref, V_pref], axis=1)
corr_df, bv, info = _shrinkage_block(joint, B_pref.columns, V_pref.columns)
paths = {
"corr_full": out_dir / "shrinkage_corr_full.csv.gz",
"corr_bv": out_dir / "shrinkage_corr_BV.csv.gz",
}
write_df(corr_df, paths["corr_full"])
write_df(bv, paths["corr_bv"])
logger.info("shrinkage written (lambda=%s)", info)
return paths
[docs]
def build_raw_crispr_one(cfg: CapelliniConfig) -> Path:
"""Aggregate the SpacePHARER predictions into a (bac × vOTU) CRISPR matrix."""
out_dir = _subdir(cfg, "crispr_raw")
df_pred = pd.read_csv(cfg.phage_host_predictions, sep="\t", header=None, comment="#")
vir_tax = read_table(cfg.tax_vir)
crispr = crispr_matrix_aggregate_viruses(df_pred, vir_tax, vir_rank=cfg.aggregate_viral_rank)
out = out_dir / "crispr_net.csv.gz"
write_df(crispr, out)
logger.info("raw CRISPR matrix written: %s", out)
return out
[docs]
def build_smooth_crispr_one(cfg: CapelliniConfig) -> dict[str, Path]:
"""Smooth the raw CRISPR matrix via taxonomy kernels; persist all artefacts."""
out_dir = _subdir(cfg, "crispr_smooth")
common = Path(cfg.output_root) / "common"
V = read_table(common / "V_processed.csv")
B = read_table(common / "B_processed.csv")
crispr_path = Path(cfg.output_root) / "crispr_raw" / "crispr_net.csv.gz"
art = build_smoothed_crispr_for_study(
raw_crispr_path=str(crispr_path),
bacteria_features=B.columns,
virus_features=V.columns,
tax_bac_path=cfg.tax_bac_for_smoothing,
tax_vir_path=cfg.tax_vir,
bacterial_ranks=cfg.bacterial_ranks,
viral_ranks=cfg.viral_ranks,
bacterial_weights=cfg.bacterial_weights,
viral_weights=cfg.viral_weights,
alpha=cfg.crispr_smooth_alpha,
transpose_after_load=cfg.transpose_raw_crispr_after_load,
)
paths: dict[str, Path] = {}
for name in ("crispr_binary", "crispr_smooth", "K_bac", "K_vir"):
p = out_dir / f"{name}.csv.gz"
write_df(art[name], p)
paths[name] = p
p_vb = out_dir / "crispr_smooth_vir_bac.csv.gz"
write_df(art["crispr_smooth"].T, p_vb)
paths["crispr_smooth_vir_bac"] = p_vb
logger.info("smoothed CRISPR written to %s", out_dir)
return paths
[docs]
def build_xstar_one(cfg: CapelliniConfig) -> dict[str, Path]:
"""Residual X* propagation + Schäfer-Strimmer correlations on Z*."""
out_dir = _subdir(cfg, "xstar")
common = Path(cfg.output_root) / "common"
smooth = Path(cfg.output_root) / "crispr_smooth"
V = read_table(common / "V_processed.csv")
B = read_table(common / "B_processed.csv")
W_path = smooth / "crispr_smooth_vir_bac.csv.gz"
if W_path.exists():
W = read_table(W_path)
else:
W = read_table(smooth / "crispr_smooth.csv.gz").T
W = orient_W_viruses_by_bacteria(W, V, B)
res = build_xstar_from_smoothed_crispr(
V_df=V, B_df=B, W_vh_smooth_df=W,
pseudocount=cfg.pseudocount, lam=cfg.lam, n_steps=cfg.n_steps,
)
# Shrinkage on Z* = [B_star V_star]
V_pref = res["V_star"].add_prefix("V__")
B_pref = res["B_star"].add_prefix("B__")
joint_star = pd.concat([B_pref, V_pref], axis=1)
corr_star_full, corr_star_bv, info_star = _shrinkage_block(joint_star, B_pref.columns, V_pref.columns)
paths = {
"V_clr": out_dir / "V_clr.csv.gz",
"B_clr": out_dir / "B_clr.csv.gz",
"V_star": out_dir / "V_star.csv.gz",
"B_star": out_dir / "B_star.csv.gz",
"X_star": out_dir / "X_star.csv.gz",
"shrinkage_xstar_full": out_dir / "shrinkage_xstar_corr_full.csv.gz",
"shrinkage_xstar_BV": out_dir / "shrinkage_xstar_corr_BV.csv.gz",
}
write_df(res["V_clr"], paths["V_clr"])
write_df(res["B_clr"], paths["B_clr"])
write_df(res["V_star"], paths["V_star"])
write_df(res["B_star"], paths["B_star"])
write_df(res["X_star"], paths["X_star"])
write_df(corr_star_full, paths["shrinkage_xstar_full"])
write_df(corr_star_bv, paths["shrinkage_xstar_BV"])
logger.info("X* + Z*-shrinkage written to %s (lambda=%s)", out_dir, info_star)
return paths
[docs]
def run_network(cfg: CapelliniConfig) -> dict[str, dict]:
"""Run the network stage according to the run_* flags on the config."""
if not cfg.output_root:
raise ValueError("Network stage requires cfg.output_root (or enhanced_networks_folder)")
Path(cfg.output_root).mkdir(parents=True, exist_ok=True)
results: dict[str, dict] = {}
if cfg.run_common_abundance:
results["common"] = build_common_abundance_one(cfg)
if cfg.run_shrinkage_correlations:
results["shrinkage"] = build_shrinkage_one(cfg)
if cfg.run_raw_crispr_networks:
results["crispr_raw"] = {"crispr_net": build_raw_crispr_one(cfg)}
if cfg.run_smooth_crispr:
results["crispr_smooth"] = build_smooth_crispr_one(cfg)
if cfg.run_xstar:
results["xstar"] = build_xstar_one(cfg)
logger.info("Network stage complete (%d sub-stages)", len(results))
return results