class TransformActor(Actor):
"""Actor for applying transformations to data."""
def __init__(self, config: TransformActorConfig):
self.transforms: dict[str, ProtoTransform] = {}
self.call_use: str | None = None
self._call_config = None
self._fail_fast = False
self._tolerate_transform_errors = True
self._declared_input_keys: frozenset[str] = frozenset()
self._rename_map: dict[str, str] | None = None
if config.rename is not None:
self.t = Transform(rename=config.rename)
self._rename_map = dict(config.rename)
return
if config.call is None:
raise ValueError(
"TransformActorConfig requires call when rename is absent."
)
call = config.call
self._call_config = call
self.call_use = call.use
inline_target = (
call.target
if call.target is not None
else "values"
if call.use is None
else None
)
transform_kwargs: dict[str, Any] = {
"name": call.use,
"params": call.params,
"module": call.module,
"foo": call.foo,
"input": tuple(call.input) if call.input else (),
"output": tuple(call.output) if call.output else (),
"input_groups": (
tuple(tuple(group) for group in call.input_groups)
if call.input_groups
else ()
),
"output_groups": (
tuple(tuple(group) for group in call.output_groups)
if call.output_groups
else ()
),
"dress": call.dress,
"strategy": call.strategy or "single",
}
if inline_target is not None:
transform_kwargs["target"] = inline_target
if call.use is None:
if call.keys is not None:
transform_kwargs["keys"] = KeySelectionConfig.model_validate(
call.keys.model_dump()
)
# When call.use references ingestion_model.transforms, defer strict
# transform validation until finish_init can hydrate module/foo.
if call.use is not None and call.module is None and call.foo is None:
self.t = Transform(name=call.use)
return
self.t = Transform(**transform_kwargs)
def _refresh_missing_key_guard(self, init_ctx: ActorInitContext) -> None:
self._fail_fast = init_ctx.fail_fast
self._tolerate_transform_errors = init_ctx.tolerate_transform_errors
if self.t.target == "keys" or self.t.strategy == "all":
self._declared_input_keys = frozenset()
return
if self._rename_map is not None and not self._fail_fast:
self._declared_input_keys = frozenset()
return
required: set[str] = set(self.t.input)
for group in self.t.input_groups:
required.update(group)
self._declared_input_keys = frozenset(required)
@staticmethod
def _extract_observation(nargs: tuple[Any, ...], **kwargs: Any) -> Any:
"""Return the observation slice: dict row or scalar/list positional value."""
if kwargs:
observation: Any | None = kwargs.get("doc")
elif nargs:
observation = nargs[0]
else:
raise ValueError(f"{type(TransformActor).__name__}: doc should be provided")
if observation is None:
raise ValueError(f"{type(TransformActor).__name__}: doc should be provided")
return observation
@staticmethod
def _observation_keys(observation: Any) -> frozenset[str]:
if isinstance(observation, dict):
return frozenset(str(k) for k in observation.keys())
return frozenset()
def _missing_declared_keys(self, observation: Any) -> frozenset[str]:
if not self._declared_input_keys or not isinstance(observation, dict):
return frozenset()
return self._declared_input_keys - self._observation_keys(observation)
def _rename_removed_keys(self, observation: Any) -> frozenset[str]:
if self._rename_map is None or not isinstance(observation, dict):
return frozenset()
return frozenset(src for src in self._rename_map if src in observation)
def fetch_important_items(self) -> dict[str, Any]:
items = self._fetch_items_from_dict(("transform",))
items.update({"t.input": self.t.input, "t.output": self.t.output})
return items
@classmethod
def from_config(cls, config: TransformActorConfig) -> TransformActor:
return cls(config)
def init_transforms(self, init_ctx: ActorInitContext) -> None:
self.transforms = init_ctx.transforms
def _merge_call_with_proto(self, call: Any, pt: ProtoTransform) -> dict[str, Any]:
next_params = call.params if call.params else pt.params
next_dress = call.dress if call.dress is not None else pt.dress
next_target = call.target if call.target is not None else pt.target
if next_target == "keys":
if call.input or call.output or call.input_groups or call.output_groups:
raise ValueError(
"call.input, call.output, call.input_groups, and call.output_groups "
"cannot be used when the effective transform target is keys "
"(from call.target or the named ingestion_model.transforms entry)."
)
if call.dress is not None:
raise ValueError("call.dress is not supported when target='keys'.")
if call.strategy is not None and call.strategy != "single":
raise ValueError(
"call.strategy is not allowed when target='keys'; "
"key mode uses implicit per-key execution."
)
next_input: tuple[str, ...] = ()
next_output: tuple[str, ...] = ()
next_input_groups: tuple[tuple[str, ...], ...] = ()
next_output_groups: tuple[tuple[str, ...], ...] = ()
else:
next_input_groups = (
tuple(tuple(group) for group in call.input_groups)
if call.input_groups
else pt.input_groups
)
next_output_groups = (
tuple(tuple(group) for group in call.output_groups)
if call.output_groups
else pt.output_groups
)
if next_input_groups:
next_input = ()
# Explicit grouped override should not inherit potentially
# conflicting proto output/output_groups for a different shape.
if call.input_groups:
next_output_groups = (
tuple(tuple(group) for group in call.output_groups)
if call.output_groups
else ()
)
next_output = tuple(call.output) if call.output else ()
elif next_dress is not None:
next_output = (next_dress.key, next_dress.value)
else:
next_output = tuple(call.output) if call.output else pt.output
else:
next_input = tuple(call.input) if call.input else pt.input
if next_dress is not None:
next_output = (next_dress.key, next_dress.value)
else:
next_output = tuple(call.output) if call.output else pt.output
transform_kwargs: dict[str, Any] = {
"dress": next_dress,
"name": call.use,
"module": pt.module,
"foo": pt.foo,
"params": next_params,
"input": next_input,
"output": next_output,
"input_groups": next_input_groups,
"output_groups": next_output_groups,
"strategy": call.strategy or "single",
"target": next_target,
}
if call.keys is not None:
transform_kwargs["keys"] = KeySelectionConfig.model_validate(
call.keys.model_dump()
)
else:
transform_kwargs["keys"] = pt.keys.model_copy(deep=True)
return transform_kwargs
def finish_init(self, init_ctx: ActorInitContext) -> None:
self.transforms = init_ctx.transforms
if self.call_use is None or self.t._foo is not None:
self._refresh_missing_key_guard(init_ctx)
return
if self._call_config is None:
self._refresh_missing_key_guard(init_ctx)
return
pt = self.transforms.get(self.call_use, None)
if pt is None:
if init_ctx.strict_references:
raise ValueError(
f"Transform '{self.call_use}' referenced by transform.call.use "
"was not found in ingestion_model.transforms."
)
self._refresh_missing_key_guard(init_ctx)
return
call = self._call_config
transform_kwargs = self._merge_call_with_proto(call, pt)
self.t = Transform(**transform_kwargs)
self._refresh_missing_key_guard(init_ctx)
def _format_transform_result(self, result: Any) -> TransformPayload:
return TransformPayload.from_result(result)
def _transform_label(self) -> str:
if self.call_use:
return self.call_use
if self.t.foo and self.t.module:
return f"{self.t.module}.{self.t.foo}"
if self.t.foo:
return self.t.foo
if self.t.name:
return self.t.name
return type(self.t).__name__
def _format_traceback(self, exc: BaseException) -> str:
return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
def __call__(
self, ctx: ExtractionContext, lindex: LocationIndex, *nargs: Any, **kwargs: Any
) -> ExtractionContext:
logger.debug("transforms : %s %s", id(self.transforms), len(self.transforms))
observation = self._extract_observation(nargs, **kwargs)
missing = self._missing_declared_keys(observation)
if missing:
if self._fail_fast:
raise TransformException(
f"Missing required input keys: {sorted(missing)}"
)
return ctx
try:
transform_result = self.t(observation)
except Exception as exc:
if not self._tolerate_transform_errors:
raise
nulled_fields = self.t.planned_output_field_names(
observation if isinstance(observation, dict) else None
)
if nulled_fields:
payload = TransformPayload(named={k: None for k in nulled_fields})
ctx.transform_buffer[lindex].append(payload)
ctx.record_transform_observation(location=lindex, payload=payload)
ctx.record_transform_failure(
location=lindex,
transform_label=self._transform_label(),
exc=exc,
traceback_text=self._format_traceback(exc),
nulled_fields=nulled_fields,
)
return ctx
if self._rename_map is not None:
base = TransformPayload.from_result(transform_result)
_update_doc = TransformPayload(
named=base.named,
positional=base.positional,
removed_keys=self._rename_removed_keys(observation),
)
else:
_update_doc = self._format_transform_result(transform_result)
ctx.transform_buffer[lindex].append(_update_doc)
ctx.record_transform_observation(location=lindex, payload=_update_doc)
return ctx
def references_vertices(self) -> set[str]:
return set()