Skip to content

graflo.architecture.pipeline.runtime.assemble

Assembly phase for turning extracted observations into graph edges.

assemble_edges(*, ctx, vertex_config, edge_config, infer_edges, infer_edge_only=None, infer_edge_except=None)

Assemble all edge documents after extraction finishes.

Source code in graflo/architecture/pipeline/runtime/assemble.py
def assemble_edges(
    *,
    ctx: AssemblyContext,
    vertex_config: VertexConfig,
    edge_config: EdgeConfig,
    infer_edges: bool,
    infer_edge_only: set[EdgeId] | None = None,
    infer_edge_except: set[EdgeId] | None = None,
) -> None:
    """Assemble all edge documents after extraction finishes."""
    if infer_edge_only is None:
        infer_edge_only = set()
    if infer_edge_except is None:
        infer_edge_except = set()

    emitted_pairs: set[tuple[str, str]] = set()

    explicit_requests: list[tuple[Any, LocationIndex | None]] = [
        (intent.edge, intent.location) for intent in ctx.edge_intents
    ]
    if not explicit_requests:
        explicit_requests = list(ctx.edge_requests)

    for edge, lindex in explicit_requests:
        if _emit_edge_documents(
            ctx=ctx,
            vertex_config=vertex_config,
            edge=edge,
            lindex=lindex,
        ):
            emitted_pairs.add((edge.source, edge.target))
    ctx.edge_requests = []
    ctx.extraction.edge_intents = []

    if not infer_edges:
        return

    populated = {v for v, dd in ctx.acc_vertex.items() if any(dd.values())}
    for edge_id, edge in edge_config.items():
        s, t, _ = edge_id
        if (s, t) in emitted_pairs or s not in populated or t not in populated:
            continue
        if not _is_inference_allowed(
            edge_id,
            infer_edge_only=infer_edge_only,
            infer_edge_except=infer_edge_except,
        ):
            continue
        if _emit_edge_documents(
            ctx=ctx,
            vertex_config=vertex_config,
            edge=edge,
            lindex=None,
        ):
            emitted_pairs.add((s, t))