@click.command(context_settings={"help_option_names": ["-h", "--help"]})
@click.option(
"-m",
"--model",
"model_path",
type=click.Path(path_type=Path),
required=True,
help="Linker artifact path (same as Linker.dump / Linker.load, with or without .gz).",
)
@click.option(
"--thr-score",
type=float,
default=0.5,
show_default=True,
help="Minimum cluster membership score (same role as server thr_score).",
)
@click.option(
"--use-gpu",
is_flag=True,
help="Move transformer heads to CUDA when available.",
)
@click.option(
"--include-anomaly-metrics",
is_flag=True,
help="Include PCA residual / Mahalanobis anomaly metrics in entity outputs.",
)
@click.option(
"--kb-validation",
is_flag=True,
help="Include kb matching",
)
@click.option(
"-o",
"--output",
"output_path",
type=click.Path(path_type=Path),
default=None,
help="Write the entity report JSON (UTF-8) to this path.",
)
@click.option(
"--dump-mention-anomaly",
"dump_mention_anomaly",
type=click.Path(path_type=Path),
default=None,
help=(
"If set, write one row per extracted mention with is_kb_match and PCA anomaly "
"metrics (residual / Mahalanobis / max-z). Format inferred from extension: "
".parquet, .csv, .jsonl."
),
)
@click.option(
"--max-length",
type=int,
default=MAX_LENGTH,
show_default=True,
help="Tokenizer chunk length.",
)
@click.argument(
"files",
nargs=-1,
required=True,
type=click.Path(exists=True, readable=True, path_type=Path),
)
def main(
model_path: Path,
files: tuple[Path, ...],
thr_score: float,
use_gpu: bool,
include_anomaly_metrics: bool,
kb_validation: bool,
output_path: Path | None,
dump_mention_anomaly: Path | None,
max_length: int,
) -> None:
"""Load a dumped Linker and predict entities for each input.
Inputs are UTF-8 files. If a file parses as JSON, supported shapes are:
\\b
- A single object: {"text": "...", "ground_truth": [ ... optional hits ... ]}
- A list of objects: [{"text": "...", "ground_truth": [...]}, ...]
Each optional ``ground_truth`` hit is typically an object with character offsets
``a``, ``b`` and a class / entity id (e.g. ``entity_id``). ``itext`` in the file
is ignored and rewritten to match the global document index in the output.
Any file that is not valid structured JSON (or does not start with ``{`` / ``[``)
is read as plain text (one document per file).
"""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
texts, ground_truth_by_doc = _flatten_inputs(files)
if not texts:
logger.error("No input documents after parsing files")
raise SystemExit(1)
try:
linker = Linker.load(model_path)
except FileNotFoundError:
logger.exception("Model not found (expected .gz next to the given path)")
raise SystemExit(1)
want_mention_dump = dump_mention_anomaly is not None
try:
pres = linker.predict(
texts,
max_length=max_length,
threshold=0.0,
use_gpu=use_gpu,
include_mention_anomaly=want_mention_dump,
include_prediction_kb_validation=kb_validation,
)
filtered = pres.filter_by_score(thr_score)
public_entity_fields = not include_anomaly_metrics and not kb_validation
out = filtered.to_dict(
include_entity_anomaly_metrics=include_anomaly_metrics,
public_entity_fields=public_entity_fields,
)
except Exception:
logger.exception("predict failed")
raise SystemExit(1)
if want_mention_dump:
try:
rows = list(pres.debug_mentions) if pres.debug_mentions is not None else []
_write_mention_anomaly(dump_mention_anomaly, rows)
except Exception:
logger.exception("mention anomaly dump failed")
raise SystemExit(1)
if any(g is not None for g in ground_truth_by_doc):
out["ground_truth"] = ground_truth_by_doc
if output_path is not None:
output_path.parent.mkdir(parents=True, exist_ok=True)
payload = json.dumps(_sanitize_for_json(out), ensure_ascii=False)
output_path.write_text(payload + "\n", encoding="utf-8")