Source code for capellini.pipeline

"""Top-level pipeline orchestrator."""

from __future__ import annotations

import logging
from typing import Any

from capellini.config import CapelliniConfig
from capellini.stages import (
    dada2 as dada2_stage,
    mmseqs2 as mmseqs2_stage,
    ncbi_mapping as ncbi_stage,
    network as network_stage,
    preflight as preflight_stage,
    procs as procs_stage,
    spacepharer as spacepharer_stage,
)

logger = logging.getLogger(__name__)


def _build_gca_target_set(silva_fixed, species_level: bool) -> set:
    """Replicate the SpacePHARER stage's target-GCA derivation."""
    if species_level:
        return (
            set(silva_fixed["GCA_species"].dropna())
            | set(silva_fixed["GCA_family"].dropna())
        )
    return (
        set(silva_fixed["GCA_genus"].dropna())
        | set(silva_fixed["GCA_family"].dropna())
    )


[docs] class CapelliniPipeline: """Run the CAPELLINI pipeline stages in order, sharing inter-stage state.""" STAGE_ORDER = [ "preflight", "dada2", "ncbi_mapping", "spacepharer", "procs", "network", ] STAGE_LABELS = { "preflight": "Preflight", "dada2": "DADA2", "ncbi_mapping": "3-layer NCBI ID Mapping", "spacepharer": "SpacePHARER Execution", "procs": "Protein Clusters (ProCs) Estimation", "network": "Enhanced Networks Estimation", } def __init__(self, config: CapelliniConfig) -> None: self.config = config self.state: dict[str, Any] = {} if not logging.getLogger().handlers: logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) # ── Public API ───────────────────────────────────────────────────────────
[docs] def run_all(self) -> None: """Run every stage in STAGE_ORDER.""" for stage in self.STAGE_ORDER: self.run_stage(stage)
[docs] def run_from(self, name: str) -> None: """Run from the given stage to the end.""" idx = self.STAGE_ORDER.index(name) for stage in self.STAGE_ORDER[idx:]: self.run_stage(stage)
[docs] def run_stage(self, name: str) -> Any: """Dispatch a single stage by name.""" logger.info("=== Stage: %s ===", name) if name == "preflight": return preflight_stage.run_preflight(self.config) if name == "dada2": return dada2_stage.run_dada2(self.config) if name == "ncbi_mapping": tt = ncbi_stage.run_ncbi_mapping(self.config) self.state["taxonomy_table"] = tt silva_fixed = mmseqs2_stage.run_mmseqs2(self.config, tt) self.state["silva_fixed"] = silva_fixed self.state["gca_target_set"] = _build_gca_target_set( silva_fixed, self.config.species_level ) return silva_fixed if name == "spacepharer": silva_fixed = self._require("silva_fixed") return spacepharer_stage.run_spacepharer(self.config, silva_fixed) if name == "procs": gca = self.state.get("gca_target_set") if gca is None: silva_fixed = self._require("silva_fixed") gca = _build_gca_target_set(silva_fixed, self.config.species_level) self.state["gca_target_set"] = gca pa = procs_stage.run_procs(self.config, gca) self.state["pa_matrix"] = pa return pa if name == "network": return network_stage.run_network(self.config) raise ValueError(f"Unknown stage: {name!r}")
# ── Internal helpers ───────────────────────────────────────────────────── def _require(self, key: str) -> Any: if key not in self.state: raise RuntimeError( f"Stage requires '{key}' in pipeline state. " f"Run upstream stages first or call run_all()." ) return self.state[key]