Skip to content

graflo.architecture.pipeline.runtime.actor.wrapper

Actor wrapper for managing actor instances and assembly.

ActorWrapper

Wrapper class for managing actor instances.

Source code in graflo/architecture/pipeline/runtime/actor/wrapper.py
class ActorWrapper:
    """Wrapper class for managing actor instances."""

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        config = parse_root_config(*args, **kwargs)
        w = ActorWrapper.from_config(config)
        self.actor = w.actor
        self.init_ctx = w.init_ctx

    @property
    def vertex_config(self) -> VertexConfig:
        return self.init_ctx.vertex_config

    @property
    def edge_config(self) -> EdgeConfig:
        return self.init_ctx.edge_config

    @property
    def infer_edges(self) -> bool:
        return self.init_ctx.infer_edges

    @property
    def infer_edge_only(self) -> set[EdgeId]:
        return self.init_ctx.infer_edge_only

    @property
    def infer_edge_except(self) -> set[EdgeId]:
        return self.init_ctx.infer_edge_except

    def init_transforms(self, init_ctx: ActorInitContext) -> None:
        self.init_ctx = init_ctx
        self.actor.init_transforms(init_ctx)

    def finish_init(self, init_ctx: ActorInitContext) -> None:
        self.init_ctx = init_ctx
        self.actor.init_transforms(init_ctx)
        self.actor.finish_init(init_ctx)

    def count(self) -> int:
        return self.actor.count()

    @classmethod
    def from_config(cls, config: ActorConfig) -> ActorWrapper:
        if isinstance(config, VertexActorConfig):
            actor = VertexActor.from_config(config)
        elif isinstance(config, TransformActorConfig):
            actor = TransformActor.from_config(config)
        elif isinstance(config, EdgeActorConfig):
            actor = EdgeActor.from_config(config)
        elif isinstance(config, DescendActorConfig):
            actor = DescendActor.from_config(config)
        elif isinstance(config, VertexRouterActorConfig):
            actor = VertexRouterActor.from_config(config)
        elif isinstance(config, EdgeRouterActorConfig):
            actor = EdgeRouterActor.from_config(config)
        else:
            raise ValueError(
                f"Expected VertexActorConfig, TransformActorConfig, EdgeActorConfig, "
                f"DescendActorConfig, VertexRouterActorConfig, or EdgeRouterActorConfig, "
                f"got {type(config)}"
            )
        wrapper = cls.__new__(cls)
        wrapper.actor = actor
        wrapper.init_ctx = ActorInitContext(
            vertex_config=VertexConfig(vertices=[]),
            edge_config=EdgeConfig(),
            transforms={},
            infer_edges=True,
            infer_edge_only=set(),
            infer_edge_except=set(),
        )
        return wrapper

    @classmethod
    def _from_step(cls, step: dict[str, Any]) -> ActorWrapper:
        config = validate_actor_step(normalize_actor_step(step))
        return cls.from_config(config)

    def __call__(
        self,
        ctx: ExtractionContext,
        lindex: LocationIndex = LocationIndex(),
        *nargs: Any,
        **kwargs: Any,
    ) -> ExtractionContext:
        ctx = self.actor(ctx, lindex, *nargs, **kwargs)
        return ctx

    def assemble(
        self, ctx: ExtractionContext | AssemblyContext | ActionContext
    ) -> defaultdict[GraphEntity, list]:
        if isinstance(ctx, AssemblyContext):
            assembly_ctx = ctx
        else:
            assembly_ctx = AssemblyContext.from_extraction(ctx)
        assemble_edges(
            ctx=assembly_ctx,
            vertex_config=self.vertex_config,
            edge_config=self.edge_config,
            infer_edges=self.infer_edges,
            infer_edge_only=self.infer_edge_only,
            infer_edge_except=self.infer_edge_except,
        )

        for vertex_name, dd in assembly_ctx.acc_vertex.items():
            for lindex, vertex_list in dd.items():
                vertex_list = [x.vertex for x in vertex_list]
                vertex_list_updated = merge_doc_basis(
                    vertex_list,
                    tuple(self.vertex_config.identity_fields(vertex_name)),
                )
                vertex_list_updated = pick_unique_dict(vertex_list_updated)
                assembly_ctx.acc_global[vertex_name] += vertex_list_updated

        assembly_ctx = add_blank_collections(assembly_ctx, self.vertex_config)

        if isinstance(ctx, ActionContext):
            ctx.acc_global = assembly_ctx.acc_global
            return ctx.acc_global
        return assembly_ctx.acc_global

    @classmethod
    def from_dict(cls, data: dict | list) -> ActorWrapper:
        if isinstance(data, list):
            return cls(*data)
        return cls(**data)

    def assemble_tree(self, fig_path: Path | None = None):
        import logging

        logger = logging.getLogger(__name__)
        _, _, _, edges = self.fetch_actors(0, [])
        logger.info("%s", len(edges))
        try:
            import networkx as nx
        except ImportError as e:
            logger.error("not able to import networks %s", e)
            return None
        nodes = {}
        g = nx.MultiDiGraph()
        for ha, hb, pa, pb in edges:
            nodes[ha] = pa
            nodes[hb] = pb
        from graflo.plot.plotter import fillcolor_palette

        map_class2color = {
            DescendActor: fillcolor_palette["green"],
            VertexActor: "orange",
            VertexRouterActor: fillcolor_palette["peach"],
            EdgeRouterActor: fillcolor_palette["red"],
            EdgeActor: fillcolor_palette["violet"],
            TransformActor: fillcolor_palette["blue"],
        }

        for n, props in nodes.items():
            nodes[n]["fillcolor"] = map_class2color[props["class"]]
            nodes[n]["style"] = "filled"
            nodes[n]["color"] = "brown"

        edges = [(ha, hb) for ha, hb, _, _ in edges]
        g.add_edges_from(edges)
        g.add_nodes_from(nodes.items())

        if fig_path is not None:
            ag = nx.nx_agraph.to_agraph(g)
            ag.draw(fig_path, "pdf", prog="dot")
            return None
        return g

    def fetch_actors(self, level: int, edges: list) -> tuple[int, type, str, list]:
        return self.actor.fetch_actors(level, edges)

    def collect_actors(self) -> list[Actor]:
        actors = [self.actor]
        if isinstance(self.actor, DescendActor):
            for descendant in self.actor.descendants:
                actors.extend(descendant.collect_actors())
        return actors

    def find_descendants(
        self,
        predicate: Callable[[ActorWrapper], bool] | None = None,
        *,
        actor_type: type[Actor] | None = None,
        **attr_in: Any,
    ) -> list[ActorWrapper]:
        if predicate is None:

            def _predicate(w: ActorWrapper) -> bool:
                if actor_type is not None and not isinstance(w.actor, actor_type):
                    return False
                for attr, allowed in attr_in.items():
                    if allowed is None:
                        continue
                    val = getattr(w.actor, attr, None)
                    if val not in allowed:
                        return False
                return True

            predicate = _predicate

        result: list[ActorWrapper] = []
        if isinstance(self.actor, DescendActor):
            for d in self.actor.descendants:
                if predicate(d):
                    result.append(d)
                result.extend(d.find_descendants(predicate=predicate))
        return result

    def remove_descendants_if(self, predicate: Callable[[ActorWrapper], bool]) -> None:
        if isinstance(self.actor, DescendActor):
            for d in list(self.actor.descendants):
                d.remove_descendants_if(predicate=predicate)
            self.actor._descendants[:] = [
                d
                for d in self.actor.descendants
                if not predicate(d)
                and not (isinstance(d.actor, DescendActor) and d.count() == 0)
            ]