Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/dbslice/adapters/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ def __init__(
profiler: Any = None,
schema: str | None = None,
allow_unsafe_where: bool = False,
statement_timeout_ms: int = 0,
):
self._conn: Any = None
self._schema_name = schema or "public"
self._schema_cache: SchemaGraph | None = None
self.batch_size = batch_size or self.DEFAULT_BATCH_SIZE
self.profiler = profiler
self.allow_unsafe_where = allow_unsafe_where
self.statement_timeout_ms = statement_timeout_ms

def connect(self, url: str) -> None:
"""Establish PostgreSQL connection."""
Expand Down Expand Up @@ -62,6 +64,16 @@ def connect(self, url: str) -> None:
# Use autocommit for reads by default
self._conn.autocommit = True

# Set statement_timeout to prevent runaway queries
if self.statement_timeout_ms > 0:
with self._conn.cursor() as cur:
cur.execute(
"SET statement_timeout = %s", (self.statement_timeout_ms,)
)
logger.debug(
"statement_timeout set", timeout_ms=self.statement_timeout_ms
)

# Set search_path so unqualified table names resolve to the target schema
if self._schema_name != "public":
with self._conn.cursor() as cur:
Expand Down Expand Up @@ -628,13 +640,17 @@ def fetch_referencing_pks(
cur.execute(query, params)
rows = cur.fetchall()
for row in rows:
result.add(row)
# Filter out NULL values (nullable FKs)
if None not in row:
result.add(row)
tracker.record_rows(len(rows))
else:
with self._conn.cursor() as cur:
cur.execute(query, params)
for row in cur.fetchall():
result.add(row)
# Filter out NULL values (nullable FKs)
if None not in row:
result.add(row)

logger.debug(
"Fetched referencing PKs with batching",
Expand Down
36 changes: 36 additions & 0 deletions src/dbslice/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from dbslice import __version__
from dbslice.config import ExtractConfig, OutputFormat, SeedSpec, TraversalDirection
from dbslice.constants import (
DEFAULT_MAX_SEED_ROWS,
DEFAULT_OUTPUT_FILE_MODE,
DEFAULT_STATEMENT_TIMEOUT_MS,
DEFAULT_STREAMING_CHUNK_SIZE,
DEFAULT_STREAMING_THRESHOLD,
DEFAULT_TRAVERSAL_DEPTH,
Expand Down Expand Up @@ -276,6 +278,8 @@ def _build_extract_config(
output_file_mode: int,
schema: str | None = None,
allow_unsafe_where: bool = False,
max_seed_rows: int = DEFAULT_MAX_SEED_ROWS,
statement_timeout_ms: int = DEFAULT_STATEMENT_TIMEOUT_MS,
) -> ExtractConfig:
"""
Build ExtractConfig from validated CLI parameters.
Expand All @@ -301,6 +305,7 @@ def _build_extract_config(
stream_threshold: Auto-enable streaming above this row count
stream_chunk_size: Streaming fetch chunk size
output_file_mode: Permissions mode for output files
max_seed_rows: Maximum rows a single seed query may return

Returns:
Configured ExtractConfig object
Expand Down Expand Up @@ -328,6 +333,8 @@ def _build_extract_config(
output_file_mode=output_file_mode,
schema=schema,
allow_unsafe_where=bool(allow_unsafe_where),
max_seed_rows=max_seed_rows,
statement_timeout_ms=statement_timeout_ms,
)


Expand Down Expand Up @@ -1095,6 +1102,20 @@ def extract(
help="Number of rows to fetch per chunk in streaming mode (default: 1000)",
),
] = None,
max_seed_rows: Annotated[
int | None,
typer.Option(
"--max-seed-rows",
help=f"Maximum rows a single seed query may return (default: {DEFAULT_MAX_SEED_ROWS})",
),
] = None,
statement_timeout: Annotated[
int | None,
typer.Option(
"--statement-timeout",
help="PostgreSQL statement timeout in milliseconds (0 = no timeout, default: 0)",
),
] = None,
output_file_mode: Annotated[
str | None,
typer.Option(
Expand Down Expand Up @@ -1257,6 +1278,12 @@ def extract(
effective_stream_chunk_size = (
stream_chunk_size if stream_chunk_size is not None else DEFAULT_STREAMING_CHUNK_SIZE
)
effective_max_seed_rows = (
max_seed_rows if max_seed_rows is not None else DEFAULT_MAX_SEED_ROWS
)
effective_statement_timeout = (
statement_timeout if statement_timeout is not None else DEFAULT_STATEMENT_TIMEOUT_MS
)
effective_output_file_mode = DEFAULT_OUTPUT_FILE_MODE

loaded_config = None
Expand Down Expand Up @@ -1377,6 +1404,10 @@ def extract(
raise ValueError("--stream-threshold must be greater than 0")
if effective_stream_chunk_size <= 0:
raise ValueError("--stream-chunk-size must be greater than 0")
if effective_max_seed_rows <= 0:
raise ValueError("--max-seed-rows must be greater than 0")
if effective_statement_timeout < 0:
raise ValueError("--statement-timeout must be >= 0")
except ValidationError as e:
console.print(f"[red]Validation Error:[/red] {e}")
raise typer.Exit(1)
Expand Down Expand Up @@ -1454,8 +1485,13 @@ def extract(
output_file_mode=effective_output_file_mode,
schema=schema,
allow_unsafe_where=effective_allow_unsafe_where,
max_seed_rows=effective_max_seed_rows,
)

# Apply CLI overrides to config
extract_config.max_seed_rows = effective_max_seed_rows
extract_config.statement_timeout_ms = effective_statement_timeout

# Apply compliance settings to extract config
extract_config.compliance_profiles = effective_compliance
extract_config.compliance_strict = effective_compliance_strict
Expand Down
67 changes: 67 additions & 0 deletions src/dbslice/compliance/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,73 @@ def sign(self, signing_key: str) -> None:
self.signature_algorithm = "hmac-sha256"
self.signature = f"hmac-sha256:{digest}"

def validate_compliance(
self,
extracted_tables: dict[str, list[dict[str, Any]]],
profiles: list[str],
) -> list[ManifestWarning]:
"""Validate that all columns required by compliance profiles were actually anonymized.

Compares the columns present in the extracted data against the manifest's
recorded actions. Any column that matches a compliance profile's required
pattern but was recorded as ``unmasked`` (or not recorded at all) is
flagged as a warning.

Args:
extracted_tables: The actual extracted data (table -> rows).
profiles: Names of active compliance profiles.

Returns:
List of ManifestWarning objects for columns that should have been
anonymized but were not.
"""
from dbslice.compliance.profiles import get_profile

new_warnings: list[ManifestWarning] = []

# Build a set of (table, column) pairs that were masked or nulled
protected: set[tuple[str, str]] = set()
for table_name, table_entry in self.tables.items():
for masked in table_entry.fields_masked:
protected.add((table_name, masked.column))
for nulled in table_entry.fields_nulled:
protected.add((table_name, nulled.column))
for col in table_entry.fields_preserved_fk:
protected.add((table_name, col))

for profile_name in profiles:
try:
profile = get_profile(profile_name)
except ValueError:
continue

required_patterns = profile.required_column_patterns

for table_name, rows in extracted_tables.items():
if not rows:
continue
columns = list(rows[0].keys())
for col in columns:
if (table_name, col) in protected:
continue
col_lower = col.lower()
for pattern in required_patterns:
if pattern in col_lower:
warning = ManifestWarning(
table=table_name,
column=col,
reason=(
f"Column matches {profile.display_name} required "
f"pattern '{pattern}' but was not anonymized"
),
severity="error",
)
new_warnings.append(warning)
self.warnings.append(warning)
break

return new_warnings

def to_dict(self) -> dict[str, Any]:
"""Convert to a JSON-serializable dictionary."""
tables_dict: dict[str, Any] = {}
Expand Down
4 changes: 4 additions & 0 deletions src/dbslice/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from typing import Any

from dbslice.constants import (
DEFAULT_MAX_SEED_ROWS,
DEFAULT_OUTPUT_FILE_MODE,
DEFAULT_STATEMENT_TIMEOUT_MS,
DEFAULT_STREAMING_CHUNK_SIZE,
DEFAULT_STREAMING_THRESHOLD,
DEFAULT_TRAVERSAL_DEPTH,
Expand Down Expand Up @@ -316,6 +318,8 @@ class ExtractConfig:
table_direction_overrides: dict[str, TraversalDirection] = field(default_factory=dict)
row_limit_global: int | None = None
row_limit_per_table: dict[str, int] = field(default_factory=dict)
max_seed_rows: int = DEFAULT_MAX_SEED_ROWS
statement_timeout_ms: int = DEFAULT_STATEMENT_TIMEOUT_MS
anonymization_seed: str | None = None
anonymization_field_providers: dict[str, str] = field(default_factory=dict)
anonymization_patterns: dict[str, str] = field(default_factory=dict)
Expand Down
25 changes: 25 additions & 0 deletions src/dbslice/config_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

from dbslice.config import ExtractConfig, OutputFormat, SeedSpec, TraversalDirection
from dbslice.constants import (
DEFAULT_MAX_SEED_ROWS,
DEFAULT_OUTPUT_FILE_MODE,
DEFAULT_STATEMENT_TIMEOUT_MS,
DEFAULT_STREAMING_CHUNK_SIZE,
DEFAULT_STREAMING_THRESHOLD,
DEFAULT_TRAVERSAL_DEPTH,
Expand Down Expand Up @@ -42,6 +44,8 @@
"fail_on_validation_error",
"max_rows_per_table",
"allow_unsafe_where",
"max_seed_rows",
"statement_timeout_ms",
}
_ANONYMIZATION_KEYS = {
"enabled",
Expand Down Expand Up @@ -309,6 +313,12 @@ class ExtractionConfig:
allow_unsafe_where: bool = False
"""Allow seed WHERE clauses with subqueries (trusted inputs only)."""

max_seed_rows: int | None = None
"""Maximum rows a single seed query may return (None = use global default)."""

statement_timeout_ms: int | None = None
"""PostgreSQL statement_timeout in milliseconds (None = use global default)."""


@dataclass
class AnonymizationConfig:
Expand Down Expand Up @@ -584,6 +594,8 @@ def _from_dict(cls, data: dict[str, Any]) -> "DbsliceConfig":
fail_on_validation_error=extraction_data.get("fail_on_validation_error", False),
max_rows_per_table=extraction_data.get("max_rows_per_table"),
allow_unsafe_where=extraction_data.get("allow_unsafe_where", False),
max_seed_rows=extraction_data.get("max_seed_rows"),
statement_timeout_ms=extraction_data.get("statement_timeout_ms"),
)

if not isinstance(extraction.default_depth, int) or extraction.default_depth < 1:
Expand All @@ -610,6 +622,15 @@ def _from_dict(cls, data: dict[str, Any]) -> "DbsliceConfig":
not isinstance(extraction.max_rows_per_table, int) or extraction.max_rows_per_table <= 0
):
raise ValueError("'extraction.max_rows_per_table' must be a positive integer")
if extraction.max_seed_rows is not None and (
not isinstance(extraction.max_seed_rows, int) or extraction.max_seed_rows <= 0
):
raise ValueError("'extraction.max_seed_rows' must be a positive integer")
if extraction.statement_timeout_ms is not None and (
not isinstance(extraction.statement_timeout_ms, int)
or extraction.statement_timeout_ms < 0
):
raise ValueError("'extraction.statement_timeout_ms' must be a non-negative integer")

anon_data = data.get("anonymization", {})
if not isinstance(anon_data, dict):
Expand Down Expand Up @@ -1195,6 +1216,8 @@ def to_extract_config(
virtual_foreign_keys=virtual_fks,
schema=final_schema,
allow_unsafe_where=final_allow_unsafe_where,
max_seed_rows=self.extraction.max_seed_rows or DEFAULT_MAX_SEED_ROWS,
statement_timeout_ms=self.extraction.statement_timeout_ms or DEFAULT_STATEMENT_TIMEOUT_MS,
compliance_profiles=self.compliance.profiles,
compliance_strict=self.compliance.strict,
generate_manifest=self.compliance.generate_manifest
Expand Down Expand Up @@ -1271,6 +1294,8 @@ def to_yaml(self, include_comments: bool = True) -> str:
output.append(f" - {table}")
if self.extraction.max_rows_per_table is not None:
output.append(f" max_rows_per_table: {self.extraction.max_rows_per_table}")
if self.extraction.max_seed_rows is not None:
output.append(f" max_seed_rows: {self.extraction.max_seed_rows}")
output.append("")

if include_comments:
Expand Down
9 changes: 9 additions & 0 deletions src/dbslice/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,12 @@

DEFAULT_OUTPUT_FILE_MODE = 0o600
"""Secure default permissions for newly created output files."""

DEFAULT_MAX_SEED_ROWS = 10000
"""Default maximum number of rows a single seed query may return."""

SEED_ROW_WARNING_THRESHOLD = 1000
"""Warn when a seed query returns more than this many rows."""

DEFAULT_STATEMENT_TIMEOUT_MS = 0
"""Default PostgreSQL statement_timeout in milliseconds (0 = no timeout)."""
Loading
Loading