Skip to content

ontocast.tool.vector_store.patch_retriever

Retrieves multi-ontology context patches from vector search.

OntologyPatchRetriever

Bases: Tool

Combines vector retrieval into one composite ontology graph.

Source code in ontocast/tool/vector_store/patch_retriever.py
class OntologyPatchRetriever(Tool):
    """Combines vector retrieval into one composite ontology graph."""

    vector_store: QdrantVectorStore = Field(exclude=True)
    sparql_tool: Any | None = Field(default=None, exclude=True)
    patch: PatchRetrievalConfig = Field(
        default_factory=PatchRetrievalConfig,
        exclude=True,
    )

    def _effective_top_k(self, top_k: int | None) -> int:
        if top_k is not None:
            return top_k
        return self.vector_store.config.top_k

    def retrieve(
        self,
        query: str,
        top_k: int | None = None,
        expand_sparql: bool = True,
        subgraph_depth: int = 1,
        max_total_triples: int = 300,
        estimated_triples_per_query: int = 24,
    ) -> tuple[RDFGraph, list[str]]:
        """Retrieve top-k hits for one query and optional induced subgraph; returns source ontology IRIs."""
        try:
            asyncio.get_running_loop()
        except RuntimeError:
            return asyncio.run(
                self.aretrieve(
                    query=query,
                    top_k=top_k,
                    expand_sparql=expand_sparql,
                    subgraph_depth=subgraph_depth,
                    max_total_triples=max_total_triples,
                    estimated_triples_per_query=estimated_triples_per_query,
                )
            )
        raise RuntimeError(
            "retrieve() cannot be called from async code; use await aretrieve()"
        )

    def retrieve_ensemble(
        self,
        queries: list[str],
        top_k: int | None = None,
        expand_sparql: bool = True,
        subgraph_depth: int = 1,
        max_total_triples: int = 300,
        estimated_triples_per_query: int = 24,
    ) -> tuple[RDFGraph, list[str]]:
        """Sync: one induced graph and source IRIs for the union of vector hits over ``queries``."""
        try:
            asyncio.get_running_loop()
        except RuntimeError:
            return asyncio.run(
                self.aretrieve_ensemble(
                    queries=queries,
                    top_k=top_k,
                    expand_sparql=expand_sparql,
                    subgraph_depth=subgraph_depth,
                    max_total_triples=max_total_triples,
                    estimated_triples_per_query=estimated_triples_per_query,
                )
            )
        raise RuntimeError(
            "retrieve_ensemble() is not allowed inside async code; use aretrieve_ensemble()"
        )

    async def aretrieve(
        self,
        query: str,
        top_k: int | None = None,
        expand_sparql: bool = True,
        subgraph_depth: int = 1,
        max_total_triples: int = 300,
        estimated_triples_per_query: int = 24,
    ) -> tuple[RDFGraph, list[str]]:
        """Async single-query variant of :meth:`aretrieve_ensemble`."""
        return await self.aretrieve_ensemble(
            queries=[query],
            top_k=top_k,
            expand_sparql=expand_sparql,
            subgraph_depth=subgraph_depth,
            max_total_triples=max_total_triples,
            estimated_triples_per_query=estimated_triples_per_query,
        )

    async def aretrieve_ensemble(
        self,
        queries: list[str],
        top_k: int | None = None,
        expand_sparql: bool = True,
        subgraph_depth: int = 1,
        max_total_triples: int = 300,
        estimated_triples_per_query: int = 24,
    ) -> tuple[RDFGraph, list[str]]:
        """Vector search over all ``queries`` once, score-filter, dedupe, single subgraph expansion.

        Hits are filtered per query and per channel relative to each channel's best
        score (see ``PatchRetrievalConfig`` per-query ratio fields for core,
        neighborhood, and BM25), then merged by rank fusion so channels with
        different score distributions all contribute. Optional per-channel
        min-best filters and ``min_merged_max_score`` reject weak or irrelevant
        candidates.

        Returns the merged RDF graph (possibly disconnected across ontologies) and sorted
        distinct ontology IRIs that contributed vector hits.
        """
        if not queries:
            return RDFGraph(), []
        eff_top_k = self._effective_top_k(top_k)
        hits_by_query = await self.vector_store.asearch_patch_hits_many(
            queries=queries,
            top_k=eff_top_k,
        )
        qc = self.vector_store.config
        pc = self.patch
        merged = _filter_and_merge_patch_hits(
            hits_by_query,
            qdrant_config=qc,
            per_query_core_score_ratio=pc.per_query_core_score_ratio,
            per_query_neighborhood_score_ratio=pc.per_query_neighborhood_score_ratio,
            per_query_bm25_score_ratio=pc.per_query_bm25_score_ratio,
            min_core_query_best_score=pc.min_core_query_best_score,
            min_neighborhood_query_best_score=pc.min_neighborhood_query_best_score,
            min_bm25_query_best_score=pc.min_bm25_query_best_score,
            min_merged_max_score=pc.min_merged_max_score,
        )
        if merged and pc.merged_score_ratio > 0.0:
            merged_top = float(merged[0].score or 0.0)
            merged_floor = merged_top * pc.merged_score_ratio
            merged = [
                atom for atom in merged if float(atom.score or 0.0) >= merged_floor
            ]
        if merged and pc.mmr_lambda < 1.0:
            vectors = await self.vector_store.afetch_vectors(
                [atom.atom_id for atom in merged]
            )
            merged = _mmr_rerank(
                merged,
                vectors,
                mmr_lambda=pc.mmr_lambda,
                max_atoms=pc.max_atoms,
                core_weight=qc.fusion_core_weight,
                neighborhood_weight=qc.fusion_neighborhood_weight,
            )
        elif pc.max_atoms > 0:
            merged = merged[: pc.max_atoms]
        source_iris = _source_iris_from_atoms(merged)

        if not expand_sparql or self.sparql_tool is None:
            return RDFGraph(), source_iris

        if not merged:
            return RDFGraph(), []

        entity_uris, entity_relevance = _ranked_entity_weights(merged)
        ontology_iris = sorted(
            {atom.ontology_iri for atom in merged if atom.ontology_iri}
        )
        ontology_version_filters: dict[str, set[str]] = {}
        ontology_hash_filters: dict[str, set[str]] = {}
        for atom in merged:
            if atom.ontology_iri and atom.ontology_version:
                ontology_version_filters.setdefault(atom.ontology_iri, set()).add(
                    str(atom.ontology_version)
                )
            if atom.ontology_iri and atom.ontology_hash:
                ontology_hash_filters.setdefault(atom.ontology_iri, set()).add(
                    atom.ontology_hash
                )

        graph = await self.sparql_tool.aget_induced_subgraph(
            entity_uris=entity_uris,
            entity_relevance=entity_relevance,
            ontology_iris=ontology_iris,
            depth=subgraph_depth,
            max_total_triples=max_total_triples,
            estimated_triples_per_query=estimated_triples_per_query,
            ontology_version_filters=ontology_version_filters or None,
            ontology_hash_filters=ontology_hash_filters or None,
        )
        _bind_common_vocab_prefixes(graph)
        return graph, source_iris

aretrieve(query, top_k=None, expand_sparql=True, subgraph_depth=1, max_total_triples=300, estimated_triples_per_query=24) async

Async single-query variant of :meth:aretrieve_ensemble.

Source code in ontocast/tool/vector_store/patch_retriever.py
async def aretrieve(
    self,
    query: str,
    top_k: int | None = None,
    expand_sparql: bool = True,
    subgraph_depth: int = 1,
    max_total_triples: int = 300,
    estimated_triples_per_query: int = 24,
) -> tuple[RDFGraph, list[str]]:
    """Async single-query variant of :meth:`aretrieve_ensemble`."""
    return await self.aretrieve_ensemble(
        queries=[query],
        top_k=top_k,
        expand_sparql=expand_sparql,
        subgraph_depth=subgraph_depth,
        max_total_triples=max_total_triples,
        estimated_triples_per_query=estimated_triples_per_query,
    )

aretrieve_ensemble(queries, top_k=None, expand_sparql=True, subgraph_depth=1, max_total_triples=300, estimated_triples_per_query=24) async

Vector search over all queries once, score-filter, dedupe, single subgraph expansion.

Hits are filtered per query and per channel relative to each channel's best score (see PatchRetrievalConfig per-query ratio fields for core, neighborhood, and BM25), then merged by rank fusion so channels with different score distributions all contribute. Optional per-channel min-best filters and min_merged_max_score reject weak or irrelevant candidates.

Returns the merged RDF graph (possibly disconnected across ontologies) and sorted distinct ontology IRIs that contributed vector hits.

Source code in ontocast/tool/vector_store/patch_retriever.py
async def aretrieve_ensemble(
    self,
    queries: list[str],
    top_k: int | None = None,
    expand_sparql: bool = True,
    subgraph_depth: int = 1,
    max_total_triples: int = 300,
    estimated_triples_per_query: int = 24,
) -> tuple[RDFGraph, list[str]]:
    """Vector search over all ``queries`` once, score-filter, dedupe, single subgraph expansion.

    Hits are filtered per query and per channel relative to each channel's best
    score (see ``PatchRetrievalConfig`` per-query ratio fields for core,
    neighborhood, and BM25), then merged by rank fusion so channels with
    different score distributions all contribute. Optional per-channel
    min-best filters and ``min_merged_max_score`` reject weak or irrelevant
    candidates.

    Returns the merged RDF graph (possibly disconnected across ontologies) and sorted
    distinct ontology IRIs that contributed vector hits.
    """
    if not queries:
        return RDFGraph(), []
    eff_top_k = self._effective_top_k(top_k)
    hits_by_query = await self.vector_store.asearch_patch_hits_many(
        queries=queries,
        top_k=eff_top_k,
    )
    qc = self.vector_store.config
    pc = self.patch
    merged = _filter_and_merge_patch_hits(
        hits_by_query,
        qdrant_config=qc,
        per_query_core_score_ratio=pc.per_query_core_score_ratio,
        per_query_neighborhood_score_ratio=pc.per_query_neighborhood_score_ratio,
        per_query_bm25_score_ratio=pc.per_query_bm25_score_ratio,
        min_core_query_best_score=pc.min_core_query_best_score,
        min_neighborhood_query_best_score=pc.min_neighborhood_query_best_score,
        min_bm25_query_best_score=pc.min_bm25_query_best_score,
        min_merged_max_score=pc.min_merged_max_score,
    )
    if merged and pc.merged_score_ratio > 0.0:
        merged_top = float(merged[0].score or 0.0)
        merged_floor = merged_top * pc.merged_score_ratio
        merged = [
            atom for atom in merged if float(atom.score or 0.0) >= merged_floor
        ]
    if merged and pc.mmr_lambda < 1.0:
        vectors = await self.vector_store.afetch_vectors(
            [atom.atom_id for atom in merged]
        )
        merged = _mmr_rerank(
            merged,
            vectors,
            mmr_lambda=pc.mmr_lambda,
            max_atoms=pc.max_atoms,
            core_weight=qc.fusion_core_weight,
            neighborhood_weight=qc.fusion_neighborhood_weight,
        )
    elif pc.max_atoms > 0:
        merged = merged[: pc.max_atoms]
    source_iris = _source_iris_from_atoms(merged)

    if not expand_sparql or self.sparql_tool is None:
        return RDFGraph(), source_iris

    if not merged:
        return RDFGraph(), []

    entity_uris, entity_relevance = _ranked_entity_weights(merged)
    ontology_iris = sorted(
        {atom.ontology_iri for atom in merged if atom.ontology_iri}
    )
    ontology_version_filters: dict[str, set[str]] = {}
    ontology_hash_filters: dict[str, set[str]] = {}
    for atom in merged:
        if atom.ontology_iri and atom.ontology_version:
            ontology_version_filters.setdefault(atom.ontology_iri, set()).add(
                str(atom.ontology_version)
            )
        if atom.ontology_iri and atom.ontology_hash:
            ontology_hash_filters.setdefault(atom.ontology_iri, set()).add(
                atom.ontology_hash
            )

    graph = await self.sparql_tool.aget_induced_subgraph(
        entity_uris=entity_uris,
        entity_relevance=entity_relevance,
        ontology_iris=ontology_iris,
        depth=subgraph_depth,
        max_total_triples=max_total_triples,
        estimated_triples_per_query=estimated_triples_per_query,
        ontology_version_filters=ontology_version_filters or None,
        ontology_hash_filters=ontology_hash_filters or None,
    )
    _bind_common_vocab_prefixes(graph)
    return graph, source_iris

retrieve(query, top_k=None, expand_sparql=True, subgraph_depth=1, max_total_triples=300, estimated_triples_per_query=24)

Retrieve top-k hits for one query and optional induced subgraph; returns source ontology IRIs.

Source code in ontocast/tool/vector_store/patch_retriever.py
def retrieve(
    self,
    query: str,
    top_k: int | None = None,
    expand_sparql: bool = True,
    subgraph_depth: int = 1,
    max_total_triples: int = 300,
    estimated_triples_per_query: int = 24,
) -> tuple[RDFGraph, list[str]]:
    """Retrieve top-k hits for one query and optional induced subgraph; returns source ontology IRIs."""
    try:
        asyncio.get_running_loop()
    except RuntimeError:
        return asyncio.run(
            self.aretrieve(
                query=query,
                top_k=top_k,
                expand_sparql=expand_sparql,
                subgraph_depth=subgraph_depth,
                max_total_triples=max_total_triples,
                estimated_triples_per_query=estimated_triples_per_query,
            )
        )
    raise RuntimeError(
        "retrieve() cannot be called from async code; use await aretrieve()"
    )

retrieve_ensemble(queries, top_k=None, expand_sparql=True, subgraph_depth=1, max_total_triples=300, estimated_triples_per_query=24)

Source code in ontocast/tool/vector_store/patch_retriever.py
def retrieve_ensemble(
    self,
    queries: list[str],
    top_k: int | None = None,
    expand_sparql: bool = True,
    subgraph_depth: int = 1,
    max_total_triples: int = 300,
    estimated_triples_per_query: int = 24,
) -> tuple[RDFGraph, list[str]]:
    """Sync: one induced graph and source IRIs for the union of vector hits over ``queries``."""
    try:
        asyncio.get_running_loop()
    except RuntimeError:
        return asyncio.run(
            self.aretrieve_ensemble(
                queries=queries,
                top_k=top_k,
                expand_sparql=expand_sparql,
                subgraph_depth=subgraph_depth,
                max_total_triples=max_total_triples,
                estimated_triples_per_query=estimated_triples_per_query,
            )
        )
    raise RuntimeError(
        "retrieve_ensemble() is not allowed inside async code; use aretrieve_ensemble()"
    )