Source code for capellini.stages.network

"""Network stage: common abundance, shrinkage, raw/smoothed CRISPR, residual X*.

Slim core. The math follows the paper (CLR + W̃ taxonomy smoothing + residual
message passing). Schäfer-Strimmer shrinkage correlations are computed both
on the joint Z = [B_clr  V_clr] and on the propagated Z* = [B_star  V_star].

Outputs go directly under ``cfg.output_root`` — no per-study sub-folder.
"""

from __future__ import annotations

import logging
from pathlib import Path

import pandas as pd

from capellini.config import CapelliniConfig
from capellini.utils.io import read_table, write_df
from capellini.utils.taxonomy import (
    clean_df_ids,
    load_bacteria_taxonomy,
)
from capellini.utils.transforms import clr, schaefer_strimmer_corr
from capellini.utils.network_utils import (
    align_abundance_from_metadata,
    build_smoothed_crispr_for_study,
    build_xstar_from_smoothed_crispr,
    crispr_matrix_aggregate_viruses,
    orient_W_viruses_by_bacteria,
    prepare_bacteria_genus_abundance,
    prevalence_filter_df,
    remove_disease_columns_from_virus_abundance,
)

logger = logging.getLogger(__name__)


def _subdir(cfg: CapelliniConfig, name: str) -> Path:
    out = Path(cfg.output_root) / name
    out.mkdir(parents=True, exist_ok=True)
    return out


def _shrinkage_block(joint_clr: pd.DataFrame, B_cols, V_cols) -> tuple[pd.DataFrame, pd.DataFrame, dict]:
    """Run SS on a CLR-transformed [B  V] frame; return (corr_full, corr_BV, info)."""
    corr_df, info = schaefer_strimmer_corr(joint_clr)
    bv = corr_df.loc[B_cols, V_cols]
    return corr_df, bv, info


[docs] def build_common_abundance_one(cfg: CapelliniConfig) -> dict[str, Path]: logger.info("common abundance: loading raw inputs") V_raw = read_table(cfg.virus_abundance_raw) V_raw = remove_disease_columns_from_virus_abundance(V_raw) V_raw = clean_df_ids(V_raw) B_otu = read_table(cfg.bacteria_otu) B_tax = load_bacteria_taxonomy(cfg.bacteria_taxonomy) B_genus = prepare_bacteria_genus_abundance(B_otu, B_tax, rank=cfg.bacteria_taxonomy_rank) meta = read_table(cfg.metadata_path, index_col=None) V, B, meta_aligned = align_abundance_from_metadata( virus_abundance=V_raw, bacteria_abundance=B_genus, metadata=meta, keep_col=cfg.keep_column, ) V = prevalence_filter_df(V, prevalence=cfg.prevalence) B = prevalence_filter_df(B, prevalence=cfg.prevalence) out_dir = _subdir(cfg, "common") paths = { "V": out_dir / "V_processed.csv", "B": out_dir / "B_processed.csv", "meta": out_dir / "metadata_aligned.csv", } write_df(V, paths["V"]) write_df(B, paths["B"]) write_df(meta_aligned, paths["meta"]) logger.info("common abundance written to %s", out_dir) return paths
[docs] def build_shrinkage_one(cfg: CapelliniConfig) -> dict[str, Path]: """SS shrinkage correlation on Z = [B_clr V_clr].""" out_dir = _subdir(cfg, "shrinkage") common = Path(cfg.output_root) / "common" V = read_table(common / "V_processed.csv") B = read_table(common / "B_processed.csv").loc[V.index] V_clr = pd.DataFrame(clr(V, pseudocount=cfg.pseudocount), index=V.index, columns=V.columns) B_clr = pd.DataFrame(clr(B, pseudocount=cfg.pseudocount), index=B.index, columns=B.columns) V_pref = V_clr.add_prefix("V__") B_pref = B_clr.add_prefix("B__") joint = pd.concat([B_pref, V_pref], axis=1) corr_df, bv, info = _shrinkage_block(joint, B_pref.columns, V_pref.columns) paths = { "corr_full": out_dir / "shrinkage_corr_full.csv.gz", "corr_bv": out_dir / "shrinkage_corr_BV.csv.gz", } write_df(corr_df, paths["corr_full"]) write_df(bv, paths["corr_bv"]) logger.info("shrinkage written (lambda=%s)", info) return paths
[docs] def build_raw_crispr_one(cfg: CapelliniConfig) -> Path: """Aggregate the SpacePHARER predictions into a (bac × vOTU) CRISPR matrix.""" out_dir = _subdir(cfg, "crispr_raw") df_pred = pd.read_csv(cfg.phage_host_predictions, sep="\t", header=None, comment="#") vir_tax = read_table(cfg.tax_vir) crispr = crispr_matrix_aggregate_viruses(df_pred, vir_tax, vir_rank=cfg.aggregate_viral_rank) out = out_dir / "crispr_net.csv.gz" write_df(crispr, out) logger.info("raw CRISPR matrix written: %s", out) return out
[docs] def build_smooth_crispr_one(cfg: CapelliniConfig) -> dict[str, Path]: """Smooth the raw CRISPR matrix via taxonomy kernels; persist all artefacts.""" out_dir = _subdir(cfg, "crispr_smooth") common = Path(cfg.output_root) / "common" V = read_table(common / "V_processed.csv") B = read_table(common / "B_processed.csv") crispr_path = Path(cfg.output_root) / "crispr_raw" / "crispr_net.csv.gz" art = build_smoothed_crispr_for_study( raw_crispr_path=str(crispr_path), bacteria_features=B.columns, virus_features=V.columns, tax_bac_path=cfg.tax_bac_for_smoothing, tax_vir_path=cfg.tax_vir, bacterial_ranks=cfg.bacterial_ranks, viral_ranks=cfg.viral_ranks, bacterial_weights=cfg.bacterial_weights, viral_weights=cfg.viral_weights, alpha=cfg.crispr_smooth_alpha, transpose_after_load=cfg.transpose_raw_crispr_after_load, ) paths: dict[str, Path] = {} for name in ("crispr_binary", "crispr_smooth", "K_bac", "K_vir"): p = out_dir / f"{name}.csv.gz" write_df(art[name], p) paths[name] = p p_vb = out_dir / "crispr_smooth_vir_bac.csv.gz" write_df(art["crispr_smooth"].T, p_vb) paths["crispr_smooth_vir_bac"] = p_vb logger.info("smoothed CRISPR written to %s", out_dir) return paths
[docs] def build_xstar_one(cfg: CapelliniConfig) -> dict[str, Path]: """Residual X* propagation + Schäfer-Strimmer correlations on Z*.""" out_dir = _subdir(cfg, "xstar") common = Path(cfg.output_root) / "common" smooth = Path(cfg.output_root) / "crispr_smooth" V = read_table(common / "V_processed.csv") B = read_table(common / "B_processed.csv") W_path = smooth / "crispr_smooth_vir_bac.csv.gz" if W_path.exists(): W = read_table(W_path) else: W = read_table(smooth / "crispr_smooth.csv.gz").T W = orient_W_viruses_by_bacteria(W, V, B) res = build_xstar_from_smoothed_crispr( V_df=V, B_df=B, W_vh_smooth_df=W, pseudocount=cfg.pseudocount, lam=cfg.lam, n_steps=cfg.n_steps, ) # Shrinkage on Z* = [B_star V_star] V_pref = res["V_star"].add_prefix("V__") B_pref = res["B_star"].add_prefix("B__") joint_star = pd.concat([B_pref, V_pref], axis=1) corr_star_full, corr_star_bv, info_star = _shrinkage_block(joint_star, B_pref.columns, V_pref.columns) paths = { "V_clr": out_dir / "V_clr.csv.gz", "B_clr": out_dir / "B_clr.csv.gz", "V_star": out_dir / "V_star.csv.gz", "B_star": out_dir / "B_star.csv.gz", "X_star": out_dir / "X_star.csv.gz", "shrinkage_xstar_full": out_dir / "shrinkage_xstar_corr_full.csv.gz", "shrinkage_xstar_BV": out_dir / "shrinkage_xstar_corr_BV.csv.gz", } write_df(res["V_clr"], paths["V_clr"]) write_df(res["B_clr"], paths["B_clr"]) write_df(res["V_star"], paths["V_star"]) write_df(res["B_star"], paths["B_star"]) write_df(res["X_star"], paths["X_star"]) write_df(corr_star_full, paths["shrinkage_xstar_full"]) write_df(corr_star_bv, paths["shrinkage_xstar_BV"]) logger.info("X* + Z*-shrinkage written to %s (lambda=%s)", out_dir, info_star) return paths
[docs] def run_network(cfg: CapelliniConfig) -> dict[str, dict]: """Run the network stage according to the run_* flags on the config.""" if not cfg.output_root: raise ValueError("Network stage requires cfg.output_root (or enhanced_networks_folder)") Path(cfg.output_root).mkdir(parents=True, exist_ok=True) results: dict[str, dict] = {} if cfg.run_common_abundance: results["common"] = build_common_abundance_one(cfg) if cfg.run_shrinkage_correlations: results["shrinkage"] = build_shrinkage_one(cfg) if cfg.run_raw_crispr_networks: results["crispr_raw"] = {"crispr_net": build_raw_crispr_one(cfg)} if cfg.run_smooth_crispr: results["crispr_smooth"] = build_smooth_crispr_one(cfg) if cfg.run_xstar: results["xstar"] = build_xstar_one(cfg) logger.info("Network stage complete (%d sub-stages)", len(results)) return results