class DBConfig(BaseSettings, abc.ABC):
"""Abstract base class for all database connection configurations using Pydantic BaseSettings."""
uri: str | None = Field(default=None, description="Backend URI")
username: str | None = Field(default=None, description="Authentication username")
password: str | None = Field(default=None, description="Authentication Password")
database: str | None = Field(
default=None,
description="Database name (backward compatibility, DB-specific mapping)",
)
schema_name: str | None = Field(
default=None,
validation_alias=AliasChoices("schema", "schema_name"),
description="Schema/graph name (unified internal structure)",
)
request_timeout: float = Field(
default=60.0, description="Request timeout in seconds"
)
@abc.abstractmethod
def _get_default_port(self) -> int:
"""Get the default port for this db type."""
pass
@abc.abstractmethod
def _get_effective_database(self) -> str | None:
"""Get the effective database name based on DB type.
For SQL databases: returns the database name
For graph databases: returns None (they don't have a database level)
Returns:
Database name or None
"""
pass
@abc.abstractmethod
def _get_effective_schema(self) -> str | None:
"""Get the effective schema/graph name based on DB type.
For SQL databases: returns the schema name
For graph databases: returns the graph/database name (mapped from user-facing field)
Returns:
Schema/graph name or None
"""
pass
@property
def effective_database(self) -> str | None:
"""Get the effective database name (delegates to concrete class)."""
return self._get_effective_database()
@property
def effective_schema(self) -> str | None:
"""Get the effective schema/graph name (delegates to concrete class)."""
return self._get_effective_schema()
@model_validator(mode="after")
def _normalize_uri(self):
"""Normalize URI: handle URIs without scheme and add default port if missing."""
if self.uri is None:
return self
# Valid URL schemes (common database protocols)
valid_schemes = {
"http",
"https",
"bolt",
"bolt+s",
"bolt+ssc",
"neo4j",
"neo4j+s",
"neo4j+ssc",
"mongodb",
"postgresql",
"postgres",
"mysql",
"nebula",
"redis", # FalkorDB uses redis:// protocol
"rediss", # Redis with SSL
}
# Try to parse as-is first
parsed = urlparse(self.uri)
# Check if parsed scheme is actually a valid scheme or if it's a hostname
# urlparse treats "localhost:14240" as scheme="localhost", path="14240"
# We need to detect this case
has_valid_scheme = parsed.scheme.lower() in valid_schemes
has_netloc = bool(parsed.netloc)
# If scheme doesn't look like a valid scheme and we have a colon, treat as host:port
if not has_valid_scheme and ":" in self.uri and not self.uri.startswith("//"):
# Check if it looks like host:port format
parts = self.uri.split(":", 1)
if len(parts) == 2:
potential_host = parts[0]
port_and_rest = parts[1]
# Extract port (may have path/query after it)
port_part = port_and_rest.split("/")[0].split("?")[0].split("#")[0]
try:
# Validate port is numeric
int(port_part)
# If hostname doesn't look like a scheme (contains dots, is localhost, etc.)
# or if the parsed scheme is not in valid schemes, treat as host:port
if (
"." in potential_host
or potential_host.lower() in {"localhost", "127.0.0.1"}
or not has_valid_scheme
):
# Reconstruct as proper URI with default scheme
default_scheme = "http" # Default to http for most DBs
rest = port_and_rest[len(port_part) :] # Everything after port
self.uri = (
f"{default_scheme}://{potential_host}:{port_part}{rest}"
)
parsed = urlparse(self.uri)
except ValueError:
# Not a valid port, treat as regular URI - add scheme if needed
if not has_valid_scheme:
default_scheme = "http"
self.uri = f"{default_scheme}://{self.uri}"
parsed = urlparse(self.uri)
elif not has_valid_scheme and not has_netloc:
# No valid scheme and no netloc - add default scheme
default_scheme = "http"
self.uri = f"{default_scheme}://{self.uri}"
parsed = urlparse(self.uri)
# Add default port if missing
if parsed.port is None:
default_port = self._get_default_port()
if parsed.scheme and parsed.hostname:
# Reconstruct URI with port
port_part = f":{default_port}" if default_port else ""
path_part = parsed.path or ""
query_part = f"?{parsed.query}" if parsed.query else ""
fragment_part = f"#{parsed.fragment}" if parsed.fragment else ""
self.uri = f"{parsed.scheme}://{parsed.hostname}{port_part}{path_part}{query_part}{fragment_part}"
return self
@model_validator(mode="after")
def _extract_port_from_uri(self):
"""Extract port from URI and set it as gs_port for TigerGraph (if applicable).
For TigerGraph 4+, gs_port is the primary port. If URI has a port but gs_port
is not set, automatically extract and set gs_port from URI port.
This simplifies configuration - users can just provide URI with port.
"""
# Only apply to configs that have gs_port field (TigerGraph)
if not hasattr(self, "gs_port"):
return self
if self.uri and self.gs_port is None:
uri_port = self.port # Get port from URI (property from base class)
if uri_port:
try:
self.gs_port = int(uri_port)
logger.debug(
f"Automatically set gs_port={self.gs_port} from URI port"
)
except (ValueError, TypeError):
# Port couldn't be converted to int, skip auto-setting
pass
return self
@model_validator(mode="after")
def _check_port_conflicts(self):
"""Check for port conflicts between URI and separate port fields.
If port is provided both in URI and as a separate field, warn and prefer URI port.
This ensures consistency and avoids confusion.
"""
if self.uri is None:
return self
uri_port = self.port # Get port from URI
if uri_port is None:
return self
# Check for port fields in subclasses
# Get model fields to check for port-related fields
port_fields = []
# Check for specific port fields that might exist in subclasses
# Use getattr with None default to avoid AttributeError
if hasattr(self, "gs_port"):
gs_port_val = getattr(self, "gs_port", None)
if gs_port_val is not None:
port_fields.append(("gs_port", gs_port_val))
if hasattr(self, "bolt_port"):
bolt_port_val = getattr(self, "bolt_port", None)
if bolt_port_val is not None:
port_fields.append(("bolt_port", bolt_port_val))
# Check each port field for conflicts
port_conflicts = []
for field_name, field_port in port_fields:
# Compare as strings to handle int vs str differences
if str(field_port) != str(uri_port):
port_conflicts.append((field_name, field_port, uri_port))
# Warn about conflicts and prefer URI port
if port_conflicts:
conflict_msgs = [
f"{field_name}={field_port} (URI has port={uri_port})"
for field_name, field_port, _ in port_conflicts
]
warning_msg = (
f"Port conflict detected: Port specified both in URI ({uri_port}) "
f"and as separate field(s): {', '.join(conflict_msgs)}. "
f"Using port from URI ({uri_port}). Consider removing the separate port field(s)."
)
warnings.warn(warning_msg, UserWarning, stacklevel=2)
logger.warning(warning_msg)
# Update port fields to match URI port (prefer URI)
for field_name, _, _ in port_conflicts:
try:
setattr(self, field_name, int(uri_port))
except (ValueError, AttributeError):
# Field might be read-only or not settable, that's okay
pass
return self
@property
def url(self) -> str | None:
"""Backward compatibility property: alias for uri."""
return self.uri
@property
def url_without_port(self) -> str:
"""Get URL without port."""
if self.uri is None:
raise ValueError("URI is not set")
parsed = urlparse(self.uri)
return f"{parsed.scheme}://{parsed.hostname}"
@property
def port(self) -> str | None:
"""Get port from URI."""
if self.uri is None:
return None
parsed = urlparse(self.uri)
return str(parsed.port) if parsed.port else None
@property
def protocol(self) -> str:
"""Get protocol/scheme from URI."""
if self.uri is None:
return "http"
parsed = urlparse(self.uri)
return parsed.scheme or "http"
@property
def hostname(self) -> str | None:
"""Get hostname from URI."""
if self.uri is None:
return None
parsed = urlparse(self.uri)
return parsed.hostname
@property
def connection_type(self) -> "DBType":
"""Get database type from class."""
# Map class to DBType - need to import here to avoid circular import
from .config_mapping import DB_TYPE_MAPPING
# Reverse lookup: find DBType for this class
for db_type, config_class in DB_TYPE_MAPPING.items():
if type(self) is config_class:
return db_type
# Fallback (shouldn't happen)
return DBType.ARANGO
def can_be_source(self) -> bool:
"""Check if this database type can be used as a source."""
return self.connection_type in SOURCE_DATABASES
def can_be_target(self) -> bool:
"""Check if this database type can be used as a target."""
return self.connection_type in TARGET_DATABASES
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DBConfig":
"""Create a connection config from a dictionary."""
if not isinstance(data, dict):
raise TypeError(f"Expected dict, got {type(data)}")
# Copy the data to avoid modifying the original
config_data = data.copy()
db_type = config_data.pop("db_type", None) or config_data.pop(
"connection_type", None
)
if not db_type:
raise ValueError("Missing 'db_type' or 'connection_type' in configuration")
try:
conn_type = DBType(db_type)
except ValueError:
raise ValueError(
f"Database type '{db_type}' not supported. "
f"Should be one of: {list(DBType)}"
)
# Map old 'url' field to 'uri' for backward compatibility
if "url" in config_data and "uri" not in config_data:
config_data["uri"] = config_data.pop("url")
# Map old credential fields
if "cred_name" in config_data and "username" not in config_data:
config_data["username"] = config_data.pop("cred_name")
if "cred_pass" in config_data and "password" not in config_data:
config_data["password"] = config_data.pop("cred_pass")
# Construct URI from protocol/hostname/port if uri is not provided
if "uri" not in config_data:
protocol = config_data.pop("protocol", "http")
hostname = config_data.pop("hostname", None)
port = config_data.pop("port", None)
hosts = config_data.pop("hosts", None)
if hosts:
# Use hosts as URI
config_data["uri"] = hosts
elif hostname:
# Construct URI from components
if port:
config_data["uri"] = f"{protocol}://{hostname}:{port}"
else:
config_data["uri"] = f"{protocol}://{hostname}"
# Get the appropriate config class and initialize it
from .config_mapping import get_config_class
config_class = get_config_class(conn_type)
return config_class(**config_data)
@classmethod
def from_docker_env(cls, docker_dir: str | Path | None = None) -> "DBConfig":
"""Load config from docker .env file.
Args:
docker_dir: Path to docker directory. If None, uses default based on db type.
Returns:
DBConfig instance loaded from .env file
"""
raise NotImplementedError("Subclasses must implement from_docker_env")
@classmethod
def from_env(
cls: Type[T],
*,
prefix: str | None = None,
profile: str | None = None,
suffix: str | None = None,
) -> T:
"""Load config from environment variables using Pydantic BaseSettings.
Supports qualifiers for multiple configs from the same env:
- **prefix**: outer prefix → ``{prefix}_{BASE_PREFIX}URI`` (e.g. ``USER_ARANGO_URI``).
- **profile**: segment after base → ``{BASE_PREFIX}{profile}_URI`` (e.g. ``ARANGO_DEV_URI``).
- **suffix**: after field name → ``{BASE_PREFIX}URI_{suffix}`` (e.g. ``ARANGO_URI_DEV``).
At most one of ``prefix``, ``profile``, ``suffix`` should be set.
Args:
prefix: Outer env prefix (e.g. ``"USER"`` → ``USER_ARANGO_URI``).
profile: Env segment after base (e.g. ``"DEV"`` → ``ARANGO_DEV_URI``).
suffix: Env segment after field name (e.g. ``"DEV"`` → ``ARANGO_URI_DEV``).
Returns:
DBConfig instance loaded from environment variables.
Examples:
# Default (ARANGO_URI, ARANGO_USERNAME, ...)
config = ArangoConfig.from_env()
# By profile: ARANGO_DEV_URI, ARANGO_DEV_USERNAME, ...
dev = ArangoConfig.from_env(profile="DEV")
# By suffix: ARANGO_URI_DEV, ARANGO_USERNAME_DEV, ...
dev2 = ArangoConfig.from_env(suffix="DEV")
# Outer prefix: USER_ARANGO_URI, ...
user_config = ArangoConfig.from_env(prefix="USER")
"""
base_prefix = cls.model_config.get("env_prefix")
if not base_prefix:
raise ValueError(
f"Class {cls.__name__} does not have env_prefix configured in model_config"
)
case_sensitive = cls.model_config.get("case_sensitive", False)
qualifiers = sum(1 for q in (prefix, profile, suffix) if q is not None)
if qualifiers > 1:
raise ValueError("At most one of prefix, profile, suffix may be set")
if suffix:
# Pydantic doesn't support env_suffix; read suffixed vars manually.
data: Dict[str, Any] = {}
suf = suffix if case_sensitive else suffix.upper()
for name in cls.model_fields:
env_name = f"{base_prefix}{name.upper()}_{suf}"
if not case_sensitive:
# Match pydantic-settings: first try exact, then uppercase
val = os.environ.get(env_name) or os.environ.get(env_name.lower())
else:
val = os.environ.get(env_name)
if val is not None:
data[name] = val
return cls(**data)
if prefix:
new_prefix = f"{prefix.upper()}_{base_prefix}"
elif profile:
new_prefix = f"{base_prefix}{profile.upper()}_"
else:
return cls()
model_config = SettingsConfigDict(
env_prefix=new_prefix,
case_sensitive=case_sensitive,
)
temp_class = type(
f"{cls.__name__}WithPrefix", (cls,), {"model_config": model_config}
)
return temp_class()