Skip to content

graflo.architecture.pipeline.runtime.actor.transform

Transform actor for applying transformations to data.

TransformActor

Bases: Actor

Actor for applying transformations to data.

Source code in graflo/architecture/pipeline/runtime/actor/transform.py
class TransformActor(Actor):
    """Actor for applying transformations to data."""

    def __init__(self, config: TransformActorConfig):
        self.transforms: dict[str, ProtoTransform] = {}
        self.call_use: str | None = None
        self._call_config = None
        self._fail_fast = False
        self._tolerate_transform_errors = True
        self._declared_input_keys: frozenset[str] = frozenset()
        self._rename_map: dict[str, str] | None = None

        if config.rename is not None:
            self.t = Transform(rename=config.rename)
            self._rename_map = dict(config.rename)
            return

        if config.call is None:
            raise ValueError(
                "TransformActorConfig requires call when rename is absent."
            )

        call = config.call
        self._call_config = call
        self.call_use = call.use
        inline_target = (
            call.target
            if call.target is not None
            else "values"
            if call.use is None
            else None
        )
        transform_kwargs: dict[str, Any] = {
            "name": call.use,
            "params": call.params,
            "module": call.module,
            "foo": call.foo,
            "input": tuple(call.input) if call.input else (),
            "output": tuple(call.output) if call.output else (),
            "input_groups": (
                tuple(tuple(group) for group in call.input_groups)
                if call.input_groups
                else ()
            ),
            "output_groups": (
                tuple(tuple(group) for group in call.output_groups)
                if call.output_groups
                else ()
            ),
            "dress": call.dress,
            "strategy": call.strategy or "single",
        }
        if inline_target is not None:
            transform_kwargs["target"] = inline_target
        if call.use is None:
            if call.keys is not None:
                transform_kwargs["keys"] = KeySelectionConfig.model_validate(
                    call.keys.model_dump()
                )
        # When call.use references ingestion_model.transforms, defer strict
        # transform validation until finish_init can hydrate module/foo.
        if call.use is not None and call.module is None and call.foo is None:
            self.t = Transform(name=call.use)
            return
        self.t = Transform(**transform_kwargs)

    def _refresh_missing_key_guard(self, init_ctx: ActorInitContext) -> None:
        self._fail_fast = init_ctx.fail_fast
        self._tolerate_transform_errors = init_ctx.tolerate_transform_errors
        if self.t.target == "keys" or self.t.strategy == "all":
            self._declared_input_keys = frozenset()
            return
        if self._rename_map is not None and not self._fail_fast:
            self._declared_input_keys = frozenset()
            return
        required: set[str] = set(self.t.input)
        for group in self.t.input_groups:
            required.update(group)
        self._declared_input_keys = frozenset(required)

    @staticmethod
    def _extract_observation(nargs: tuple[Any, ...], **kwargs: Any) -> Any:
        """Return the observation slice: dict row or scalar/list positional value."""
        if kwargs:
            observation: Any | None = kwargs.get("doc")
        elif nargs:
            observation = nargs[0]
        else:
            raise ValueError(f"{type(TransformActor).__name__}: doc should be provided")
        if observation is None:
            raise ValueError(f"{type(TransformActor).__name__}: doc should be provided")
        return observation

    @staticmethod
    def _observation_keys(observation: Any) -> frozenset[str]:
        if isinstance(observation, dict):
            return frozenset(str(k) for k in observation.keys())
        return frozenset()

    def _missing_declared_keys(self, observation: Any) -> frozenset[str]:
        if not self._declared_input_keys or not isinstance(observation, dict):
            return frozenset()
        return self._declared_input_keys - self._observation_keys(observation)

    def _rename_removed_keys(self, observation: Any) -> frozenset[str]:
        if self._rename_map is None or not isinstance(observation, dict):
            return frozenset()
        return frozenset(src for src in self._rename_map if src in observation)

    def fetch_important_items(self) -> dict[str, Any]:
        items = self._fetch_items_from_dict(("transform",))
        items.update({"t.input": self.t.input, "t.output": self.t.output})
        return items

    @classmethod
    def from_config(cls, config: TransformActorConfig) -> TransformActor:
        return cls(config)

    def init_transforms(self, init_ctx: ActorInitContext) -> None:
        self.transforms = init_ctx.transforms

    def _merge_call_with_proto(self, call: Any, pt: ProtoTransform) -> dict[str, Any]:
        next_params = call.params if call.params else pt.params
        next_dress = call.dress if call.dress is not None else pt.dress
        next_target = call.target if call.target is not None else pt.target

        if next_target == "keys":
            if call.input or call.output or call.input_groups or call.output_groups:
                raise ValueError(
                    "call.input, call.output, call.input_groups, and call.output_groups "
                    "cannot be used when the effective transform target is keys "
                    "(from call.target or the named ingestion_model.transforms entry)."
                )
            if call.dress is not None:
                raise ValueError("call.dress is not supported when target='keys'.")
            if call.strategy is not None and call.strategy != "single":
                raise ValueError(
                    "call.strategy is not allowed when target='keys'; "
                    "key mode uses implicit per-key execution."
                )
            next_input: tuple[str, ...] = ()
            next_output: tuple[str, ...] = ()
            next_input_groups: tuple[tuple[str, ...], ...] = ()
            next_output_groups: tuple[tuple[str, ...], ...] = ()
        else:
            next_input_groups = (
                tuple(tuple(group) for group in call.input_groups)
                if call.input_groups
                else pt.input_groups
            )
            next_output_groups = (
                tuple(tuple(group) for group in call.output_groups)
                if call.output_groups
                else pt.output_groups
            )
            if next_input_groups:
                next_input = ()
                # Explicit grouped override should not inherit potentially
                # conflicting proto output/output_groups for a different shape.
                if call.input_groups:
                    next_output_groups = (
                        tuple(tuple(group) for group in call.output_groups)
                        if call.output_groups
                        else ()
                    )
                    next_output = tuple(call.output) if call.output else ()
                elif next_dress is not None:
                    next_output = (next_dress.key, next_dress.value)
                else:
                    next_output = tuple(call.output) if call.output else pt.output
            else:
                next_input = tuple(call.input) if call.input else pt.input
                if next_dress is not None:
                    next_output = (next_dress.key, next_dress.value)
                else:
                    next_output = tuple(call.output) if call.output else pt.output

        transform_kwargs: dict[str, Any] = {
            "dress": next_dress,
            "name": call.use,
            "module": pt.module,
            "foo": pt.foo,
            "params": next_params,
            "input": next_input,
            "output": next_output,
            "input_groups": next_input_groups,
            "output_groups": next_output_groups,
            "strategy": call.strategy or "single",
            "target": next_target,
        }
        if call.keys is not None:
            transform_kwargs["keys"] = KeySelectionConfig.model_validate(
                call.keys.model_dump()
            )
        else:
            transform_kwargs["keys"] = pt.keys.model_copy(deep=True)
        return transform_kwargs

    def finish_init(self, init_ctx: ActorInitContext) -> None:
        self.transforms = init_ctx.transforms
        if self.call_use is None or self.t._foo is not None:
            self._refresh_missing_key_guard(init_ctx)
            return
        if self._call_config is None:
            self._refresh_missing_key_guard(init_ctx)
            return
        pt = self.transforms.get(self.call_use, None)
        if pt is None:
            if init_ctx.strict_references:
                raise ValueError(
                    f"Transform '{self.call_use}' referenced by transform.call.use "
                    "was not found in ingestion_model.transforms."
                )
            self._refresh_missing_key_guard(init_ctx)
            return
        call = self._call_config
        transform_kwargs = self._merge_call_with_proto(call, pt)
        self.t = Transform(**transform_kwargs)
        self._refresh_missing_key_guard(init_ctx)

    def _format_transform_result(self, result: Any) -> TransformPayload:
        return TransformPayload.from_result(result)

    def _transform_label(self) -> str:
        if self.call_use:
            return self.call_use
        if self.t.foo and self.t.module:
            return f"{self.t.module}.{self.t.foo}"
        if self.t.foo:
            return self.t.foo
        if self.t.name:
            return self.t.name
        return type(self.t).__name__

    def _format_traceback(self, exc: BaseException) -> str:
        return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))

    def __call__(
        self, ctx: ExtractionContext, lindex: LocationIndex, *nargs: Any, **kwargs: Any
    ) -> ExtractionContext:
        logger.debug("transforms : %s %s", id(self.transforms), len(self.transforms))
        observation = self._extract_observation(nargs, **kwargs)
        missing = self._missing_declared_keys(observation)
        if missing:
            if self._fail_fast:
                raise TransformException(
                    f"Missing required input keys: {sorted(missing)}"
                )
            return ctx
        try:
            transform_result = self.t(observation)
        except Exception as exc:
            if not self._tolerate_transform_errors:
                raise
            nulled_fields = self.t.planned_output_field_names(
                observation if isinstance(observation, dict) else None
            )
            if nulled_fields:
                payload = TransformPayload(named={k: None for k in nulled_fields})
                ctx.transform_buffer[lindex].append(payload)
                ctx.record_transform_observation(location=lindex, payload=payload)
            ctx.record_transform_failure(
                location=lindex,
                transform_label=self._transform_label(),
                exc=exc,
                traceback_text=self._format_traceback(exc),
                nulled_fields=nulled_fields,
            )
            return ctx
        if self._rename_map is not None:
            base = TransformPayload.from_result(transform_result)
            _update_doc = TransformPayload(
                named=base.named,
                positional=base.positional,
                removed_keys=self._rename_removed_keys(observation),
            )
        else:
            _update_doc = self._format_transform_result(transform_result)
        ctx.transform_buffer[lindex].append(_update_doc)
        ctx.record_transform_observation(location=lindex, payload=_update_doc)
        return ctx

    def references_vertices(self) -> set[str]:
        return set()