Skip to content

graflo.hq.document_caster

Stateless document-to-graph casting (no I/O).

DocumentCaster

Cast source documents to :class:GraphContainer via ingestion resources.

Source code in graflo/hq/document_caster.py
class DocumentCaster:
    """Cast source documents to :class:`GraphContainer` via ingestion resources."""

    def __init__(self, ingestion_model: IngestionModel) -> None:
        self.ingestion_model = ingestion_model

    async def cast_batch(
        self,
        data: Iterable[Any],
        resource_name: str | None,
        *,
        params: IngestionParams,
        allowed_vertex_names: set[str] | None = None,
    ) -> CastBatchResult:
        runtime = self.ingestion_model.fetch_resource(resource_name)
        resolved_name = runtime.name
        vertex_filter = cast_vertex_filter(
            runtime.collect_vertex_names(),
            allowed_vertex_names=allowed_vertex_names,
        )

        doc_list = list(data)
        cast_results, failures = await self._gather_cast_results(
            runtime,
            doc_list,
            on_doc_error=params.on_doc_error,
            resolved_name=resolved_name,
            params=params,
        )

        graph = GraphContainer.from_docs_list(
            [r.entities for r in cast_results if isinstance(r, ResourceCastResult)]
        )
        filter_graph_container_by_vertices_inplace(
            graph, allowed_vertex_names=vertex_filter
        )
        if params.drop_empty_identity_docs:
            filter_graph_container_drop_empty_identity_inplace(
                graph,
                vertex_config=runtime.vertex_config,
            )
        return CastBatchResult(graph=graph, failures=failures)

    async def _gather_cast_results(
        self,
        runtime: ResourceRuntime,
        doc_list: list[Any],
        *,
        on_doc_error: Literal["fail", "skip"],
        resolved_name: str,
        params: IngestionParams,
    ) -> tuple[list[ResourceCastResult | BaseException], list[DocCastFailure]]:
        semaphore = asyncio.Semaphore(params.n_cores)

        async def process_doc(doc: dict[str, Any]) -> ResourceCastResult:
            async with semaphore:
                return await asyncio.to_thread(runtime.cast_document, doc)

        if on_doc_error == "fail":
            raw = await asyncio.gather(
                *[process_doc(_coerce_doc(doc)) for doc in doc_list]
            )
        else:
            raw = await asyncio.gather(
                *[process_doc(_coerce_doc(doc)) for doc in doc_list],
                return_exceptions=True,
            )

        cast_results: list[ResourceCastResult | BaseException] = []
        failures: list[DocCastFailure] = []
        for i, item in enumerate(raw):
            doc = _coerce_doc(doc_list[i])
            if isinstance(item, asyncio.CancelledError):
                raise item
            if isinstance(item, (KeyboardInterrupt, SystemExit)):
                raise item
            if isinstance(item, BaseException):
                failures.append(
                    _doc_failure_from_exception(
                        resource_name=resolved_name,
                        doc_index=i,
                        doc=doc,
                        exc=item,
                        doc_keys=params.doc_error_preview_keys,
                        doc_preview_max_bytes=params.doc_error_preview_max_bytes,
                    )
                )
                continue
            failures.extend(
                _transform_failures_to_doc_cast_failures(
                    resource_name=resolved_name,
                    doc_index=i,
                    doc=doc,
                    transform_failures=item.transform_failures,
                    doc_keys=params.doc_error_preview_keys,
                    doc_preview_max_bytes=params.doc_error_preview_max_bytes,
                )
            )
            cast_results.append(item)
        return cast_results, failures

cast_vertex_filter(resource_vertex_names, *, allowed_vertex_names)

Vertex names to retain after casting for a single resource.

Source code in graflo/hq/document_caster.py
def cast_vertex_filter(
    resource_vertex_names: set[str],
    *,
    allowed_vertex_names: set[str] | None,
) -> set[str]:
    """Vertex names to retain after casting for a single resource."""
    if allowed_vertex_names is None:
        return resource_vertex_names
    return resource_vertex_names & allowed_vertex_names

filter_graph_container_by_vertices_inplace(gc, *, allowed_vertex_names)

Restrict persistence to a subset of vertex types (in-place).

Source code in graflo/hq/document_caster.py
def filter_graph_container_by_vertices_inplace(
    gc: GraphContainer, *, allowed_vertex_names: set[str] | None
) -> None:
    """Restrict persistence to a subset of vertex types (in-place)."""
    if allowed_vertex_names is None:
        return
    gc.vertices = {
        vcol: items
        for vcol, items in gc.vertices.items()
        if vcol in allowed_vertex_names
    }
    gc.edges = {
        (vfrom, vto, rel): items
        for (vfrom, vto, rel), items in gc.edges.items()
        if vfrom in allowed_vertex_names and vto in allowed_vertex_names
    }

filter_graph_container_drop_empty_identity_inplace(gc, *, vertex_config)

Remove vertex docs and edge tuples with no usable schema identity.

Source code in graflo/hq/document_caster.py
def filter_graph_container_drop_empty_identity_inplace(
    gc: GraphContainer, *, vertex_config: VertexConfig
) -> None:
    """Remove vertex docs and edge tuples with no usable schema identity."""
    blank = set(vertex_config.blank_vertices)
    vertex_set = vertex_config.vertex_set

    for vcol, docs in list(gc.vertices.items()):
        if vcol in blank or vcol not in vertex_set:
            continue
        id_fields = vertex_config.identity_fields(vcol)
        gc.vertices[vcol] = [
            d for d in docs if not _vertex_doc_has_empty_identity(d, id_fields)
        ]

    for edge_id, docs in list(gc.edges.items()):
        vfrom, vto, _rel = edge_id
        if vfrom not in vertex_set or vto not in vertex_set:
            continue
        if vfrom in blank or vto in blank:
            continue
        src_ids = vertex_config.identity_fields(vfrom)
        tgt_ids = vertex_config.identity_fields(vto)
        kept = [
            t
            for t in docs
            if not _vertex_doc_has_empty_identity(t[0], src_ids)
            and not _vertex_doc_has_empty_identity(t[1], tgt_ids)
        ]
        if kept:
            gc.edges[edge_id] = kept
        else:
            del gc.edges[edge_id]