Skip to content

pelinker.cli.fit

FitCliConfig dataclass

Hydra config for python -m pelinker.cli.fit.

Source code in pelinker/cli/fit.py
@dataclass
class FitCliConfig:
    """Hydra config for ``python -m pelinker.cli.fit``."""

    model_type: str = "pubmedbert"
    layers_spec: str = "1"
    kb_path: str = MISSING
    pca_components: int = 100
    umap_dim: int = 8
    min_class_size: int = 20
    # Stage-B HDBSCAN ``min_cluster_size`` (choose upstream, e.g. ``run/analysis/clustering_quality.py``).
    min_cluster_size: int = 20
    # Filesystem base path for ``Linker.dump`` (``.gz`` added by the linker).
    model_path: str | None = None
    # Directory for fit-time reports (``linker_fit.clustering_report.json``).
    report_path: str | None = None
    embeddings_parquet: Any = MISSING
    input_text_table_path: str | None = None
    use_gpu: bool = False
    nlp_model: str = "en_core_web_trf"
    # Stage (A): text table I/O buffer rows, encoder batch size (GPU), optional cap on read passes.
    input_buffer_rows: int = 1000
    encoder_batch_size: int = 200
    max_input_buffers: int | None = None
    negatives_per_positive: float = 0.0
    negative_label: str = NEGATIVE_LABEL
    negative_seed: int | None = None
    screener_kind: str = "lda"
    """``lda`` or ``svm``; persisted as :attr:`~pelinker.model.Linker.screener`."""
    manifold_oov_enabled: bool = True
    """When false, skip 3D manifold OOV score model (no predict-time gate from that path)."""
    # Stage (B): parquet batching (``batch_size`` rows per read batch).
    n_embedding_batches: int | None = None  # max read batches per parquet; None = all
    batch_size: int = 1000
    kb_name: str | None = None
    kb_version: str = "0.1.0"
    kb_created_at: str | None = None
    kb_description: str = ""
    kb_entity_count: int | None = None
    # Discriminator: auto = fit from parquet only if no text table; else embed then fit (legacy).
    # str (not Literal): OmegaConf structured configs reject Literal annotations on fields.
    pipeline: str = "embed_only"
    # Per-parquet backbone/layer (length 1 broadcast, or same length as ``embeddings_parquet``).
    # When omitted, ``model_type`` / ``layers_spec`` scalars apply unless the parquet stem matches
    # ``..._<model>_<layers>`` (see ``_parse_embedding_parquet_stem``).
    model_types: list[str] | None = None
    layers_specs: list[str] | None = None

    def __post_init__(self) -> None:
        if self.pipeline not in _PIPELINE_VALUES:
            raise ValueError(
                "pipeline must be one of "
                f"{sorted(_PIPELINE_VALUES)}, got {self.pipeline!r}"
            )
        if self.screener_kind not in ("lda", "svm"):
            raise ValueError(
                f"screener_kind must be 'lda' or 'svm', got {self.screener_kind!r}"
            )
        if self.min_cluster_size < 2:
            raise ValueError("min_cluster_size must be >= 2")

manifold_oov_enabled = True class-attribute instance-attribute

When false, skip 3D manifold OOV score model (no predict-time gate from that path).

screener_kind = 'lda' class-attribute instance-attribute

lda or svm; persisted as :attr:~pelinker.model.Linker.screener.

fit(cfg)

Run embedding (optional), fit a Linker from parquet(s) (optional), and write outputs.

Paths (no implicit fallbacks — missing required paths raise):

  • embeddings_parquet: output path(s) for embed_only / both stage (A), or input parquet(s) for fit_only / both stage (B).
  • report_path: directory; fit stages write linker_fit.clustering_report.json there (see :func:pelinker.reporting.linker_fit_clustering_report_path).
  • model_path: filesystem path passed to Linker.dump for fit stages.

Pipelines:

  • pipeline=auto: embed then fit if input_text_table_path is set; else fit from parquet.
  • pipeline=embed_only: write parquet(s) only (model_path / report_path not used).
  • pipeline=fit_only: fit from existing parquet(s); requires model_path and report_path.
  • pipeline=both: text table + embed then fit; requires model_path and report_path.

Multiple embeddings_parquet values fuse in list order (inner join on pmid/entity/mention). Set model_types / layers_specs (or scalars) so embedding_metadata.sources matches; or infer model_type / layers_spec from each filename stem when lists are omitted.

Source code in pelinker/cli/fit.py
def fit(cfg: FitCliConfig) -> None:
    """
    Run embedding (optional), fit a ``Linker`` from parquet(s) (optional), and write outputs.

    Paths (no implicit fallbacks — missing required paths raise):

    - ``embeddings_parquet``: output path(s) for ``embed_only`` / ``both`` stage (A), or input
      parquet(s) for ``fit_only`` / ``both`` stage (B).
    - ``report_path``: directory; fit stages write ``linker_fit.clustering_report.json`` there
      (see :func:`pelinker.reporting.linker_fit_clustering_report_path`).
    - ``model_path``: filesystem path passed to ``Linker.dump`` for fit stages.

    Pipelines:

    - ``pipeline=auto``: embed then fit if ``input_text_table_path`` is set; else fit from parquet.
    - ``pipeline=embed_only``: write parquet(s) only (``model_path`` / ``report_path`` not used).
    - ``pipeline=fit_only``: fit from existing parquet(s); requires ``model_path`` and ``report_path``.
    - ``pipeline=both``: text table + embed then fit; requires ``model_path`` and ``report_path``.

    Multiple ``embeddings_parquet`` values fuse in list order (inner join on pmid/entity/mention).
    Set ``model_types`` / ``layers_specs`` (or scalars) so ``embedding_metadata.sources`` matches;
    or infer ``model_type`` / ``layers_spec`` from each filename stem when lists are omitted.
    """
    logging.basicConfig(
        level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
    )

    kb_path = expand_config_path(cfg.kb_path)
    if kb_path is None:
        raise ValueError("kb_path must be provided")
    logger.info("Using KB: %s", kb_path)

    df0 = pd.read_csv(kb_path)

    logger.info("Loaded %s properties from KB", len(df0))

    if "entity_id" not in df0.columns:
        raise ValueError(
            "KB CSV must contain an 'entity_id' column "
            "(see run/embed_kb_corpus --kb-csv-path)."
        )
    if "label" not in df0.columns:
        raise ValueError("KB CSV must contain a 'label' column.")

    labels_map: dict[str, str] = {
        str(eid): str(lbl)
        for eid, lbl in zip(df0["entity_id"], df0["label"])
        if pd.notna(lbl)
    }

    kb_labels = set(df0["label"].dropna().unique())

    logger.info("Extracted %s unique entity labels from KB", len(kb_labels))

    transform_config = TransformConfig(
        pca_components=cfg.pca_components,
        umap_components=cfg.umap_dim,
    )

    input_text_table_path = expand_config_path(cfg.input_text_table_path)
    model_path = expand_config_path(cfg.model_path)
    report_path_resolved = expand_config_path(cfg.report_path)

    path_strs = _coerce_str_list(cfg.embeddings_parquet)
    if not path_strs:
        raise ValueError("embeddings_parquet must be one or more paths")

    embed_paths: list[Path] = []
    for s in path_strs:
        p = expand_config_path(s)
        if p is None:
            raise ValueError(f"Invalid embeddings path: {s!r}")
        embed_paths.append(p)

    mts = _coerce_optional_str_list(cfg.model_types)
    lss = _coerce_optional_str_list(cfg.layers_specs)
    embedding_metadata = _embedding_metadata(
        embed_paths, cfg.model_type, cfg.layers_spec, mts, lss
    )

    pipeline = cfg.pipeline
    if pipeline == "auto":
        effective: FitPipeline = "fit_only" if input_text_table_path is None else "both"
    else:
        effective = cast(FitPipeline, pipeline)

    if effective == "fit_only" and input_text_table_path is not None:
        raise ValueError(
            "pipeline=fit_only (or auto with no text table): omit input_text_table_path."
        )
    if effective in ("both", "embed_only") and input_text_table_path is None:
        raise ValueError(
            f"pipeline={effective} requires input_text_table_path for stage (A)."
        )

    if effective in ("fit_only", "both"):
        if model_path is None:
            raise ValueError(
                "model_path is required for pipeline fit_only, both, or auto when fitting"
            )
        if report_path_resolved is None:
            raise ValueError(
                "report_path is required for pipeline fit_only, both, or auto when fitting"
            )

    if effective == "fit_only":
        missing = [p for p in embed_paths if not p.is_file()]
        if missing:
            raise FileNotFoundError(
                f"Embedding parquet(s) not found for fit_only: {missing}"
            )

    if effective == "both":
        _abort_if_outputs_exist(
            embed_paths,
            context="pipeline=both: embeddings_parquet target(s)",
        )
    elif effective == "embed_only":
        _abort_if_outputs_exist(
            embed_paths,
            context="pipeline=embed_only",
        )

    if effective in ("both", "embed_only"):
        logger.info(
            "Stage (A): embed_kb_corpus → %s",
            embed_paths if len(embed_paths) > 1 else embed_paths[0],
        )
        training = EmbeddingTrainingConfig(
            input_text_table_path=input_text_table_path,
            kb_csv_path=kb_path,
            use_gpu=cfg.use_gpu,
            input_buffer_rows=cfg.input_buffer_rows,
            encoder_batch_size=cfg.encoder_batch_size,
            nlp_model=cfg.nlp_model,
            max_input_buffers=cfg.max_input_buffers,
            negatives_per_positive=cfg.negatives_per_positive,
            negative_label=cfg.negative_label,
            negative_seed=cfg.negative_seed,
        )
        if len(embed_paths) == 1:
            embed_kb_corpus(
                metadata=embedding_metadata,
                training=training,
                output_parquet_path=embed_paths[0],
            )
        else:
            embed_kb_corpus(
                metadata=embedding_metadata,
                training=training,
                output_parquet_paths=tuple(embed_paths),
            )

    if effective == "embed_only":
        logger.info("Embed-only pipeline finished; not fitting or saving a linker.")
        return

    linker_fit_cfg = LinkerFitConfig(
        min_class_size=cfg.min_class_size,
        batch_size=cfg.batch_size,
        n_embedding_batches=cfg.n_embedding_batches,
        negative_screener=NegativeScreenerConfig(
            kind=cfg.screener_kind,
            negative_label=cfg.negative_label,
        ),
        manifold_oov_screener=ManifoldOovScreenerConfig(
            enabled=cfg.manifold_oov_enabled,
        ),
    )

    kb_created = (
        date.fromisoformat(cfg.kb_created_at) if cfg.kb_created_at else date.today()
    )
    kb_display_name = (cfg.kb_name or "").strip() or kb_path.stem
    kb_config = KBConfig(
        name=kb_display_name,
        version=cfg.kb_version,
        created_at=kb_created,
        description=cfg.kb_description,
        entity_count=cfg.kb_entity_count,
    )

    linker = Linker(
        labels_map=labels_map,
        transform_config=transform_config,
        embedding_metadata=embedding_metadata,
    )

    logger.info("Stage (B): Linker.fit from %s", embed_paths)

    linker.fit(
        embeddings=embed_paths if len(embed_paths) > 1 else embed_paths[0],
        transform_config=transform_config,
        min_cluster_size=cfg.min_cluster_size,
        fit_config=linker_fit_cfg,
        embedding_training=None,
        kb_config=kb_config,
    )

    logger.info("Fitted Linker model with %s entities", len(linker.vocabulary))
    logger.info(
        "Number of clusters: %s",
        len(set(linker.cluster_assignments.values())),
    )

    if model_path is None or report_path_resolved is None:
        raise ValueError("model_path and report_path must be set when fitting")

    report_path_resolved.mkdir(parents=True, exist_ok=True)
    fit_report = linker.take_fit_clustering_report()
    if fit_report is None:
        raise RuntimeError("Linker.fit produced no clustering report to serialize")
    report_json = linker_fit_clustering_report_path(report_path_resolved)
    write_clustering_report_json(report_json, fit_report)
    logger.info("Wrote clustering report to %s", report_json)

    logger.info("Saving model to %s", model_path)
    linker.dump(model_path)

    logger.info("Model saved successfully!")