Skip to content

graflo.architecture.pipeline.runtime.actor.edge_router

Edge router actor for routing documents to dynamically created edges.

EdgeRouterActor

Bases: Actor

Routes documents to dynamically created edges based on type fields.

Source code in graflo/architecture/pipeline/runtime/actor/edge_router.py
class EdgeRouterActor(Actor):
    """Routes documents to dynamically created edges based on type fields."""

    def __init__(self, config: EdgeRouterActorConfig):
        self.source_type_field = config.source_type_field
        self.target_type_field = config.target_type_field
        self.source = config.source
        self.target = config.target
        self.source_fields = config.source_fields
        self.target_fields = config.target_fields
        self.relation_field = config.relation_field
        self.relation = config.relation
        self._source_type_map: dict[str, str] = {
            **(config.type_map or {}),
            **(config.source_type_map or {}),
        }
        self._target_type_map: dict[str, str] = {
            **(config.type_map or {}),
            **(config.target_type_map or {}),
        }
        self._relation_map: dict[str, str] = config.relation_map or {}
        self._edge_cache: dict[tuple[str, str, str | None], Edge] = {}
        self._init_ctx: ActorInitContext | None = None
        self.vertex_config: VertexConfig = VertexConfig(vertices=[])
        self.edge_config: EdgeConfig = EdgeConfig()

    @classmethod
    def from_config(cls, config: EdgeRouterActorConfig) -> EdgeRouterActor:
        return cls(config)

    def finish_init(self, init_ctx: ActorInitContext) -> None:
        self._init_ctx = init_ctx
        self.vertex_config = init_ctx.vertex_config
        self.edge_config = init_ctx.edge_config
        self._edge_cache.clear()

    def _resolve_type(self, raw: str, type_map: dict[str, str]) -> str | None:
        resolved = type_map.get(raw, raw)
        if resolved not in self.vertex_config.vertex_set:
            logger.debug(
                "EdgeRouterActor: resolved type '%s' not in vertex_set, skipping",
                resolved,
            )
            return None
        return resolved

    def _resolve_relation(self, raw: str | None) -> str | None:
        if raw is None:
            return None
        return self._relation_map.get(raw, raw)

    def _resolve_side_type(
        self,
        doc: dict[str, Any],
        *,
        explicit_type: str | None,
        type_field: str | None,
        type_map: dict[str, str],
        side_name: str,
    ) -> str | None:
        if explicit_type is not None:
            return self._resolve_type(explicit_type, type_map)

        if type_field is None:
            logger.debug(
                "EdgeRouterActor: no %s type source configured, skipping",
                side_name,
            )
            return None

        raw_type = doc.get(type_field)
        if raw_type is None:
            logger.debug(
                "EdgeRouterActor: missing %s type field '%s' in doc, skipping",
                side_name,
                type_field,
            )
            return None

        return self._resolve_type(raw_type, type_map)

    def _get_or_create_edge(
        self,
        source_name: str,
        target_name: str,
        relation: str | None,
    ) -> Edge:
        key = (source_name, target_name, relation)
        if key in self._edge_cache:
            return self._edge_cache[key]
        edge = Edge(source=source_name, target=target_name, relation=relation)
        edge.finish_init(vertex_config=self.vertex_config)
        self.edge_config.update_edges(edge, vertex_config=self.vertex_config)
        self._edge_cache[key] = edge
        logger.debug(
            "EdgeRouterActor: registered dynamic edge (%s, %s, %s)",
            source_name,
            target_name,
            relation,
        )
        return edge

    def _project_vertex_doc(
        self,
        doc: dict[str, Any],
        fields: dict[str, str] | None,
        vertex_name: str,
    ) -> dict[str, Any]:
        if fields is not None:
            return {vf: doc[df] for vf, df in fields.items() if df in doc}
        identity = self.vertex_config.identity_fields(vertex_name)
        return {f: doc[f] for f in identity if f in doc}

    def __call__(
        self,
        ctx: ExtractionContext,
        lindex: LocationIndex,
        *nargs: Any,
        **kwargs: Any,
    ) -> ExtractionContext:
        doc: dict[str, Any] = kwargs.get("doc", {})

        source_name = self._resolve_side_type(
            doc,
            explicit_type=self.source,
            type_field=self.source_type_field,
            type_map=self._source_type_map,
            side_name="source",
        )
        target_name = self._resolve_side_type(
            doc,
            explicit_type=self.target,
            type_field=self.target_type_field,
            type_map=self._target_type_map,
            side_name="target",
        )
        if source_name is None or target_name is None:
            return ctx

        raw_relation = (
            doc.get(self.relation_field) if self.relation_field else self.relation
        )
        relation = self._resolve_relation(raw_relation)

        source_doc = self._project_vertex_doc(doc, self.source_fields, source_name)
        target_doc = self._project_vertex_doc(doc, self.target_fields, target_name)

        if not source_doc or not target_doc:
            logger.debug(
                "EdgeRouterActor: could not project identity docs for "
                "(%s, %s), skipping",
                source_name,
                target_name,
            )
            return ctx

        source_lindex = lindex.extend(("src", 0))
        target_lindex = lindex.extend(("tgt", 0))
        ctx.acc_vertex[source_name][source_lindex].append(
            VertexRep(vertex=source_doc, ctx={})
        )
        ctx.acc_vertex[target_name][target_lindex].append(
            VertexRep(vertex=target_doc, ctx={})
        )

        edge = self._get_or_create_edge(source_name, target_name, relation)
        ctx.edge_requests.append((edge, lindex))
        ctx.record_edge_intent(edge=edge, location=lindex)
        return ctx

    def references_vertices(self) -> set[str]:
        return {s for s, _, _ in self._edge_cache} | {t for _, t, _ in self._edge_cache}

    def fetch_important_items(self) -> dict[str, Any]:
        items: dict[str, Any] = {
            "source": self.source or "",
            "target": self.target or "",
            "source_type_field": self.source_type_field,
            "target_type_field": self.target_type_field,
            "relation_field": self.relation_field or "",
            "cached_edges": sorted(str(k) for k in self._edge_cache),
        }
        if self._relation_map:
            items["relation_map"] = self._relation_map
        return items