class SchemaSanitizer:
"""Sanitizes schema attributes to avoid reserved words and normalize indexes.
This class handles:
- Sanitizing vertex names and field names to avoid reserved words
- Normalizing vertex indexes for TigerGraph (ensuring consistent indexes
for edges with the same relation)
- Applying field index mappings to resources
"""
def __init__(self, db_flavor: DBType):
"""Initialize the schema sanitizer.
Args:
db_flavor: Target database flavor to load reserved words for
"""
self.db_flavor = db_flavor
self.reserved_words = load_reserved_words(db_flavor)
self.vertex_attribute_mappings: defaultdict[str, dict[str, str]] = defaultdict(
dict
)
self.vertex_mappings: dict[str, str] = {}
def sanitize(
self,
schema: Schema,
ingestion_model: IngestionModel | None = None,
) -> Schema:
"""Sanitize attribute names and vertex names in the schema to avoid reserved words.
This method modifies:
- Field names in vertices and edges
- Vertex names themselves
- Edge source/target/by references to vertices
- Resource apply lists that reference vertices
The sanitization is deterministic: the same input always produces the same output.
Args:
schema: The schema to sanitize
Returns:
Schema with sanitized attribute names and vertex names
"""
if not self.reserved_words:
# No reserved words to check, return schema as-is
return schema
# First pass: Sanitize physical vertex storage names
for vertex in schema.core_schema.vertex_config.vertices:
dbname = schema.db_profile.vertex_storage_name(vertex.name)
sanitized_vertex_name = sanitize_attribute_name(
dbname, self.reserved_words, suffix=f"_{VERTEX_SUFFIX}"
)
if sanitized_vertex_name != dbname:
logger.debug(
f"Sanitizing vertex name '{dbname}' -> '{sanitized_vertex_name}'"
)
self.vertex_mappings[dbname] = sanitized_vertex_name
schema.db_profile.vertex_storage_names[vertex.name] = (
sanitized_vertex_name
)
# Second pass: Sanitize vertex field names
for vertex in schema.core_schema.vertex_config.vertices:
for field in vertex.fields:
original_name = field.name
sanitized_name = sanitize_attribute_name(
original_name, self.reserved_words
)
if sanitized_name != original_name:
self.vertex_attribute_mappings[vertex.name][original_name] = (
sanitized_name
)
logger.debug(
f"Sanitizing field name '{original_name}' -> '{sanitized_name}' "
f"in vertex '{vertex.name}'"
)
field.name = sanitized_name
vertex.identity = [
self.vertex_attribute_mappings[vertex.name].get(item, item)
for item in vertex.identity
]
vertex_names = {
schema.db_profile.vertex_storage_name(vertex.name)
for vertex in schema.core_schema.vertex_config.vertices
}
for edge in schema.core_schema.edge_config.edges:
if not edge.relation:
continue
original = schema.db_profile.edge_relation_name(
edge.edge_id,
default_relation=edge.relation,
)
if original is None:
continue
# First pass: sanitize against reserved words
sanitized = sanitize_attribute_name(
original,
self.reserved_words,
suffix=f"_{RELATION_SUFFIX}",
)
# Second pass: avoid collision with vertex names
if sanitized in vertex_names:
base = f"{sanitized}_{RELATION_SUFFIX}"
candidate = base
counter = 1
while candidate in vertex_names:
candidate = f"{base}_{counter}"
counter += 1
sanitized = candidate
# Update only if needed
if sanitized != original:
schema.db_profile.set_edge_name_spec(
edge.edge_id,
relation_name=sanitized,
)
# Third pass: Normalize edge indexes for TigerGraph
# TigerGraph requires that edges with the same relation have consistent source and target indexes
# 1) group edges by relation
# 2) check that for each group specified by relation the sources have the same index
# and separately the targets have the same index
# 3) if this is not the case, identify the most popular index
# 4) for vertices that don't comply with the chose source/target index, we want to prepare a mapping
# and rename relevant fields indexes
field_index_mappings: dict[
str, dict[str, str]
] = {} # vertex_name -> {old_field: new_field}
if schema.db_profile.db_flavor == DBType.TIGERGRAPH:
# Group edges by relation
edges_by_relation: dict[str | None, list[Edge]] = {}
for edge in schema.core_schema.edge_config.edges:
# Use sanitized dbname when grouping by relation for TigerGraph
relation = (
schema.db_profile.edge_relation_name(
edge.edge_id,
default_relation=edge.relation,
)
or edge.relation
)
if relation not in edges_by_relation:
edges_by_relation[relation] = []
edges_by_relation[relation].append(edge)
# Process each relation group
for relation, relation_edges in edges_by_relation.items():
if len(relation_edges) <= 1:
# Only one edge with this relation, no normalization needed
continue
# Collect all vertex/index pairs using a list to capture all occurrences
# This handles cases where a vertex appears multiple times in edges for the same relation
source_vertex_indexes: list[tuple[str, tuple[str, ...]]] = []
target_vertex_indexes: list[tuple[str, tuple[str, ...]]] = []
for edge in relation_edges:
source_vertex = edge.source
target_vertex = edge.target
# Get identity fields for source vertex
source_vertex_indexes.append(
(
source_vertex,
tuple(
schema.core_schema.vertex_config.identity_fields(
source_vertex
)
),
)
)
# Get identity fields for target vertex
target_vertex_indexes.append(
(
target_vertex,
tuple(
schema.core_schema.vertex_config.identity_fields(
target_vertex
)
),
)
)
# Normalize source indexes
self._normalize_vertex_indexes(
source_vertex_indexes,
relation,
schema,
field_index_mappings,
"source",
)
# Normalize target indexes
self._normalize_vertex_indexes(
target_vertex_indexes,
relation,
schema,
field_index_mappings,
"target",
)
# Fourth pass: the field maps from edge/relation normalization should be applied to resources:
# new transforms should be added mapping old index names to those identified in the previous step
if field_index_mappings and ingestion_model is not None:
for resource in ingestion_model.resources:
self._apply_field_index_mappings_to_resource(
resource, field_index_mappings
)
return schema
def _normalize_vertex_indexes(
self,
vertex_indexes: list[tuple[str, tuple[str, ...]]],
relation: str | None,
schema: Schema,
field_index_mappings: dict[str, dict[str, str]],
role: str, # "source" or "target" for logging
) -> None:
"""Normalize vertex indexes to use the most popular index pattern.
For vertices that don't match the most popular index, this method:
1. Creates field mappings (old_field -> new_field)
2. Updates vertex indexes to match the most popular pattern
3. Adds new fields to vertices if needed
4. Removes old fields that are being replaced
Args:
vertex_indexes: List of (vertex_name, index_fields_tuple) pairs
relation: Relation name for logging
schema: Schema to update
field_index_mappings: Dictionary to update with field mappings
role: "source" or "target" for logging purposes
"""
if not vertex_indexes:
return
# Extract unique vertex/index pairs (a vertex might appear multiple times)
vertex_index_dict: dict[str, tuple[str, ...]] = {}
for vertex_name, index_fields in vertex_indexes:
# Only store first occurrence - we'll normalize all vertices together
if vertex_name not in vertex_index_dict:
vertex_index_dict[vertex_name] = index_fields
# Check if all indexes are consistent
indexes_list = list(vertex_index_dict.values())
indexes_set = set(indexes_list)
indexes_consistent = len(indexes_set) == 1
if indexes_consistent:
# All indexes are the same, no normalization needed
return
# Find most popular index
index_counter = Counter(indexes_list)
most_popular_index = index_counter.most_common(1)[0][0]
# Normalize vertices that don't match
for vertex_name, index_fields in vertex_index_dict.items():
if index_fields == most_popular_index:
continue
# Initialize mappings for this vertex if needed
if vertex_name not in field_index_mappings:
field_index_mappings[vertex_name] = {}
# Map old fields to new fields
old_fields = list(index_fields)
new_fields = list(most_popular_index)
# Create field-to-field mapping
# If lengths match, map positionally; otherwise map first field to first field
if len(old_fields) == len(new_fields):
for old_field, new_field in zip(old_fields, new_fields):
if old_field != new_field:
# Update existing mapping if it exists, otherwise create new one
field_index_mappings[vertex_name][old_field] = new_field
else:
# If lengths don't match, map the first field
if old_fields and new_fields:
if old_fields[0] != new_fields[0]:
field_index_mappings[vertex_name][old_fields[0]] = new_fields[0]
# Update vertex index and fields
vertex = schema.core_schema.vertex_config[vertex_name]
existing_field_names = {f.name for f in vertex.fields}
# Add new fields that don't exist
for new_field in most_popular_index:
if new_field not in existing_field_names:
vertex.fields.append(Field(name=new_field, type=None))
existing_field_names.add(new_field)
# Remove old fields that are being replaced (not in new index)
fields_to_remove = [
f
for f in vertex.fields
if f.name in old_fields and f.name not in new_fields
]
for field_to_remove in fields_to_remove:
vertex.fields.remove(field_to_remove)
# Update logical identity to match the most popular one.
vertex.identity = list(most_popular_index)
logger.debug(
f"Normalizing {role} index for vertex '{vertex_name}' in relation '{relation}': "
f"{old_fields} -> {new_fields}"
)
def _apply_field_index_mappings_to_resource(
self, resource: Resource, field_index_mappings: dict[str, dict[str, str]]
) -> None:
"""Apply field index mappings to TransformActor instances in a resource.
For vertices that had their indexes normalized, this method updates TransformActor
instances to map old field names to new field names in their Transform.map attribute.
Only updates TransformActors where the vertex is confirmed to be created at that level
(via VertexActor).
Args:
resource: The resource to update
field_index_mappings: Dictionary mapping vertex names to field mappings
(old_field -> new_field)
"""
from graflo.architecture.pipeline.runtime.actor import (
ActorWrapper,
DescendActor,
TransformActor,
VertexActor,
)
def collect_vertices_at_level(wrappers: list[ActorWrapper]) -> set[str]:
"""Collect vertices created by VertexActor instances at the current level only.
Does not recurse into nested structures - only collects vertices from
the immediate level.
Args:
wrappers: List of ActorWrapper instances
Returns:
set[str]: Set of vertex names created at this level
"""
vertices = set()
for wrapper in wrappers:
if isinstance(wrapper.actor, VertexActor):
vertices.add(wrapper.actor.name)
return vertices
def update_transform_actor_maps(
wrapper: ActorWrapper, parent_vertices: set[str] | None = None
) -> set[str]:
"""Recursively update TransformActor instances with field index mappings.
Args:
wrapper: ActorWrapper instance to process
parent_vertices: Set of vertices available from parent levels (for nested structures)
Returns:
set[str]: Set of all vertices available at this level (including parent)
"""
if parent_vertices is None:
parent_vertices = set()
# Collect vertices created at this level
current_level_vertices = set()
if isinstance(wrapper.actor, VertexActor):
current_level_vertices.add(wrapper.actor.name)
# All available vertices = current level + parent levels
all_available_vertices = current_level_vertices | parent_vertices
# Process VertexActor with from_doc: apply field mappings to doc_field values
if isinstance(wrapper.actor, VertexActor) and wrapper.actor.from_doc:
vertex_name = wrapper.actor.name
if vertex_name in field_index_mappings:
mappings = field_index_mappings[vertex_name]
if mappings:
from_doc = wrapper.actor.from_doc
for v_f, d_f in list(from_doc.items()):
if d_f in mappings:
from_doc[v_f] = mappings[d_f]
logger.debug(
f"Updated VertexActor from_doc in resource '{resource.name}' "
f"for vertex '{vertex_name}': {mappings}"
)
# Process TransformActor: apply mappings from all available vertices
if isinstance(wrapper.actor, TransformActor):
transform_actor: TransformActor = wrapper.actor
def apply_mappings_to_transform(
mappings: dict[str, str],
vertex_name: str,
actor: TransformActor,
) -> None:
"""Apply field mappings to TransformActor's transform.map attribute.
Args:
mappings: Dictionary mapping old field names to new field names
vertex_name: Name of the vertex these mappings belong to (for logging)
actor: The TransformActor instance to update
"""
transform = actor.t
if transform.rename:
# Update existing map: replace values and keys that match old field names
# First, update values
for map_key, map_value in transform.rename.items():
if isinstance(map_value, str) and map_value in mappings:
transform.rename[map_key] = mappings[map_value]
# if the terminal attr not in the map - add it
for k, v in mappings.items():
if v not in transform.rename.values():
transform.rename[k] = v
else:
# Create new map with all mappings
transform.rename = mappings.copy()
# Update Transform object IO to reflect map edits
actor.t._init_io_from_map(force_init=True)
logger.debug(
f"Updated TransformActor map in resource '{resource.name}' "
f"for vertex '{vertex_name}': {mappings}"
)
applied_any = False
for vertex in all_available_vertices:
if vertex in field_index_mappings:
mappings = field_index_mappings[vertex]
if mappings:
apply_mappings_to_transform(
mappings, vertex, transform_actor
)
applied_any = True
if not applied_any:
logger.debug(
f"Skipping TransformActor in resource '{resource.name}': "
f"no mappings for vertices {all_available_vertices}"
)
# Recursively process nested structures (DescendActor)
if isinstance(wrapper.actor, DescendActor):
# Collect vertices from all descendants at this level
descendant_vertices = collect_vertices_at_level(
wrapper.actor.descendants
)
all_available_vertices |= descendant_vertices
# Recursively process each descendant
for descendant_wrapper in wrapper.actor.descendants:
nested_vertices = update_transform_actor_maps(
descendant_wrapper, parent_vertices=all_available_vertices
)
# Merge nested vertices into available vertices
all_available_vertices |= nested_vertices
return all_available_vertices
# Process the root ActorWrapper if it exists
if hasattr(resource, "root") and resource.root is not None:
update_transform_actor_maps(resource.root)
else:
logger.warning(
f"Resource '{resource.name}' does not have a root ActorWrapper. "
f"Skipping field index mapping updates."
)