def fit(cfg: FitCliConfig) -> None:
"""
Run embedding (optional), fit a ``Linker`` from parquet(s) (optional), and write outputs.
Paths (no implicit fallbacks — missing required paths raise):
- ``embeddings_parquet``: output path(s) for ``embed_only`` / ``both`` stage (A), or input
parquet(s) for ``fit_only`` / ``both`` stage (B).
- ``report_path``: directory; fit stages write ``linker_fit.clustering_report.json`` there
(see :func:`pelinker.reporting.linker_fit_clustering_report_path`).
- ``model_path``: filesystem path passed to ``Linker.dump`` for fit stages.
Pipelines:
- ``pipeline=auto``: embed then fit if ``input_text_table_path`` is set; else fit from parquet.
- ``pipeline=embed_only``: write parquet(s) only (``model_path`` / ``report_path`` not used).
- ``pipeline=fit_only``: fit from existing parquet(s); requires ``model_path`` and ``report_path``.
- ``pipeline=both``: text table + embed then fit; requires ``model_path`` and ``report_path``.
Multiple ``embeddings_parquet`` values fuse in list order (inner join on pmid/entity/mention).
Set ``model_types`` / ``layers_specs`` (or scalars) so ``embedding_metadata.sources`` matches;
or infer ``model_type`` / ``layers_spec`` from each filename stem when lists are omitted.
"""
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
kb_path = expand_config_path(cfg.kb_path)
if kb_path is None:
raise ValueError("kb_path must be provided")
logger.info("Using KB: %s", kb_path)
df0 = pd.read_csv(kb_path)
logger.info("Loaded %s properties from KB", len(df0))
if "entity_id" not in df0.columns:
raise ValueError(
"KB CSV must contain an 'entity_id' column "
"(see run/embed_kb_corpus --kb-csv-path)."
)
if "label" not in df0.columns:
raise ValueError("KB CSV must contain a 'label' column.")
labels_map: dict[str, str] = {
str(eid): str(lbl)
for eid, lbl in zip(df0["entity_id"], df0["label"])
if pd.notna(lbl)
}
kb_labels = set(df0["label"].dropna().unique())
logger.info("Extracted %s unique entity labels from KB", len(kb_labels))
transform_config = TransformConfig(
pca_components=cfg.pca_components,
umap_components=cfg.umap_dim,
)
input_text_table_path = expand_config_path(cfg.input_text_table_path)
model_path = expand_config_path(cfg.model_path)
report_path_resolved = expand_config_path(cfg.report_path)
path_strs = _coerce_str_list(cfg.embeddings_parquet)
if not path_strs:
raise ValueError("embeddings_parquet must be one or more paths")
embed_paths: list[Path] = []
for s in path_strs:
p = expand_config_path(s)
if p is None:
raise ValueError(f"Invalid embeddings path: {s!r}")
embed_paths.append(p)
mts = _coerce_optional_str_list(cfg.model_types)
lss = _coerce_optional_str_list(cfg.layers_specs)
embedding_metadata = _embedding_metadata(
embed_paths, cfg.model_type, cfg.layers_spec, mts, lss
)
pipeline = cfg.pipeline
if pipeline == "auto":
effective: FitPipeline = "fit_only" if input_text_table_path is None else "both"
else:
effective = cast(FitPipeline, pipeline)
if effective == "fit_only" and input_text_table_path is not None:
raise ValueError(
"pipeline=fit_only (or auto with no text table): omit input_text_table_path."
)
if effective in ("both", "embed_only") and input_text_table_path is None:
raise ValueError(
f"pipeline={effective} requires input_text_table_path for stage (A)."
)
if effective in ("fit_only", "both"):
if model_path is None:
raise ValueError(
"model_path is required for pipeline fit_only, both, or auto when fitting"
)
if report_path_resolved is None:
raise ValueError(
"report_path is required for pipeline fit_only, both, or auto when fitting"
)
if effective == "fit_only":
missing = [p for p in embed_paths if not p.is_file()]
if missing:
raise FileNotFoundError(
f"Embedding parquet(s) not found for fit_only: {missing}"
)
if effective == "both":
_abort_if_outputs_exist(
embed_paths,
context="pipeline=both: embeddings_parquet target(s)",
)
elif effective == "embed_only":
_abort_if_outputs_exist(
embed_paths,
context="pipeline=embed_only",
)
if effective in ("both", "embed_only"):
logger.info(
"Stage (A): embed_kb_corpus → %s",
embed_paths if len(embed_paths) > 1 else embed_paths[0],
)
training = EmbeddingTrainingConfig(
input_text_table_path=input_text_table_path,
kb_csv_path=kb_path,
use_gpu=cfg.use_gpu,
input_buffer_rows=cfg.input_buffer_rows,
encoder_batch_size=cfg.encoder_batch_size,
nlp_model=cfg.nlp_model,
max_input_buffers=cfg.max_input_buffers,
negatives_per_positive=cfg.negatives_per_positive,
negative_label=cfg.negative_label,
negative_seed=cfg.negative_seed,
)
if len(embed_paths) == 1:
embed_kb_corpus(
metadata=embedding_metadata,
training=training,
output_parquet_path=embed_paths[0],
)
else:
embed_kb_corpus(
metadata=embedding_metadata,
training=training,
output_parquet_paths=tuple(embed_paths),
)
if effective == "embed_only":
logger.info("Embed-only pipeline finished; not fitting or saving a linker.")
return
linker_fit_cfg = LinkerFitConfig(
min_class_size=cfg.min_class_size,
batch_size=cfg.batch_size,
n_embedding_batches=cfg.n_embedding_batches,
negative_screener=NegativeScreenerConfig(
kind=cfg.screener_kind,
negative_label=cfg.negative_label,
),
manifold_oov_screener=ManifoldOovScreenerConfig(
enabled=cfg.manifold_oov_enabled,
),
)
kb_created = (
date.fromisoformat(cfg.kb_created_at) if cfg.kb_created_at else date.today()
)
kb_display_name = (cfg.kb_name or "").strip() or kb_path.stem
kb_config = KBConfig(
name=kb_display_name,
version=cfg.kb_version,
created_at=kb_created,
description=cfg.kb_description,
entity_count=cfg.kb_entity_count,
)
linker = Linker(
labels_map=labels_map,
transform_config=transform_config,
embedding_metadata=embedding_metadata,
)
logger.info("Stage (B): Linker.fit from %s", embed_paths)
linker.fit(
embeddings=embed_paths if len(embed_paths) > 1 else embed_paths[0],
transform_config=transform_config,
min_cluster_size=cfg.min_cluster_size,
fit_config=linker_fit_cfg,
embedding_training=None,
kb_config=kb_config,
)
logger.info("Fitted Linker model with %s entities", len(linker.vocabulary))
logger.info(
"Number of clusters: %s",
len(set(linker.cluster_assignments.values())),
)
if model_path is None or report_path_resolved is None:
raise ValueError("model_path and report_path must be set when fitting")
report_path_resolved.mkdir(parents=True, exist_ok=True)
fit_report = linker.take_fit_clustering_report()
if fit_report is None:
raise RuntimeError("Linker.fit produced no clustering report to serialize")
report_json = linker_fit_clustering_report_path(report_path_resolved)
write_clustering_report_json(report_json, fit_report)
logger.info("Wrote clustering report to %s", report_json)
logger.info("Saving model to %s", model_path)
linker.dump(model_path)
logger.info("Model saved successfully!")