From a2a0aec5979fabe5320696d84a41ec73eabc713a Mon Sep 17 00:00:00 2001 From: nabroleonx Date: Mon, 6 Apr 2026 17:03:51 +0300 Subject: [PATCH] fix: improve extraction engine and validation mechanisms Consolidate duplicate WHERE clause validation, fix NULL filtering inconsistency in FK traversal, add seed cardinality limits, fix Decimal precision loss in JSON output, type ExtractionResult fields, and refactor streaming deferred updates to use chunked fetching. Additionally addresses five previously unimplemented review findings: - Use \N sentinel for NULL in CSV output to distinguish from empty string - Raise error when passthrough table has no primary key (was silent skip) - Add --statement-timeout CLI flag for PostgreSQL query timeout - Validate anonymizer provider names at configure time (catch typos) - Add post-extraction compliance manifest validation BREAKING CHANGE: JSON output now serializes Decimal values as strings (e.g., "99.99") instead of floats to preserve exact precision. CSV output now uses \N for NULL values instead of empty string. Passthrough tables without a primary key now raise an error instead of being silently skipped. --- src/dbslice/adapters/postgresql.py | 20 ++- src/dbslice/cli.py | 36 +++++ src/dbslice/compliance/manifest.py | 67 ++++++++++ src/dbslice/config.py | 4 + src/dbslice/config_file.py | 25 ++++ src/dbslice/constants.py | 9 ++ src/dbslice/core/cycles.py | 87 +++++++++---- src/dbslice/core/engine.py | 88 +++++++++---- src/dbslice/core/graph.py | 8 +- src/dbslice/input_validators.py | 38 ++---- src/dbslice/output/csv_out.py | 7 +- src/dbslice/output/json_out.py | 6 +- src/dbslice/utils/anonymizer.py | 33 +++++ src/dbslice/validation.py | 2 +- tests/conftest.py | 4 +- tests/test_anonymizer.py | 39 ++++++ tests/test_compliance.py | 85 ++++++++++++ tests/test_csv_output.py | 4 +- tests/test_cycles.py | 143 ++++++++++++++++++++ tests/test_json_output.py | 13 +- tests/test_max_seed_rows.py | 203 +++++++++++++++++++++++++++++ tests/test_passthrough.py | 12 +- tests/test_performance.py | 2 +- tests/test_streaming.py | 147 +++++++++++++++++++++ tests/test_validation.py | 17 +++ tests/test_validators.py | 59 +++++++++ 26 files changed, 1056 insertions(+), 102 deletions(-) create mode 100644 tests/test_max_seed_rows.py diff --git a/src/dbslice/adapters/postgresql.py b/src/dbslice/adapters/postgresql.py index 99f9266..85d558f 100644 --- a/src/dbslice/adapters/postgresql.py +++ b/src/dbslice/adapters/postgresql.py @@ -27,6 +27,7 @@ def __init__( profiler: Any = None, schema: str | None = None, allow_unsafe_where: bool = False, + statement_timeout_ms: int = 0, ): self._conn: Any = None self._schema_name = schema or "public" @@ -34,6 +35,7 @@ def __init__( self.batch_size = batch_size or self.DEFAULT_BATCH_SIZE self.profiler = profiler self.allow_unsafe_where = allow_unsafe_where + self.statement_timeout_ms = statement_timeout_ms def connect(self, url: str) -> None: """Establish PostgreSQL connection.""" @@ -62,6 +64,16 @@ def connect(self, url: str) -> None: # Use autocommit for reads by default self._conn.autocommit = True + # Set statement_timeout to prevent runaway queries + if self.statement_timeout_ms > 0: + with self._conn.cursor() as cur: + cur.execute( + "SET statement_timeout = %s", (self.statement_timeout_ms,) + ) + logger.debug( + "statement_timeout set", timeout_ms=self.statement_timeout_ms + ) + # Set search_path so unqualified table names resolve to the target schema if self._schema_name != "public": with self._conn.cursor() as cur: @@ -628,13 +640,17 @@ def fetch_referencing_pks( cur.execute(query, params) rows = cur.fetchall() for row in rows: - result.add(row) + # Filter out NULL values (nullable FKs) + if None not in row: + result.add(row) tracker.record_rows(len(rows)) else: with self._conn.cursor() as cur: cur.execute(query, params) for row in cur.fetchall(): - result.add(row) + # Filter out NULL values (nullable FKs) + if None not in row: + result.add(row) logger.debug( "Fetched referencing PKs with batching", diff --git a/src/dbslice/cli.py b/src/dbslice/cli.py index ba6fa6e..de9f4d3 100644 --- a/src/dbslice/cli.py +++ b/src/dbslice/cli.py @@ -13,7 +13,9 @@ from dbslice import __version__ from dbslice.config import ExtractConfig, OutputFormat, SeedSpec, TraversalDirection from dbslice.constants import ( + DEFAULT_MAX_SEED_ROWS, DEFAULT_OUTPUT_FILE_MODE, + DEFAULT_STATEMENT_TIMEOUT_MS, DEFAULT_STREAMING_CHUNK_SIZE, DEFAULT_STREAMING_THRESHOLD, DEFAULT_TRAVERSAL_DEPTH, @@ -276,6 +278,8 @@ def _build_extract_config( output_file_mode: int, schema: str | None = None, allow_unsafe_where: bool = False, + max_seed_rows: int = DEFAULT_MAX_SEED_ROWS, + statement_timeout_ms: int = DEFAULT_STATEMENT_TIMEOUT_MS, ) -> ExtractConfig: """ Build ExtractConfig from validated CLI parameters. @@ -301,6 +305,7 @@ def _build_extract_config( stream_threshold: Auto-enable streaming above this row count stream_chunk_size: Streaming fetch chunk size output_file_mode: Permissions mode for output files + max_seed_rows: Maximum rows a single seed query may return Returns: Configured ExtractConfig object @@ -328,6 +333,8 @@ def _build_extract_config( output_file_mode=output_file_mode, schema=schema, allow_unsafe_where=bool(allow_unsafe_where), + max_seed_rows=max_seed_rows, + statement_timeout_ms=statement_timeout_ms, ) @@ -1095,6 +1102,20 @@ def extract( help="Number of rows to fetch per chunk in streaming mode (default: 1000)", ), ] = None, + max_seed_rows: Annotated[ + int | None, + typer.Option( + "--max-seed-rows", + help=f"Maximum rows a single seed query may return (default: {DEFAULT_MAX_SEED_ROWS})", + ), + ] = None, + statement_timeout: Annotated[ + int | None, + typer.Option( + "--statement-timeout", + help="PostgreSQL statement timeout in milliseconds (0 = no timeout, default: 0)", + ), + ] = None, output_file_mode: Annotated[ str | None, typer.Option( @@ -1257,6 +1278,12 @@ def extract( effective_stream_chunk_size = ( stream_chunk_size if stream_chunk_size is not None else DEFAULT_STREAMING_CHUNK_SIZE ) + effective_max_seed_rows = ( + max_seed_rows if max_seed_rows is not None else DEFAULT_MAX_SEED_ROWS + ) + effective_statement_timeout = ( + statement_timeout if statement_timeout is not None else DEFAULT_STATEMENT_TIMEOUT_MS + ) effective_output_file_mode = DEFAULT_OUTPUT_FILE_MODE loaded_config = None @@ -1377,6 +1404,10 @@ def extract( raise ValueError("--stream-threshold must be greater than 0") if effective_stream_chunk_size <= 0: raise ValueError("--stream-chunk-size must be greater than 0") + if effective_max_seed_rows <= 0: + raise ValueError("--max-seed-rows must be greater than 0") + if effective_statement_timeout < 0: + raise ValueError("--statement-timeout must be >= 0") except ValidationError as e: console.print(f"[red]Validation Error:[/red] {e}") raise typer.Exit(1) @@ -1454,8 +1485,13 @@ def extract( output_file_mode=effective_output_file_mode, schema=schema, allow_unsafe_where=effective_allow_unsafe_where, + max_seed_rows=effective_max_seed_rows, ) + # Apply CLI overrides to config + extract_config.max_seed_rows = effective_max_seed_rows + extract_config.statement_timeout_ms = effective_statement_timeout + # Apply compliance settings to extract config extract_config.compliance_profiles = effective_compliance extract_config.compliance_strict = effective_compliance_strict diff --git a/src/dbslice/compliance/manifest.py b/src/dbslice/compliance/manifest.py index fb33e52..c4d467f 100644 --- a/src/dbslice/compliance/manifest.py +++ b/src/dbslice/compliance/manifest.py @@ -200,6 +200,73 @@ def sign(self, signing_key: str) -> None: self.signature_algorithm = "hmac-sha256" self.signature = f"hmac-sha256:{digest}" + def validate_compliance( + self, + extracted_tables: dict[str, list[dict[str, Any]]], + profiles: list[str], + ) -> list[ManifestWarning]: + """Validate that all columns required by compliance profiles were actually anonymized. + + Compares the columns present in the extracted data against the manifest's + recorded actions. Any column that matches a compliance profile's required + pattern but was recorded as ``unmasked`` (or not recorded at all) is + flagged as a warning. + + Args: + extracted_tables: The actual extracted data (table -> rows). + profiles: Names of active compliance profiles. + + Returns: + List of ManifestWarning objects for columns that should have been + anonymized but were not. + """ + from dbslice.compliance.profiles import get_profile + + new_warnings: list[ManifestWarning] = [] + + # Build a set of (table, column) pairs that were masked or nulled + protected: set[tuple[str, str]] = set() + for table_name, table_entry in self.tables.items(): + for masked in table_entry.fields_masked: + protected.add((table_name, masked.column)) + for nulled in table_entry.fields_nulled: + protected.add((table_name, nulled.column)) + for col in table_entry.fields_preserved_fk: + protected.add((table_name, col)) + + for profile_name in profiles: + try: + profile = get_profile(profile_name) + except ValueError: + continue + + required_patterns = profile.required_column_patterns + + for table_name, rows in extracted_tables.items(): + if not rows: + continue + columns = list(rows[0].keys()) + for col in columns: + if (table_name, col) in protected: + continue + col_lower = col.lower() + for pattern in required_patterns: + if pattern in col_lower: + warning = ManifestWarning( + table=table_name, + column=col, + reason=( + f"Column matches {profile.display_name} required " + f"pattern '{pattern}' but was not anonymized" + ), + severity="error", + ) + new_warnings.append(warning) + self.warnings.append(warning) + break + + return new_warnings + def to_dict(self) -> dict[str, Any]: """Convert to a JSON-serializable dictionary.""" tables_dict: dict[str, Any] = {} diff --git a/src/dbslice/config.py b/src/dbslice/config.py index 98a2dea..bddeb2c 100644 --- a/src/dbslice/config.py +++ b/src/dbslice/config.py @@ -5,7 +5,9 @@ from typing import Any from dbslice.constants import ( + DEFAULT_MAX_SEED_ROWS, DEFAULT_OUTPUT_FILE_MODE, + DEFAULT_STATEMENT_TIMEOUT_MS, DEFAULT_STREAMING_CHUNK_SIZE, DEFAULT_STREAMING_THRESHOLD, DEFAULT_TRAVERSAL_DEPTH, @@ -316,6 +318,8 @@ class ExtractConfig: table_direction_overrides: dict[str, TraversalDirection] = field(default_factory=dict) row_limit_global: int | None = None row_limit_per_table: dict[str, int] = field(default_factory=dict) + max_seed_rows: int = DEFAULT_MAX_SEED_ROWS + statement_timeout_ms: int = DEFAULT_STATEMENT_TIMEOUT_MS anonymization_seed: str | None = None anonymization_field_providers: dict[str, str] = field(default_factory=dict) anonymization_patterns: dict[str, str] = field(default_factory=dict) diff --git a/src/dbslice/config_file.py b/src/dbslice/config_file.py index 78a8e3c..cae7264 100644 --- a/src/dbslice/config_file.py +++ b/src/dbslice/config_file.py @@ -10,7 +10,9 @@ from dbslice.config import ExtractConfig, OutputFormat, SeedSpec, TraversalDirection from dbslice.constants import ( + DEFAULT_MAX_SEED_ROWS, DEFAULT_OUTPUT_FILE_MODE, + DEFAULT_STATEMENT_TIMEOUT_MS, DEFAULT_STREAMING_CHUNK_SIZE, DEFAULT_STREAMING_THRESHOLD, DEFAULT_TRAVERSAL_DEPTH, @@ -42,6 +44,8 @@ "fail_on_validation_error", "max_rows_per_table", "allow_unsafe_where", + "max_seed_rows", + "statement_timeout_ms", } _ANONYMIZATION_KEYS = { "enabled", @@ -309,6 +313,12 @@ class ExtractionConfig: allow_unsafe_where: bool = False """Allow seed WHERE clauses with subqueries (trusted inputs only).""" + max_seed_rows: int | None = None + """Maximum rows a single seed query may return (None = use global default).""" + + statement_timeout_ms: int | None = None + """PostgreSQL statement_timeout in milliseconds (None = use global default).""" + @dataclass class AnonymizationConfig: @@ -584,6 +594,8 @@ def _from_dict(cls, data: dict[str, Any]) -> "DbsliceConfig": fail_on_validation_error=extraction_data.get("fail_on_validation_error", False), max_rows_per_table=extraction_data.get("max_rows_per_table"), allow_unsafe_where=extraction_data.get("allow_unsafe_where", False), + max_seed_rows=extraction_data.get("max_seed_rows"), + statement_timeout_ms=extraction_data.get("statement_timeout_ms"), ) if not isinstance(extraction.default_depth, int) or extraction.default_depth < 1: @@ -610,6 +622,15 @@ def _from_dict(cls, data: dict[str, Any]) -> "DbsliceConfig": not isinstance(extraction.max_rows_per_table, int) or extraction.max_rows_per_table <= 0 ): raise ValueError("'extraction.max_rows_per_table' must be a positive integer") + if extraction.max_seed_rows is not None and ( + not isinstance(extraction.max_seed_rows, int) or extraction.max_seed_rows <= 0 + ): + raise ValueError("'extraction.max_seed_rows' must be a positive integer") + if extraction.statement_timeout_ms is not None and ( + not isinstance(extraction.statement_timeout_ms, int) + or extraction.statement_timeout_ms < 0 + ): + raise ValueError("'extraction.statement_timeout_ms' must be a non-negative integer") anon_data = data.get("anonymization", {}) if not isinstance(anon_data, dict): @@ -1195,6 +1216,8 @@ def to_extract_config( virtual_foreign_keys=virtual_fks, schema=final_schema, allow_unsafe_where=final_allow_unsafe_where, + max_seed_rows=self.extraction.max_seed_rows or DEFAULT_MAX_SEED_ROWS, + statement_timeout_ms=self.extraction.statement_timeout_ms or DEFAULT_STATEMENT_TIMEOUT_MS, compliance_profiles=self.compliance.profiles, compliance_strict=self.compliance.strict, generate_manifest=self.compliance.generate_manifest @@ -1271,6 +1294,8 @@ def to_yaml(self, include_comments: bool = True) -> str: output.append(f" - {table}") if self.extraction.max_rows_per_table is not None: output.append(f" max_rows_per_table: {self.extraction.max_rows_per_table}") + if self.extraction.max_seed_rows is not None: + output.append(f" max_seed_rows: {self.extraction.max_seed_rows}") output.append("") if include_comments: diff --git a/src/dbslice/constants.py b/src/dbslice/constants.py index ee55822..bc947de 100644 --- a/src/dbslice/constants.py +++ b/src/dbslice/constants.py @@ -30,3 +30,12 @@ DEFAULT_OUTPUT_FILE_MODE = 0o600 """Secure default permissions for newly created output files.""" + +DEFAULT_MAX_SEED_ROWS = 10000 +"""Default maximum number of rows a single seed query may return.""" + +SEED_ROW_WARNING_THRESHOLD = 1000 +"""Warn when a seed query returns more than this many rows.""" + +DEFAULT_STATEMENT_TIMEOUT_MS = 0 +"""Default PostgreSQL statement_timeout in milliseconds (0 = no timeout).""" diff --git a/src/dbslice/core/cycles.py b/src/dbslice/core/cycles.py index 438a656..61aa78d 100644 --- a/src/dbslice/core/cycles.py +++ b/src/dbslice/core/cycles.py @@ -1,3 +1,4 @@ +from collections.abc import Iterator from dataclasses import dataclass from typing import Any @@ -252,42 +253,74 @@ def build_deferred_updates( Returns: List of DeferredUpdate objects describing the UPDATE statements needed """ - deferred_updates = [] + # Wrap each table's row list as a single-chunk iterator to reuse the + # chunked implementation, avoiding duplicated logic. + chunk_iterators: dict[str, Iterator[list[dict[str, Any]]]] = { + table: iter([rows]) for table, rows in tables_data.items() + } + return build_deferred_updates_chunked(fks_to_break, schema, chunk_iterators) + +def build_deferred_updates_chunked( + fks_to_break: list[ForeignKey], + schema: SchemaGraph, + chunk_iterators: dict[str, Iterator[list[dict[str, Any]]]], +) -> list[DeferredUpdate]: + """Build deferred UPDATE statements by processing rows in chunks. + + Unlike `build_deferred_updates` which requires all row data in memory, + this variant processes rows incrementally via chunk iterators. Only the + resulting DeferredUpdate objects (which contain just PK and FK values) + are accumulated in memory. + + Args: + fks_to_break: List of ForeignKey objects that were broken + schema: Database schema graph + chunk_iterators: Map of table name to an iterator yielding chunks + (lists) of row dicts + + Returns: + List of DeferredUpdate objects describing the UPDATE statements needed + """ + # Group broken FKs by source table so we can process each table's + # chunks once and check all relevant FKs per row. + fks_by_table: dict[str, list[ForeignKey]] = {} for fk in fks_to_break: - source_table = fk.source_table - fk_columns = fk.source_columns + fks_by_table.setdefault(fk.source_table, []).append(fk) - if source_table not in tables_data: + deferred_updates: list[DeferredUpdate] = [] + + for table, table_fks in fks_by_table.items(): + if table not in chunk_iterators: continue - table_info = schema.get_table(source_table) + table_info = schema.get_table(table) if not table_info: continue pk_columns = table_info.primary_key - for row_data in tables_data[source_table]: - for fk_col in fk_columns: - if fk_col not in row_data: - continue - - fk_value = row_data[fk_col] - - # Skip if value is already NULL (no update needed) - if fk_value is None: - continue - - pk_values = tuple(row_data[col] for col in pk_columns) - - deferred_updates.append( - DeferredUpdate( - table=source_table, - pk_columns=pk_columns, - pk_values=pk_values, - fk_column=fk_col, - fk_value=fk_value, - ) - ) + for chunk in chunk_iterators[table]: + for row_data in chunk: + for fk in table_fks: + for fk_col in fk.source_columns: + if fk_col not in row_data: + continue + + fk_value = row_data[fk_col] + if fk_value is None: + continue + + pk_values = tuple(row_data[col] for col in pk_columns) + + deferred_updates.append( + DeferredUpdate( + table=table, + pk_columns=pk_columns, + pk_values=pk_values, + fk_column=fk_col, + fk_value=fk_value, + ) + ) return deferred_updates diff --git a/src/dbslice/core/engine.py b/src/dbslice/core/engine.py index e9f9266..efe785a 100644 --- a/src/dbslice/core/engine.py +++ b/src/dbslice/core/engine.py @@ -12,7 +12,12 @@ SeedSpec, TraversalDirection, ) -from dbslice.constants import DEFAULT_ANONYMIZATION_SEED, DEFAULT_TRAVERSAL_DEPTH +from dbslice.constants import ( + DEFAULT_ANONYMIZATION_SEED, + DEFAULT_TRAVERSAL_DEPTH, + SEED_ROW_WARNING_THRESHOLD, +) +from dbslice.core.cycles import CycleInfo, DeferredUpdate from dbslice.core.graph import GraphTraverser, TraversalConfig, TraversalResult from dbslice.exceptions import ( ExtractionError, @@ -24,6 +29,7 @@ from dbslice.output.sql import SQLGenerator from dbslice.utils.anonymizer import DeterministicAnonymizer from dbslice.utils.connection import get_adapter_for_url, parse_database_url +from dbslice.utils.profiling import QueryProfiler from dbslice.validation import ExtractionValidator, ValidationResult logger = get_logger(__name__) @@ -72,11 +78,11 @@ class ExtractionResult: stats: dict[str, int] = field(default_factory=dict) traversal_path: list[str] = field(default_factory=list) has_cycles: bool = False - broken_fks: list[Any] = field(default_factory=list) # list[ForeignKey] - deferred_updates: list[Any] = field(default_factory=list) # list[DeferredUpdate] - cycle_infos: list[Any] = field(default_factory=list) # list[CycleInfo] + broken_fks: list[ForeignKey] = field(default_factory=list) + deferred_updates: list[DeferredUpdate] = field(default_factory=list) + cycle_infos: list[CycleInfo] = field(default_factory=list) validation_result: ValidationResult | None = None - profiler: Any = None # Optional QueryProfiler + profiler: QueryProfiler | None = None used_deferred_cycle_strategy: bool = False def total_rows(self) -> int: @@ -206,6 +212,7 @@ def extract(self) -> tuple[ExtractionResult, SchemaGraph]: profiler=profiler, schema=self.config.schema, allow_unsafe_where=self.config.allow_unsafe_where, + statement_timeout_ms=self.config.statement_timeout_ms, ) else: self.adapter = get_adapter_for_url(self.config.database_url) @@ -536,6 +543,17 @@ def _do_extract(self, db_type: DatabaseType) -> ExtractionResult: for table, rows in tables_data.items(): self.manifest.set_table_row_count(table, len(rows)) + if self.config.compliance_profiles: + compliance_warnings = self.manifest.validate_compliance( + tables_data, self.config.compliance_profiles + ) + if compliance_warnings: + logger.warning( + "Compliance validation found unanonymized columns", + warning_count=len(compliance_warnings), + profiles=self.config.compliance_profiles, + ) + return ExtractionResult( tables=tables_data, insert_order=insert_order, @@ -994,6 +1012,22 @@ def _process_seed(self, seed: SeedSpec) -> TraversalResult: ) seed_rows = list(self.adapter.fetch_rows(seed.table, where_clause, params)) + if len(seed_rows) > self.config.max_seed_rows: + raise ExtractionError( + f"Seed query returned {len(seed_rows)} rows, exceeding limit of " + f"{self.config.max_seed_rows}. Use --max-seed-rows to increase the " + f"limit or refine your WHERE clause.", + table=seed.table, + ) + + if len(seed_rows) > SEED_ROW_WARNING_THRESHOLD: + logger.warning( + "Seed query returned a large number of rows", + table=seed.table, + row_count=len(seed_rows), + limit=self.config.max_seed_rows, + ) + if not seed_rows: seed_str = ( f"{seed.table}.{seed.column}={seed.value}" @@ -1060,8 +1094,8 @@ def _do_streaming_extract( db_type: DatabaseType, all_records: dict[str, set[tuple[Any, ...]]], insert_order: list[str], - broken_fks: list[Any], - cycle_infos: list[Any], + broken_fks: list[ForeignKey], + cycle_infos: list[CycleInfo], all_paths: list[str], used_deferred_cycle_strategy: bool, ) -> ExtractionResult: @@ -1091,27 +1125,36 @@ def _do_streaming_extract( from dbslice.core.streaming import StreamingExtractionEngine - deferred_updates = [] + deferred_updates: list[Any] = [] if broken_fks: - from dbslice.core.cycles import build_deferred_updates + from dbslice.core.cycles import build_deferred_updates_chunked self._log("cycles", f"Breaking {len(broken_fks)} circular reference(s)...") - # For streaming mode, we need to build deferred updates without having - # all data in memory. We fetch the necessary data on-demand. with logger.timed_operation("build_deferred_updates_streaming"): - temp_data = {} + tables_needed: set[str] = set() for fk in broken_fks: - table = fk.source_table - if table not in temp_data and table in all_records: - pk_values = all_records[table] - table_info = self.schema.get_table(table) - if table_info: - pk_columns = table_info.primary_key - rows = list(self.adapter.fetch_by_pk(table, pk_columns, pk_values)) - temp_data[table] = rows + if fk.source_table in all_records: + tables_needed.add(fk.source_table) + + chunk_iterators = {} + for table in tables_needed: + pk_values = all_records[table] + table_info = self.schema.get_table(table) + if table_info: + pk_columns = table_info.primary_key + chunk_iterators[table] = ( + self.adapter.fetch_by_pk_chunked( + table, + pk_columns, + pk_values, + chunk_size=self.config.streaming_chunk_size, + ) + ) - deferred_updates = build_deferred_updates(broken_fks, temp_data, self.schema) + deferred_updates = build_deferred_updates_chunked( + broken_fks, self.schema, chunk_iterators + ) logger.info( "Circular references resolved", @@ -1145,7 +1188,7 @@ def _do_streaming_extract( def _topological_sort( self, tables: set[str], db_type: DatabaseType - ) -> tuple[list[str], list[Any], list[Any], bool]: + ) -> tuple[list[str], list[ForeignKey], list[CycleInfo], bool]: """ Topologically sort tables based on FK dependencies with cycle handling. @@ -1158,7 +1201,6 @@ def _topological_sort( assert self.schema is not None from dbslice.core.cycles import ( - CycleInfo, break_cycles_at_nullable_fks, find_cycles_dfs, identify_cycle_fks, diff --git a/src/dbslice/core/graph.py b/src/dbslice/core/graph.py index 587cc3b..024aee8 100644 --- a/src/dbslice/core/graph.py +++ b/src/dbslice/core/graph.py @@ -5,6 +5,7 @@ from dbslice.adapters.base import DatabaseAdapter from dbslice.config import TraversalDirection from dbslice.constants import DEFAULT_TRAVERSAL_DEPTH +from dbslice.exceptions import ExtractionError from dbslice.logging import get_logger from dbslice.models import ForeignKey, SchemaGraph @@ -372,11 +373,12 @@ def _process_passthrough_tables( pk_columns = table_info.primary_key if not pk_columns: - logger.warning( - "Passthrough table has no primary key, skipping", + raise ExtractionError( + f"Passthrough table '{table}' has no primary key. " + f"Passthrough tables must have a primary key to extract all rows. " + f"Remove it from passthrough or add a primary key.", table=table, ) - continue logger.debug("Fetching all rows from passthrough table", table=table) all_pks = self.adapter.fetch_all_pks(table, pk_columns) diff --git a/src/dbslice/input_validators.py b/src/dbslice/input_validators.py index e870e7a..80052b9 100644 --- a/src/dbslice/input_validators.py +++ b/src/dbslice/input_validators.py @@ -187,10 +187,12 @@ def validate_column_name(column: str) -> None: def validate_where_clause(where_clause: str) -> None: """ - Validate a WHERE clause for basic safety. + Validate a WHERE clause for safety against SQL injection. - Note: This is basic validation. SQL injection protection should also - be handled by using parameterized queries in the database layer. + Delegates to the comprehensive implementation in ``config.validate_where_clause`` + which handles Unicode normalization, quote stripping, dollar-quoting, type casts, + and dangerous PostgreSQL functions. This wrapper adds a length check and converts + the exception type so callers that expect ``SeedValidationError`` are unaffected. Args: where_clause: The WHERE clause to validate @@ -211,30 +213,16 @@ def validate_where_clause(where_clause: str) -> None: f"WHERE clause too long (max {MAX_WHERE_CLAUSE_LENGTH} characters)", ) - if ";" in where_clause: + from dbslice.config import validate_where_clause as _config_validate + from dbslice.exceptions import InsecureWhereClauseError + + try: + _config_validate(where_clause) + except InsecureWhereClauseError as exc: raise SeedValidationError( where_clause, - "WHERE clause contains potentially dangerous SQL patterns (semicolon found)", - ) - - dangerous_patterns = [ - (r"\bdrop\s+table\b", "DROP TABLE"), - (r"\bdelete\s+from\b", "DELETE FROM"), - (r"\btruncate\b", "TRUNCATE"), - (r"\balter\s+table\b", "ALTER TABLE"), - (r"\bunion\s+select\b", "UNION SELECT"), - (r"\bexec\s*\(", "EXEC"), - (r"\bexecute\s*\(", "EXECUTE"), - (r"--", "SQL comment"), - (r"/\*", "SQL comment"), - ] - - where_lower = where_clause.lower() - for pattern, name in dangerous_patterns: - if re.search(pattern, where_lower, re.IGNORECASE): - raise SeedValidationError( - where_clause, f"WHERE clause contains potentially dangerous SQL patterns ({name})" - ) + f"WHERE clause contains potentially dangerous SQL patterns ({exc.dangerous_keyword})", + ) from exc def validate_seed_value(value: Any) -> None: diff --git a/src/dbslice/output/csv_out.py b/src/dbslice/output/csv_out.py index 4d30dd7..a5d8063 100644 --- a/src/dbslice/output/csv_out.py +++ b/src/dbslice/output/csv_out.py @@ -212,7 +212,7 @@ def _format_value(self, value: Any) -> str: Format a Python value as CSV field value. Type conversions: - - None -> empty string (CSV convention for NULL) + - None -> ``\\N`` (PostgreSQL COPY convention for NULL) - bool -> "true"/"false" - datetime -> ISO 8601 string - date -> ISO 8601 date string @@ -231,8 +231,9 @@ def _format_value(self, value: Any) -> str: String representation suitable for CSV """ if value is None: - # CSV convention: NULL is represented as empty field - return "" + # Use \N sentinel (PostgreSQL COPY convention) to distinguish + # NULL from empty string, enabling lossless round-trips. + return "\\N" if isinstance(value, bool): # Use lowercase for consistency with JSON diff --git a/src/dbslice/output/json_out.py b/src/dbslice/output/json_out.py index bdf6be3..28b6584 100644 --- a/src/dbslice/output/json_out.py +++ b/src/dbslice/output/json_out.py @@ -19,7 +19,7 @@ class DatabaseTypeEncoder(json.JSONEncoder): - date -> ISO 8601 date string (YYYY-MM-DD) - time -> ISO 8601 time string (HH:MM:SS[.ffffff]) - timedelta -> total seconds (as float) - - Decimal -> float + - Decimal -> string (preserves exact precision) - UUID -> string - bytes -> hex string - Any other non-serializable type -> string representation @@ -49,8 +49,8 @@ def default(self, obj: Any) -> Any: return obj.total_seconds() if isinstance(obj, Decimal): - # Convert to float for JSON compatibility - return float(obj) + # Convert to string to preserve exact precision for financial data + return str(obj) if isinstance(obj, UUID): return str(obj) diff --git a/src/dbslice/utils/anonymizer.py b/src/dbslice/utils/anonymizer.py index afbfddd..f47f6c9 100644 --- a/src/dbslice/utils/anonymizer.py +++ b/src/dbslice/utils/anonymizer.py @@ -266,6 +266,9 @@ def configure( ] self.security_null_fields = [pattern.lower() for pattern in (security_null_fields or [])] + # Validate that all explicit provider names are resolvable + self._validate_provider_names() + logger.info( "Anonymizer configured", redact_field_count=len(self.redact_fields), @@ -275,6 +278,36 @@ def configure( security_null_pattern_count=len(self.security_null_fields), ) + def _validate_provider_names(self) -> None: + """Validate that all configured provider names are valid Faker methods or custom transformers.""" + from dbslice.compliance.transformers import CUSTOM_TRANSFORMERS + + # Collect all explicitly configured provider names + all_providers: dict[str, str] = {} # provider -> source description + for field, provider in self.field_providers.items(): + all_providers.setdefault(provider, f"field_providers[{field}]") + for pattern, provider in self.custom_patterns: + all_providers.setdefault(provider, f"patterns[{pattern}]") + for pattern, provider in self.fallback_patterns: + all_providers.setdefault(provider, f"fallback_patterns[{pattern}]") + + invalid = [] + for provider, source in all_providers.items(): + if provider in CUSTOM_TRANSFORMERS: + continue + if hasattr(self.fake, provider): + continue + invalid.append((provider, source)) + + if invalid: + details = ", ".join(f"'{p}' (from {s})" for p, s in invalid) + raise ValueError( + f"Unknown anonymization provider(s): {details}. " + f"Each provider must be a valid Faker method or a custom transformer " + f"({', '.join(sorted(CUSTOM_TRANSFORMERS))}). " + f"Check for typos in your anonymization configuration." + ) + def _is_foreign_key_column(self, table: str, column: str) -> bool: """ Check if a column is part of a foreign key. diff --git a/src/dbslice/validation.py b/src/dbslice/validation.py index 240fafa..b5ac218 100644 --- a/src/dbslice/validation.py +++ b/src/dbslice/validation.py @@ -295,7 +295,7 @@ def _extract_fk_values( Returns: Tuple of FK values """ - return tuple(row.get(col) for col in fk_columns) + return tuple(row[col] for col in fk_columns) def _has_parent_record( self, diff --git a/tests/conftest.py b/tests/conftest.py index 93d94c4..f1c2cca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -305,7 +305,9 @@ def fetch_referencing_pks( fk_tuple = tuple(row[col] for col in fk_cols) if fk_tuple in target_pk_values: pk_tuple = tuple(row[col] for col in pk_cols) - result.add(pk_tuple) + # Filter out NULL values (nullable FKs) + if None not in pk_tuple: + result.add(pk_tuple) return result diff --git a/tests/test_anonymizer.py b/tests/test_anonymizer.py index dc86580..fec03e1 100644 --- a/tests/test_anonymizer.py +++ b/tests/test_anonymizer.py @@ -1,5 +1,7 @@ """Tests for anonymization functionality.""" +import pytest + from dbslice.models import Column, ForeignKey, SchemaGraph, Table from dbslice.utils.anonymizer import DeterministicAnonymizer @@ -548,3 +550,40 @@ def test_column_name_in_hash(self): assert result_firstname == result_firstname2 assert result_lastname == result_lastname2 + + def test_invalid_field_provider_raises(self): + """Typo in field_providers should raise ValueError at configure time.""" + anon = DeterministicAnonymizer() + with pytest.raises(ValueError, match="Unknown anonymization provider"): + anon.configure( + [], + field_providers={"users.email": "nonexistent_faker_method"}, + ) + + def test_invalid_pattern_provider_raises(self): + """Typo in pattern provider should raise ValueError at configure time.""" + anon = DeterministicAnonymizer() + with pytest.raises(ValueError, match="Unknown anonymization provider"): + anon.configure( + [], + patterns={"*.email": "hipaa_zip_3"}, # typo: should be hipaa_zip3 + ) + + def test_custom_transformer_names_are_valid(self): + """Custom transformer names (e.g., hipaa_zip3) should pass validation.""" + anon = DeterministicAnonymizer() + # Should not raise + anon.configure( + [], + patterns={"*.zip": "hipaa_zip3", "*.dob": "year_only"}, + ) + + def test_valid_faker_methods_pass_validation(self): + """Standard Faker method names should pass validation.""" + anon = DeterministicAnonymizer() + # Should not raise + anon.configure( + [], + field_providers={"users.email": "email", "users.name": "name"}, + patterns={"*.phone": "phone_number"}, + ) diff --git a/tests/test_compliance.py b/tests/test_compliance.py index c50b52b..7350913 100644 --- a/tests/test_compliance.py +++ b/tests/test_compliance.py @@ -699,3 +699,88 @@ def fetch_rows(self, table: str, where_clause: str, params: tuple[object, ...]): target_table=None, console=MagicMock(), ) + + +class TestManifestComplianceValidation: + """Tests for post-extraction compliance manifest validation.""" + + def test_detects_unanonymized_column(self): + """Columns matching compliance patterns that were not masked should be flagged.""" + manifest = ComplianceManifest() + manifest.initialize("test-123", compliance_profiles=["gdpr"]) + + # Record that users.id was unmasked (fine) and users.email was unmasked (bad) + manifest.record_unmasked_field("users", "id") + manifest.record_unmasked_field("users", "email") + + extracted = { + "users": [{"id": 1, "email": "alice@example.com"}], + } + + warnings = manifest.validate_compliance(extracted, ["gdpr"]) + + assert len(warnings) == 1 + assert warnings[0].table == "users" + assert warnings[0].column == "email" + assert warnings[0].severity == "error" + assert "GDPR" in warnings[0].reason + + def test_no_warning_for_masked_column(self): + """Masked columns should not produce warnings.""" + manifest = ComplianceManifest() + manifest.initialize("test-123", compliance_profiles=["gdpr"]) + + manifest.record_masked_field("users", "email", "email") + + extracted = { + "users": [{"id": 1, "email": "fake@example.com"}], + } + + warnings = manifest.validate_compliance(extracted, ["gdpr"]) + + email_warnings = [w for w in warnings if w.column == "email"] + assert len(email_warnings) == 0 + + def test_no_warning_for_nulled_column(self): + """NULLed columns should not produce warnings.""" + manifest = ComplianceManifest() + manifest.initialize("test-123", compliance_profiles=["hipaa"]) + + manifest.record_nulled_field("users", "password", "security_null_pattern") + + extracted = { + "users": [{"id": 1, "password": None}], + } + + warnings = manifest.validate_compliance(extracted, ["hipaa"]) + + password_warnings = [w for w in warnings if w.column == "password"] + assert len(password_warnings) == 0 + + def test_fk_preserved_columns_not_flagged(self): + """FK-preserved columns should not produce warnings even if they match patterns.""" + manifest = ComplianceManifest() + manifest.initialize("test-123", compliance_profiles=["gdpr"]) + + manifest.record_fk_preserved("orders", "user_name_id") + + extracted = { + "orders": [{"id": 1, "user_name_id": 42}], + } + + warnings = manifest.validate_compliance(extracted, ["gdpr"]) + + fk_warnings = [w for w in warnings if w.column == "user_name_id"] + assert len(fk_warnings) == 0 + + def test_empty_tables_no_warnings(self): + """Empty tables should not produce warnings.""" + manifest = ComplianceManifest() + manifest.initialize("test-123", compliance_profiles=["gdpr"]) + + extracted: dict[str, list[dict]] = { + "users": [], + } + + warnings = manifest.validate_compliance(extracted, ["gdpr"]) + assert len(warnings) == 0 diff --git a/tests/test_csv_output.py b/tests/test_csv_output.py index 37a4344..0d89462 100644 --- a/tests/test_csv_output.py +++ b/tests/test_csv_output.py @@ -133,7 +133,7 @@ def test_generate_per_table_separate_headers(self, per_table_generator, sample_t assert "id,user_id,total" in orders_lines[0] def test_format_value_none(self, generator): - assert generator._format_value(None) == "" + assert generator._format_value(None) == "\\N" def test_format_value_bool(self, generator): assert generator._format_value(True) == "true" @@ -307,7 +307,7 @@ def test_null_value_in_csv(self, generator, sample_tables_schema): reader = csv.DictReader(csv_output.strip().split("\n")) rows = list(reader) - assert rows[0]["name"] == "" # NULL becomes empty string in CSV + assert rows[0]["name"] == "\\N" # NULL uses \N sentinel (PostgreSQL COPY convention) def test_write_to_file_single_mode(self, generator, sample_tables_schema, tmp_path): tables_data = { diff --git a/tests/test_cycles.py b/tests/test_cycles.py index 71a0f99..25889fe 100644 --- a/tests/test_cycles.py +++ b/tests/test_cycles.py @@ -7,6 +7,7 @@ DeferredUpdate, break_cycles_at_nullable_fks, build_deferred_updates, + build_deferred_updates_chunked, find_cycles_dfs, identify_cycle_fks, select_nullable_fk_to_break, @@ -427,6 +428,148 @@ def test_format_where_clause_composite(self): assert " AND " in where_clause +class TestBuildDeferredUpdatesChunked: + """Tests for the chunked (streaming-friendly) deferred update builder.""" + + def test_matches_non_chunked_output(self): + """Chunked variant should produce same DeferredUpdates as the original.""" + fk = ForeignKey( + name="fk_manager", + source_table="employees", + source_columns=("manager_id",), + target_table="employees", + target_columns=("id",), + is_nullable=True, + ) + + employees_table = Table( + name="employees", + schema="public", + columns=[], + primary_key=("id",), + foreign_keys=[], + ) + + schema = SchemaGraph(tables={"employees": employees_table}, edges=[fk]) + + rows = [ + {"id": 1, "name": "Alice", "manager_id": 2}, + {"id": 2, "name": "Bob", "manager_id": 1}, + {"id": 3, "name": "Charlie", "manager_id": None}, + ] + + # Non-chunked + expected = build_deferred_updates([fk], {"employees": rows}, schema) + + # Chunked — split rows into multiple small chunks + def chunk_iter(): + yield [rows[0]] + yield [rows[1], rows[2]] + + chunked_result = build_deferred_updates_chunked( + [fk], schema, {"employees": chunk_iter()} + ) + + assert len(chunked_result) == len(expected) + for a, b in zip(expected, chunked_result): + assert a.table == b.table + assert a.pk_columns == b.pk_columns + assert a.pk_values == b.pk_values + assert a.fk_column == b.fk_column + assert a.fk_value == b.fk_value + + def test_skips_null_fk_values(self): + """Chunked variant should skip NULL FK values like the original.""" + fk = ForeignKey( + name="fk_manager", + source_table="employees", + source_columns=("manager_id",), + target_table="employees", + target_columns=("id",), + is_nullable=True, + ) + + employees_table = Table( + name="employees", + schema="public", + columns=[], + primary_key=("id",), + foreign_keys=[], + ) + + schema = SchemaGraph(tables={"employees": employees_table}, edges=[fk]) + + def chunk_iter(): + yield [ + {"id": 1, "name": "Alice", "manager_id": None}, + {"id": 2, "name": "Bob", "manager_id": 1}, + ] + + updates = build_deferred_updates_chunked( + [fk], schema, {"employees": chunk_iter()} + ) + + assert len(updates) == 1 + assert updates[0].pk_values == (2,) + assert updates[0].fk_value == 1 + + def test_empty_iterator(self): + """Chunked variant handles empty iterator.""" + fk = ForeignKey( + name="fk_manager", + source_table="employees", + source_columns=("manager_id",), + target_table="employees", + target_columns=("id",), + is_nullable=True, + ) + + employees_table = Table( + name="employees", + schema="public", + columns=[], + primary_key=("id",), + foreign_keys=[], + ) + + schema = SchemaGraph(tables={"employees": employees_table}, edges=[fk]) + + def chunk_iter(): + return iter([]) + + updates = build_deferred_updates_chunked( + [fk], schema, {"employees": chunk_iter()} + ) + + assert updates == [] + + def test_missing_table_in_iterators(self): + """Chunked variant handles missing table gracefully.""" + fk = ForeignKey( + name="fk_manager", + source_table="employees", + source_columns=("manager_id",), + target_table="employees", + target_columns=("id",), + is_nullable=True, + ) + + employees_table = Table( + name="employees", + schema="public", + columns=[], + primary_key=("id",), + foreign_keys=[], + ) + + schema = SchemaGraph(tables={"employees": employees_table}, edges=[fk]) + + # No iterator for the table + updates = build_deferred_updates_chunked([fk], schema, {}) + + assert updates == [] + + class TestCycleInfo: """Tests for CycleInfo dataclass.""" diff --git a/tests/test_json_output.py b/tests/test_json_output.py index d322790..ef130ec 100644 --- a/tests/test_json_output.py +++ b/tests/test_json_output.py @@ -43,6 +43,13 @@ def test_encode_decimal(self): result = json.dumps(d, cls=DatabaseTypeEncoder) assert "99.99" in result + def test_encode_decimal_large_value_preserves_precision(self): + """Verify that large Decimal values preserve exact precision (no float rounding).""" + d = Decimal("99999999999999.99") + result = json.dumps(d, cls=DatabaseTypeEncoder) + parsed = json.loads(result) + assert parsed == "99999999999999.99" + def test_encode_uuid(self): u = UUID("12345678-1234-5678-1234-567812345678") result = json.dumps(u, cls=DatabaseTypeEncoder) @@ -83,7 +90,7 @@ def test_encode_mixed_types(self): # Verify all types were encoded correctly assert "2024-01-15" in parsed["datetime"] assert parsed["date"] == "2024-01-15" - assert parsed["decimal"] == 99.99 + assert parsed["decimal"] == "99.99" assert "12345678-1234-5678" in parsed["uuid"] assert parsed["bytes"] == "0001" assert parsed["string"] == "test" @@ -272,7 +279,7 @@ def test_generate_with_special_types(self, sample_tables_schema): user = parsed["tables"]["users"][0] assert "2024-01-15" in user["created_at"] assert user["birthday"] == "1990-05-20" - assert user["balance"] == 1234.56 + assert user["balance"] == "1234.56" assert "12345678-1234-5678" in user["uuid"] assert user["avatar"] == "89504e47" # hex of b"\x89PNG" @@ -554,4 +561,4 @@ def test_full_extraction_workflow(self, tmp_path): assert len(parsed["tables"]["users"]) == 2 assert len(parsed["tables"]["orders"]) == 3 assert parsed["tables"]["users"][0]["email"] == "user1@example.com" - assert parsed["tables"]["orders"][0]["amount"] == 99.99 + assert parsed["tables"]["orders"][0]["amount"] == "99.99" diff --git a/tests/test_max_seed_rows.py b/tests/test_max_seed_rows.py new file mode 100644 index 0000000..b94d65b --- /dev/null +++ b/tests/test_max_seed_rows.py @@ -0,0 +1,203 @@ +"""Tests for max_seed_rows safety limit on seed queries.""" + +from unittest.mock import patch + +import pytest + +from dbslice.config import ExtractConfig, SeedSpec +from dbslice.constants import DEFAULT_MAX_SEED_ROWS, SEED_ROW_WARNING_THRESHOLD +from dbslice.core.engine import ExtractionEngine +from dbslice.exceptions import ExtractionError +from dbslice.models import Column, SchemaGraph, Table +from tests.conftest import MockAdapter + + +def _make_schema_and_adapter( + row_count: int, +) -> tuple[SchemaGraph, MockAdapter]: + """Create a simple schema with a users table and the given number of rows.""" + users_table = Table( + name="users", + schema="public", + columns=[ + Column(name="id", data_type="INTEGER", nullable=False, is_primary_key=True), + Column(name="status", data_type="TEXT", nullable=True, is_primary_key=False), + ], + primary_key=("id",), + foreign_keys=[], + ) + schema = SchemaGraph(tables={"users": users_table}, edges=[]) + rows = [{"id": i, "status": "active"} for i in range(1, row_count + 1)] + adapter = MockAdapter(schema, {"users": rows}) + return schema, adapter + + +class TestMaxSeedRowsDefault: + """The default max_seed_rows value works correctly.""" + + def test_default_value_is_set(self): + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + ) + assert config.max_seed_rows == DEFAULT_MAX_SEED_ROWS + + def test_seed_within_default_limit_succeeds(self): + schema, adapter = _make_schema_and_adapter(5) + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + validate=False, + ) + engine = ExtractionEngine(config) + engine.schema = schema + engine.adapter = adapter + + # Should not raise + engine._process_seed(config.seeds[0]) + + +class TestMaxSeedRowsExceedsLimit: + """Seed queries exceeding the limit raise ExtractionError.""" + + def test_exceeds_default_limit(self): + row_count = DEFAULT_MAX_SEED_ROWS + 1 + schema, adapter = _make_schema_and_adapter(row_count) + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + validate=False, + ) + engine = ExtractionEngine(config) + engine.schema = schema + engine.adapter = adapter + + with pytest.raises(ExtractionError, match="exceeding limit"): + engine._process_seed(config.seeds[0]) + + def test_exceeds_custom_limit(self): + schema, adapter = _make_schema_and_adapter(6) + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + validate=False, + max_seed_rows=5, + ) + engine = ExtractionEngine(config) + engine.schema = schema + engine.adapter = adapter + + with pytest.raises(ExtractionError, match="exceeding limit of 5"): + engine._process_seed(config.seeds[0]) + + def test_exactly_at_limit_succeeds(self): + schema, adapter = _make_schema_and_adapter(5) + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + validate=False, + max_seed_rows=5, + ) + engine = ExtractionEngine(config) + engine.schema = schema + engine.adapter = adapter + + # Should not raise -- exactly at limit, not over + engine._process_seed(config.seeds[0]) + + def test_error_message_includes_row_count_and_limit(self): + schema, adapter = _make_schema_and_adapter(20) + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + validate=False, + max_seed_rows=10, + ) + engine = ExtractionEngine(config) + engine.schema = schema + engine.adapter = adapter + + with pytest.raises(ExtractionError, match=r"20 rows.*limit of 10"): + engine._process_seed(config.seeds[0]) + + def test_error_message_mentions_cli_flag(self): + schema, adapter = _make_schema_and_adapter(20) + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + validate=False, + max_seed_rows=10, + ) + engine = ExtractionEngine(config) + engine.schema = schema + engine.adapter = adapter + + with pytest.raises(ExtractionError, match="--max-seed-rows"): + engine._process_seed(config.seeds[0]) + + +class TestMaxSeedRowsCustomConfig: + """Custom max_seed_rows via config works correctly.""" + + def test_custom_limit_via_config(self): + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + max_seed_rows=500, + ) + assert config.max_seed_rows == 500 + + def test_high_limit_allows_large_result(self): + schema, adapter = _make_schema_and_adapter(50) + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + validate=False, + max_seed_rows=100, + ) + engine = ExtractionEngine(config) + engine.schema = schema + engine.adapter = adapter + + # 50 rows is under the 100 limit, should succeed + engine._process_seed(config.seeds[0]) + + +class TestSeedRowWarningThreshold: + """Warning is emitted when seed rows exceed the warning threshold.""" + + def test_warning_logged_above_threshold(self): + row_count = SEED_ROW_WARNING_THRESHOLD + 1 + schema, adapter = _make_schema_and_adapter(row_count) + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + validate=False, + max_seed_rows=row_count + 1, # Don't hit the hard limit + ) + engine = ExtractionEngine(config) + engine.schema = schema + engine.adapter = adapter + + with patch("dbslice.core.engine.logger") as mock_logger: + engine._process_seed(config.seeds[0]) + mock_logger.warning.assert_called_once() + call_args = mock_logger.warning.call_args + assert "large number of rows" in call_args[0][0] + + def test_no_warning_below_threshold(self): + row_count = SEED_ROW_WARNING_THRESHOLD + schema, adapter = _make_schema_and_adapter(row_count) + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + validate=False, + max_seed_rows=row_count + 1, + ) + engine = ExtractionEngine(config) + engine.schema = schema + engine.adapter = adapter + + with patch("dbslice.core.engine.logger") as mock_logger: + engine._process_seed(config.seeds[0]) + mock_logger.warning.assert_not_called() diff --git a/tests/test_passthrough.py b/tests/test_passthrough.py index ace170a..514a156 100644 --- a/tests/test_passthrough.py +++ b/tests/test_passthrough.py @@ -3,6 +3,7 @@ import pytest from dbslice.core.graph import GraphTraverser, TraversalConfig +from dbslice.exceptions import ExtractionError from dbslice.models import Column, ForeignKey, SchemaGraph, Table from tests.conftest import MockAdapter @@ -383,7 +384,7 @@ def test_passthrough_with_exclude_tables( def test_passthrough_table_without_pk(passthrough_schema: SchemaGraph): - """Test that passthrough tables without primary keys are skipped.""" + """Test that passthrough tables without primary keys raise an error.""" # Add a table without a primary key no_pk_table = Table( name="logs", @@ -418,10 +419,5 @@ def test_passthrough_table_without_pk(passthrough_schema: SchemaGraph): passthrough_tables={"logs", "countries"}, ) - result = traverser.traverse("users", seed_pks, config) - - # Table without PK should not be included - assert "logs" not in result.records - - # Table with PK should be included - assert "countries" in result.records + with pytest.raises(ExtractionError, match="no primary key"): + traverser.traverse("users", seed_pks, config) diff --git a/tests/test_performance.py b/tests/test_performance.py index f235b39..97e8b51 100644 --- a/tests/test_performance.py +++ b/tests/test_performance.py @@ -225,7 +225,7 @@ def test_engine_passes_configured_batch_size_to_adapter(monkeypatch): captured: dict[str, int | None] = {} class FakePostgreSQLAdapter: - def __init__(self, batch_size=None, profiler=None, schema=None, allow_unsafe_where=False): + def __init__(self, batch_size=None, profiler=None, schema=None, allow_unsafe_where=False, statement_timeout_ms=0): captured["batch_size"] = batch_size def connect(self, url: str) -> None: diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 503e6ff..6a36249 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -509,6 +509,153 @@ def test_streaming_with_cycles(sample_schema, tmp_path): assert "orders" in sql_content +def test_streaming_deferred_updates_use_chunked_fetch(sample_schema, tmp_path): + """Streaming mode should build deferred updates via chunked fetching, not list(). + + This verifies that the streaming extract path uses fetch_by_pk_chunked + (not fetch_by_pk with list()) when processing broken FK tables for + deferred updates, so that memory stays bounded. + """ + from dbslice.models import Column, ForeignKey, Table + + # Build a schema with a circular dependency: users <-> orders + users_table = Table( + name="users", + schema="main", + columns=[ + Column(name="id", data_type="INTEGER", nullable=False, is_primary_key=True), + Column(name="email", data_type="TEXT", nullable=False, is_primary_key=False), + Column(name="name", data_type="TEXT", nullable=True, is_primary_key=False), + Column( + name="last_order_id", + data_type="INTEGER", + nullable=True, + is_primary_key=False, + ), + ], + primary_key=("id",), + foreign_keys=[], + ) + + orders_table = Table( + name="orders", + schema="main", + columns=[ + Column(name="id", data_type="INTEGER", nullable=False, is_primary_key=True), + Column( + name="user_id", + data_type="INTEGER", + nullable=False, + is_primary_key=False, + ), + Column(name="total", data_type="REAL", nullable=True, is_primary_key=False), + Column(name="status", data_type="TEXT", nullable=True, is_primary_key=False), + ], + primary_key=("id",), + foreign_keys=[], + ) + + fk_orders_users = ForeignKey( + name="fk_orders_users", + source_table="orders", + source_columns=("user_id",), + target_table="users", + target_columns=("id",), + is_nullable=False, + ) + fk_users_orders = ForeignKey( + name="fk_users_orders", + source_table="users", + source_columns=("last_order_id",), + target_table="orders", + target_columns=("id",), + is_nullable=True, + ) + + from dbslice.models import SchemaGraph + + schema = SchemaGraph( + tables={"users": users_table, "orders": orders_table}, + edges=[fk_orders_users, fk_users_orders], + ) + + data = { + "users": [ + {"id": i, "email": f"u{i}@test.com", "name": f"U{i}", "last_order_id": i} + for i in range(1, 6) + ], + "orders": [ + {"id": i, "user_id": i, "total": 10.0 * i, "status": "ok"} for i in range(1, 6) + ], + } + + # Adapter that tracks whether fetch_by_pk_chunked was used + class TrackingAdapter(ChunkedMockAdapter): + def __init__(self, schema, data): + super().__init__(schema, data, chunk_size=2) + self.fetch_by_pk_calls: list[str] = [] + + def fetch_by_pk(self, table, pk_columns, pk_values): + self.fetch_by_pk_calls.append(table) + return super().fetch_by_pk(table, pk_columns, pk_values) + + def fetch_by_pk_chunked(self, table, pk_columns, pk_values, chunk_size=1000): + self.chunked_fetch_calls.append( + {"table": table, "pk_count": len(pk_values), "chunk_size": chunk_size} + ) + # Yield in small chunks to prove we don't need all at once + all_rows = list(super().fetch_by_pk(table, pk_columns, pk_values)) + for i in range(0, len(all_rows), chunk_size): + yield all_rows[i : i + chunk_size] + + adapter = TrackingAdapter(schema, data) + adapter.connect("test://localhost/test") + + all_records = { + "users": {(i,) for i in range(1, 6)}, + "orders": {(i,) for i in range(1, 6)}, + } + + config = ExtractConfig( + database_url="test://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + stream=True, + output_file=str(tmp_path / "chunked_deferred.sql"), + streaming_chunk_size=2, + ) + + engine = ExtractionEngine(config) + engine.adapter = adapter + engine.schema = schema + + # Reset call tracking before the streaming extract + adapter.fetch_by_pk_calls.clear() + adapter.chunked_fetch_calls.clear() + + result = engine._do_streaming_extract( + db_type=DatabaseType.POSTGRESQL, + all_records=all_records, + insert_order=["orders", "users"], + broken_fks=[fk_users_orders], + cycle_infos=[], + all_paths=[], + used_deferred_cycle_strategy=True, + ) + + # Verify chunked fetch was used for the broken FK table + chunked_tables = [c["table"] for c in adapter.chunked_fetch_calls] + assert "users" in chunked_tables, ( + "Expected fetch_by_pk_chunked to be used for the broken FK table 'users'" + ) + + # The chunked_fetch_calls list confirms the chunked path was taken + # instead of loading all rows into memory with list(fetch_by_pk(...)). + assert len(adapter.chunked_fetch_calls) > 0 + + # Result should still be valid + assert result.total_rows() > 0 + + def test_streaming_empty_table(sample_schema, tmp_path): """Test streaming handles empty tables gracefully.""" data = { diff --git a/tests/test_validation.py b/tests/test_validation.py index 815ea4c..c8ac3f1 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -477,6 +477,23 @@ def test_orphan_with_composite_pk(self): orphan = result.orphaned_records[0] assert orphan.fk_values == (1, 999) + def test_extract_fk_values_raises_on_missing_column(self, simple_schema): + """Test that missing FK columns raise KeyError instead of silently returning None.""" + validator = ExtractionValidator(simple_schema) + + # Row is missing the 'user_id' FK column entirely + tables = { + "users": [ + {"id": 1, "email": "alice@example.com"}, + ], + "orders": [ + {"id": 1, "total": 100.0}, # 'user_id' column missing + ], + } + + with pytest.raises(KeyError, match="user_id"): + validator.validate(tables) + class TestOrphanedRecord: """Tests for OrphanedRecord dataclass.""" diff --git a/tests/test_validators.py b/tests/test_validators.py index 5eb7dd8..9113cd3 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -453,3 +453,62 @@ def test_case_sensitivity(self): for case_variant in ["DROP", "drop", "Drop", "dRoP"]: with pytest.raises(IdentifierValidationError): validate_identifier(case_variant, "test") + + +class TestWhereClauseDelegation: + """Verify input_validators.validate_where_clause catches attacks via config delegation. + + These vectors were previously only caught by config.validate_where_clause. + After the delegation fix, the input_validators wrapper must catch them too. + """ + + def test_dollar_quoting_blocked(self): + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("$$DROP TABLE users$$") + + def test_tagged_dollar_quoting_blocked(self): + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("$body$DROP TABLE users$body$") + + def test_type_cast_blocked(self): + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("id = 1::int") + + def test_escape_string_blocked(self): + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("id = E'\\x44ROP'") + + def test_pg_sleep_blocked(self): + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("id = 1 AND pg_sleep(10)") + + def test_lo_import_blocked(self): + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("id = lo_import('/etc/passwd')") + + def test_dblink_blocked(self): + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("id = 1 AND dblink('host=evil.com')") + + def test_fullwidth_unicode_drop_blocked(self): + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("\uff24\uff32\uff2f\uff30 TABLE users") + + def test_comment_sequences_blocked(self): + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("id = 1 -- comment") + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("id = 1 /* comment */") + + def test_standalone_drop_keyword_blocked(self): + """DROP as a bare keyword (not just DROP TABLE) should be caught.""" + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("1=1 DROP TABLE users") + + def test_safe_clauses_still_pass(self): + """Ensure delegation does not cause false positives on safe inputs.""" + validate_where_clause("status = 'active'") + validate_where_clause("id IN (1, 2, 3)") + validate_where_clause("dropbox_id = 123") + validate_where_clause("category = 'DELETE'") + validate_where_clause("name = 'O''Brien'")