diff --git a/src/dbslice/adapters/postgresql.py b/src/dbslice/adapters/postgresql.py index 99f9266..85d558f 100644 --- a/src/dbslice/adapters/postgresql.py +++ b/src/dbslice/adapters/postgresql.py @@ -27,6 +27,7 @@ def __init__( profiler: Any = None, schema: str | None = None, allow_unsafe_where: bool = False, + statement_timeout_ms: int = 0, ): self._conn: Any = None self._schema_name = schema or "public" @@ -34,6 +35,7 @@ def __init__( self.batch_size = batch_size or self.DEFAULT_BATCH_SIZE self.profiler = profiler self.allow_unsafe_where = allow_unsafe_where + self.statement_timeout_ms = statement_timeout_ms def connect(self, url: str) -> None: """Establish PostgreSQL connection.""" @@ -62,6 +64,16 @@ def connect(self, url: str) -> None: # Use autocommit for reads by default self._conn.autocommit = True + # Set statement_timeout to prevent runaway queries + if self.statement_timeout_ms > 0: + with self._conn.cursor() as cur: + cur.execute( + "SET statement_timeout = %s", (self.statement_timeout_ms,) + ) + logger.debug( + "statement_timeout set", timeout_ms=self.statement_timeout_ms + ) + # Set search_path so unqualified table names resolve to the target schema if self._schema_name != "public": with self._conn.cursor() as cur: @@ -628,13 +640,17 @@ def fetch_referencing_pks( cur.execute(query, params) rows = cur.fetchall() for row in rows: - result.add(row) + # Filter out NULL values (nullable FKs) + if None not in row: + result.add(row) tracker.record_rows(len(rows)) else: with self._conn.cursor() as cur: cur.execute(query, params) for row in cur.fetchall(): - result.add(row) + # Filter out NULL values (nullable FKs) + if None not in row: + result.add(row) logger.debug( "Fetched referencing PKs with batching", diff --git a/src/dbslice/cli.py b/src/dbslice/cli.py index ba6fa6e..de9f4d3 100644 --- a/src/dbslice/cli.py +++ b/src/dbslice/cli.py @@ -13,7 +13,9 @@ from dbslice import __version__ from dbslice.config import ExtractConfig, OutputFormat, SeedSpec, TraversalDirection from dbslice.constants import ( + DEFAULT_MAX_SEED_ROWS, DEFAULT_OUTPUT_FILE_MODE, + DEFAULT_STATEMENT_TIMEOUT_MS, DEFAULT_STREAMING_CHUNK_SIZE, DEFAULT_STREAMING_THRESHOLD, DEFAULT_TRAVERSAL_DEPTH, @@ -276,6 +278,8 @@ def _build_extract_config( output_file_mode: int, schema: str | None = None, allow_unsafe_where: bool = False, + max_seed_rows: int = DEFAULT_MAX_SEED_ROWS, + statement_timeout_ms: int = DEFAULT_STATEMENT_TIMEOUT_MS, ) -> ExtractConfig: """ Build ExtractConfig from validated CLI parameters. @@ -301,6 +305,7 @@ def _build_extract_config( stream_threshold: Auto-enable streaming above this row count stream_chunk_size: Streaming fetch chunk size output_file_mode: Permissions mode for output files + max_seed_rows: Maximum rows a single seed query may return Returns: Configured ExtractConfig object @@ -328,6 +333,8 @@ def _build_extract_config( output_file_mode=output_file_mode, schema=schema, allow_unsafe_where=bool(allow_unsafe_where), + max_seed_rows=max_seed_rows, + statement_timeout_ms=statement_timeout_ms, ) @@ -1095,6 +1102,20 @@ def extract( help="Number of rows to fetch per chunk in streaming mode (default: 1000)", ), ] = None, + max_seed_rows: Annotated[ + int | None, + typer.Option( + "--max-seed-rows", + help=f"Maximum rows a single seed query may return (default: {DEFAULT_MAX_SEED_ROWS})", + ), + ] = None, + statement_timeout: Annotated[ + int | None, + typer.Option( + "--statement-timeout", + help="PostgreSQL statement timeout in milliseconds (0 = no timeout, default: 0)", + ), + ] = None, output_file_mode: Annotated[ str | None, typer.Option( @@ -1257,6 +1278,12 @@ def extract( effective_stream_chunk_size = ( stream_chunk_size if stream_chunk_size is not None else DEFAULT_STREAMING_CHUNK_SIZE ) + effective_max_seed_rows = ( + max_seed_rows if max_seed_rows is not None else DEFAULT_MAX_SEED_ROWS + ) + effective_statement_timeout = ( + statement_timeout if statement_timeout is not None else DEFAULT_STATEMENT_TIMEOUT_MS + ) effective_output_file_mode = DEFAULT_OUTPUT_FILE_MODE loaded_config = None @@ -1377,6 +1404,10 @@ def extract( raise ValueError("--stream-threshold must be greater than 0") if effective_stream_chunk_size <= 0: raise ValueError("--stream-chunk-size must be greater than 0") + if effective_max_seed_rows <= 0: + raise ValueError("--max-seed-rows must be greater than 0") + if effective_statement_timeout < 0: + raise ValueError("--statement-timeout must be >= 0") except ValidationError as e: console.print(f"[red]Validation Error:[/red] {e}") raise typer.Exit(1) @@ -1454,8 +1485,13 @@ def extract( output_file_mode=effective_output_file_mode, schema=schema, allow_unsafe_where=effective_allow_unsafe_where, + max_seed_rows=effective_max_seed_rows, ) + # Apply CLI overrides to config + extract_config.max_seed_rows = effective_max_seed_rows + extract_config.statement_timeout_ms = effective_statement_timeout + # Apply compliance settings to extract config extract_config.compliance_profiles = effective_compliance extract_config.compliance_strict = effective_compliance_strict diff --git a/src/dbslice/compliance/manifest.py b/src/dbslice/compliance/manifest.py index fb33e52..c4d467f 100644 --- a/src/dbslice/compliance/manifest.py +++ b/src/dbslice/compliance/manifest.py @@ -200,6 +200,73 @@ def sign(self, signing_key: str) -> None: self.signature_algorithm = "hmac-sha256" self.signature = f"hmac-sha256:{digest}" + def validate_compliance( + self, + extracted_tables: dict[str, list[dict[str, Any]]], + profiles: list[str], + ) -> list[ManifestWarning]: + """Validate that all columns required by compliance profiles were actually anonymized. + + Compares the columns present in the extracted data against the manifest's + recorded actions. Any column that matches a compliance profile's required + pattern but was recorded as ``unmasked`` (or not recorded at all) is + flagged as a warning. + + Args: + extracted_tables: The actual extracted data (table -> rows). + profiles: Names of active compliance profiles. + + Returns: + List of ManifestWarning objects for columns that should have been + anonymized but were not. + """ + from dbslice.compliance.profiles import get_profile + + new_warnings: list[ManifestWarning] = [] + + # Build a set of (table, column) pairs that were masked or nulled + protected: set[tuple[str, str]] = set() + for table_name, table_entry in self.tables.items(): + for masked in table_entry.fields_masked: + protected.add((table_name, masked.column)) + for nulled in table_entry.fields_nulled: + protected.add((table_name, nulled.column)) + for col in table_entry.fields_preserved_fk: + protected.add((table_name, col)) + + for profile_name in profiles: + try: + profile = get_profile(profile_name) + except ValueError: + continue + + required_patterns = profile.required_column_patterns + + for table_name, rows in extracted_tables.items(): + if not rows: + continue + columns = list(rows[0].keys()) + for col in columns: + if (table_name, col) in protected: + continue + col_lower = col.lower() + for pattern in required_patterns: + if pattern in col_lower: + warning = ManifestWarning( + table=table_name, + column=col, + reason=( + f"Column matches {profile.display_name} required " + f"pattern '{pattern}' but was not anonymized" + ), + severity="error", + ) + new_warnings.append(warning) + self.warnings.append(warning) + break + + return new_warnings + def to_dict(self) -> dict[str, Any]: """Convert to a JSON-serializable dictionary.""" tables_dict: dict[str, Any] = {} diff --git a/src/dbslice/config.py b/src/dbslice/config.py index 98a2dea..bddeb2c 100644 --- a/src/dbslice/config.py +++ b/src/dbslice/config.py @@ -5,7 +5,9 @@ from typing import Any from dbslice.constants import ( + DEFAULT_MAX_SEED_ROWS, DEFAULT_OUTPUT_FILE_MODE, + DEFAULT_STATEMENT_TIMEOUT_MS, DEFAULT_STREAMING_CHUNK_SIZE, DEFAULT_STREAMING_THRESHOLD, DEFAULT_TRAVERSAL_DEPTH, @@ -316,6 +318,8 @@ class ExtractConfig: table_direction_overrides: dict[str, TraversalDirection] = field(default_factory=dict) row_limit_global: int | None = None row_limit_per_table: dict[str, int] = field(default_factory=dict) + max_seed_rows: int = DEFAULT_MAX_SEED_ROWS + statement_timeout_ms: int = DEFAULT_STATEMENT_TIMEOUT_MS anonymization_seed: str | None = None anonymization_field_providers: dict[str, str] = field(default_factory=dict) anonymization_patterns: dict[str, str] = field(default_factory=dict) diff --git a/src/dbslice/config_file.py b/src/dbslice/config_file.py index 78a8e3c..cae7264 100644 --- a/src/dbslice/config_file.py +++ b/src/dbslice/config_file.py @@ -10,7 +10,9 @@ from dbslice.config import ExtractConfig, OutputFormat, SeedSpec, TraversalDirection from dbslice.constants import ( + DEFAULT_MAX_SEED_ROWS, DEFAULT_OUTPUT_FILE_MODE, + DEFAULT_STATEMENT_TIMEOUT_MS, DEFAULT_STREAMING_CHUNK_SIZE, DEFAULT_STREAMING_THRESHOLD, DEFAULT_TRAVERSAL_DEPTH, @@ -42,6 +44,8 @@ "fail_on_validation_error", "max_rows_per_table", "allow_unsafe_where", + "max_seed_rows", + "statement_timeout_ms", } _ANONYMIZATION_KEYS = { "enabled", @@ -309,6 +313,12 @@ class ExtractionConfig: allow_unsafe_where: bool = False """Allow seed WHERE clauses with subqueries (trusted inputs only).""" + max_seed_rows: int | None = None + """Maximum rows a single seed query may return (None = use global default).""" + + statement_timeout_ms: int | None = None + """PostgreSQL statement_timeout in milliseconds (None = use global default).""" + @dataclass class AnonymizationConfig: @@ -584,6 +594,8 @@ def _from_dict(cls, data: dict[str, Any]) -> "DbsliceConfig": fail_on_validation_error=extraction_data.get("fail_on_validation_error", False), max_rows_per_table=extraction_data.get("max_rows_per_table"), allow_unsafe_where=extraction_data.get("allow_unsafe_where", False), + max_seed_rows=extraction_data.get("max_seed_rows"), + statement_timeout_ms=extraction_data.get("statement_timeout_ms"), ) if not isinstance(extraction.default_depth, int) or extraction.default_depth < 1: @@ -610,6 +622,15 @@ def _from_dict(cls, data: dict[str, Any]) -> "DbsliceConfig": not isinstance(extraction.max_rows_per_table, int) or extraction.max_rows_per_table <= 0 ): raise ValueError("'extraction.max_rows_per_table' must be a positive integer") + if extraction.max_seed_rows is not None and ( + not isinstance(extraction.max_seed_rows, int) or extraction.max_seed_rows <= 0 + ): + raise ValueError("'extraction.max_seed_rows' must be a positive integer") + if extraction.statement_timeout_ms is not None and ( + not isinstance(extraction.statement_timeout_ms, int) + or extraction.statement_timeout_ms < 0 + ): + raise ValueError("'extraction.statement_timeout_ms' must be a non-negative integer") anon_data = data.get("anonymization", {}) if not isinstance(anon_data, dict): @@ -1195,6 +1216,8 @@ def to_extract_config( virtual_foreign_keys=virtual_fks, schema=final_schema, allow_unsafe_where=final_allow_unsafe_where, + max_seed_rows=self.extraction.max_seed_rows or DEFAULT_MAX_SEED_ROWS, + statement_timeout_ms=self.extraction.statement_timeout_ms or DEFAULT_STATEMENT_TIMEOUT_MS, compliance_profiles=self.compliance.profiles, compliance_strict=self.compliance.strict, generate_manifest=self.compliance.generate_manifest @@ -1271,6 +1294,8 @@ def to_yaml(self, include_comments: bool = True) -> str: output.append(f" - {table}") if self.extraction.max_rows_per_table is not None: output.append(f" max_rows_per_table: {self.extraction.max_rows_per_table}") + if self.extraction.max_seed_rows is not None: + output.append(f" max_seed_rows: {self.extraction.max_seed_rows}") output.append("") if include_comments: diff --git a/src/dbslice/constants.py b/src/dbslice/constants.py index ee55822..bc947de 100644 --- a/src/dbslice/constants.py +++ b/src/dbslice/constants.py @@ -30,3 +30,12 @@ DEFAULT_OUTPUT_FILE_MODE = 0o600 """Secure default permissions for newly created output files.""" + +DEFAULT_MAX_SEED_ROWS = 10000 +"""Default maximum number of rows a single seed query may return.""" + +SEED_ROW_WARNING_THRESHOLD = 1000 +"""Warn when a seed query returns more than this many rows.""" + +DEFAULT_STATEMENT_TIMEOUT_MS = 0 +"""Default PostgreSQL statement_timeout in milliseconds (0 = no timeout).""" diff --git a/src/dbslice/core/cycles.py b/src/dbslice/core/cycles.py index 438a656..61aa78d 100644 --- a/src/dbslice/core/cycles.py +++ b/src/dbslice/core/cycles.py @@ -1,3 +1,4 @@ +from collections.abc import Iterator from dataclasses import dataclass from typing import Any @@ -252,42 +253,74 @@ def build_deferred_updates( Returns: List of DeferredUpdate objects describing the UPDATE statements needed """ - deferred_updates = [] + # Wrap each table's row list as a single-chunk iterator to reuse the + # chunked implementation, avoiding duplicated logic. + chunk_iterators: dict[str, Iterator[list[dict[str, Any]]]] = { + table: iter([rows]) for table, rows in tables_data.items() + } + return build_deferred_updates_chunked(fks_to_break, schema, chunk_iterators) + +def build_deferred_updates_chunked( + fks_to_break: list[ForeignKey], + schema: SchemaGraph, + chunk_iterators: dict[str, Iterator[list[dict[str, Any]]]], +) -> list[DeferredUpdate]: + """Build deferred UPDATE statements by processing rows in chunks. + + Unlike `build_deferred_updates` which requires all row data in memory, + this variant processes rows incrementally via chunk iterators. Only the + resulting DeferredUpdate objects (which contain just PK and FK values) + are accumulated in memory. + + Args: + fks_to_break: List of ForeignKey objects that were broken + schema: Database schema graph + chunk_iterators: Map of table name to an iterator yielding chunks + (lists) of row dicts + + Returns: + List of DeferredUpdate objects describing the UPDATE statements needed + """ + # Group broken FKs by source table so we can process each table's + # chunks once and check all relevant FKs per row. + fks_by_table: dict[str, list[ForeignKey]] = {} for fk in fks_to_break: - source_table = fk.source_table - fk_columns = fk.source_columns + fks_by_table.setdefault(fk.source_table, []).append(fk) - if source_table not in tables_data: + deferred_updates: list[DeferredUpdate] = [] + + for table, table_fks in fks_by_table.items(): + if table not in chunk_iterators: continue - table_info = schema.get_table(source_table) + table_info = schema.get_table(table) if not table_info: continue pk_columns = table_info.primary_key - for row_data in tables_data[source_table]: - for fk_col in fk_columns: - if fk_col not in row_data: - continue - - fk_value = row_data[fk_col] - - # Skip if value is already NULL (no update needed) - if fk_value is None: - continue - - pk_values = tuple(row_data[col] for col in pk_columns) - - deferred_updates.append( - DeferredUpdate( - table=source_table, - pk_columns=pk_columns, - pk_values=pk_values, - fk_column=fk_col, - fk_value=fk_value, - ) - ) + for chunk in chunk_iterators[table]: + for row_data in chunk: + for fk in table_fks: + for fk_col in fk.source_columns: + if fk_col not in row_data: + continue + + fk_value = row_data[fk_col] + if fk_value is None: + continue + + pk_values = tuple(row_data[col] for col in pk_columns) + + deferred_updates.append( + DeferredUpdate( + table=table, + pk_columns=pk_columns, + pk_values=pk_values, + fk_column=fk_col, + fk_value=fk_value, + ) + ) return deferred_updates diff --git a/src/dbslice/core/engine.py b/src/dbslice/core/engine.py index e9f9266..efe785a 100644 --- a/src/dbslice/core/engine.py +++ b/src/dbslice/core/engine.py @@ -12,7 +12,12 @@ SeedSpec, TraversalDirection, ) -from dbslice.constants import DEFAULT_ANONYMIZATION_SEED, DEFAULT_TRAVERSAL_DEPTH +from dbslice.constants import ( + DEFAULT_ANONYMIZATION_SEED, + DEFAULT_TRAVERSAL_DEPTH, + SEED_ROW_WARNING_THRESHOLD, +) +from dbslice.core.cycles import CycleInfo, DeferredUpdate from dbslice.core.graph import GraphTraverser, TraversalConfig, TraversalResult from dbslice.exceptions import ( ExtractionError, @@ -24,6 +29,7 @@ from dbslice.output.sql import SQLGenerator from dbslice.utils.anonymizer import DeterministicAnonymizer from dbslice.utils.connection import get_adapter_for_url, parse_database_url +from dbslice.utils.profiling import QueryProfiler from dbslice.validation import ExtractionValidator, ValidationResult logger = get_logger(__name__) @@ -72,11 +78,11 @@ class ExtractionResult: stats: dict[str, int] = field(default_factory=dict) traversal_path: list[str] = field(default_factory=list) has_cycles: bool = False - broken_fks: list[Any] = field(default_factory=list) # list[ForeignKey] - deferred_updates: list[Any] = field(default_factory=list) # list[DeferredUpdate] - cycle_infos: list[Any] = field(default_factory=list) # list[CycleInfo] + broken_fks: list[ForeignKey] = field(default_factory=list) + deferred_updates: list[DeferredUpdate] = field(default_factory=list) + cycle_infos: list[CycleInfo] = field(default_factory=list) validation_result: ValidationResult | None = None - profiler: Any = None # Optional QueryProfiler + profiler: QueryProfiler | None = None used_deferred_cycle_strategy: bool = False def total_rows(self) -> int: @@ -206,6 +212,7 @@ def extract(self) -> tuple[ExtractionResult, SchemaGraph]: profiler=profiler, schema=self.config.schema, allow_unsafe_where=self.config.allow_unsafe_where, + statement_timeout_ms=self.config.statement_timeout_ms, ) else: self.adapter = get_adapter_for_url(self.config.database_url) @@ -536,6 +543,17 @@ def _do_extract(self, db_type: DatabaseType) -> ExtractionResult: for table, rows in tables_data.items(): self.manifest.set_table_row_count(table, len(rows)) + if self.config.compliance_profiles: + compliance_warnings = self.manifest.validate_compliance( + tables_data, self.config.compliance_profiles + ) + if compliance_warnings: + logger.warning( + "Compliance validation found unanonymized columns", + warning_count=len(compliance_warnings), + profiles=self.config.compliance_profiles, + ) + return ExtractionResult( tables=tables_data, insert_order=insert_order, @@ -994,6 +1012,22 @@ def _process_seed(self, seed: SeedSpec) -> TraversalResult: ) seed_rows = list(self.adapter.fetch_rows(seed.table, where_clause, params)) + if len(seed_rows) > self.config.max_seed_rows: + raise ExtractionError( + f"Seed query returned {len(seed_rows)} rows, exceeding limit of " + f"{self.config.max_seed_rows}. Use --max-seed-rows to increase the " + f"limit or refine your WHERE clause.", + table=seed.table, + ) + + if len(seed_rows) > SEED_ROW_WARNING_THRESHOLD: + logger.warning( + "Seed query returned a large number of rows", + table=seed.table, + row_count=len(seed_rows), + limit=self.config.max_seed_rows, + ) + if not seed_rows: seed_str = ( f"{seed.table}.{seed.column}={seed.value}" @@ -1060,8 +1094,8 @@ def _do_streaming_extract( db_type: DatabaseType, all_records: dict[str, set[tuple[Any, ...]]], insert_order: list[str], - broken_fks: list[Any], - cycle_infos: list[Any], + broken_fks: list[ForeignKey], + cycle_infos: list[CycleInfo], all_paths: list[str], used_deferred_cycle_strategy: bool, ) -> ExtractionResult: @@ -1091,27 +1125,36 @@ def _do_streaming_extract( from dbslice.core.streaming import StreamingExtractionEngine - deferred_updates = [] + deferred_updates: list[Any] = [] if broken_fks: - from dbslice.core.cycles import build_deferred_updates + from dbslice.core.cycles import build_deferred_updates_chunked self._log("cycles", f"Breaking {len(broken_fks)} circular reference(s)...") - # For streaming mode, we need to build deferred updates without having - # all data in memory. We fetch the necessary data on-demand. with logger.timed_operation("build_deferred_updates_streaming"): - temp_data = {} + tables_needed: set[str] = set() for fk in broken_fks: - table = fk.source_table - if table not in temp_data and table in all_records: - pk_values = all_records[table] - table_info = self.schema.get_table(table) - if table_info: - pk_columns = table_info.primary_key - rows = list(self.adapter.fetch_by_pk(table, pk_columns, pk_values)) - temp_data[table] = rows + if fk.source_table in all_records: + tables_needed.add(fk.source_table) + + chunk_iterators = {} + for table in tables_needed: + pk_values = all_records[table] + table_info = self.schema.get_table(table) + if table_info: + pk_columns = table_info.primary_key + chunk_iterators[table] = ( + self.adapter.fetch_by_pk_chunked( + table, + pk_columns, + pk_values, + chunk_size=self.config.streaming_chunk_size, + ) + ) - deferred_updates = build_deferred_updates(broken_fks, temp_data, self.schema) + deferred_updates = build_deferred_updates_chunked( + broken_fks, self.schema, chunk_iterators + ) logger.info( "Circular references resolved", @@ -1145,7 +1188,7 @@ def _do_streaming_extract( def _topological_sort( self, tables: set[str], db_type: DatabaseType - ) -> tuple[list[str], list[Any], list[Any], bool]: + ) -> tuple[list[str], list[ForeignKey], list[CycleInfo], bool]: """ Topologically sort tables based on FK dependencies with cycle handling. @@ -1158,7 +1201,6 @@ def _topological_sort( assert self.schema is not None from dbslice.core.cycles import ( - CycleInfo, break_cycles_at_nullable_fks, find_cycles_dfs, identify_cycle_fks, diff --git a/src/dbslice/core/graph.py b/src/dbslice/core/graph.py index 587cc3b..024aee8 100644 --- a/src/dbslice/core/graph.py +++ b/src/dbslice/core/graph.py @@ -5,6 +5,7 @@ from dbslice.adapters.base import DatabaseAdapter from dbslice.config import TraversalDirection from dbslice.constants import DEFAULT_TRAVERSAL_DEPTH +from dbslice.exceptions import ExtractionError from dbslice.logging import get_logger from dbslice.models import ForeignKey, SchemaGraph @@ -372,11 +373,12 @@ def _process_passthrough_tables( pk_columns = table_info.primary_key if not pk_columns: - logger.warning( - "Passthrough table has no primary key, skipping", + raise ExtractionError( + f"Passthrough table '{table}' has no primary key. " + f"Passthrough tables must have a primary key to extract all rows. " + f"Remove it from passthrough or add a primary key.", table=table, ) - continue logger.debug("Fetching all rows from passthrough table", table=table) all_pks = self.adapter.fetch_all_pks(table, pk_columns) diff --git a/src/dbslice/input_validators.py b/src/dbslice/input_validators.py index e870e7a..80052b9 100644 --- a/src/dbslice/input_validators.py +++ b/src/dbslice/input_validators.py @@ -187,10 +187,12 @@ def validate_column_name(column: str) -> None: def validate_where_clause(where_clause: str) -> None: """ - Validate a WHERE clause for basic safety. + Validate a WHERE clause for safety against SQL injection. - Note: This is basic validation. SQL injection protection should also - be handled by using parameterized queries in the database layer. + Delegates to the comprehensive implementation in ``config.validate_where_clause`` + which handles Unicode normalization, quote stripping, dollar-quoting, type casts, + and dangerous PostgreSQL functions. This wrapper adds a length check and converts + the exception type so callers that expect ``SeedValidationError`` are unaffected. Args: where_clause: The WHERE clause to validate @@ -211,30 +213,16 @@ def validate_where_clause(where_clause: str) -> None: f"WHERE clause too long (max {MAX_WHERE_CLAUSE_LENGTH} characters)", ) - if ";" in where_clause: + from dbslice.config import validate_where_clause as _config_validate + from dbslice.exceptions import InsecureWhereClauseError + + try: + _config_validate(where_clause) + except InsecureWhereClauseError as exc: raise SeedValidationError( where_clause, - "WHERE clause contains potentially dangerous SQL patterns (semicolon found)", - ) - - dangerous_patterns = [ - (r"\bdrop\s+table\b", "DROP TABLE"), - (r"\bdelete\s+from\b", "DELETE FROM"), - (r"\btruncate\b", "TRUNCATE"), - (r"\balter\s+table\b", "ALTER TABLE"), - (r"\bunion\s+select\b", "UNION SELECT"), - (r"\bexec\s*\(", "EXEC"), - (r"\bexecute\s*\(", "EXECUTE"), - (r"--", "SQL comment"), - (r"/\*", "SQL comment"), - ] - - where_lower = where_clause.lower() - for pattern, name in dangerous_patterns: - if re.search(pattern, where_lower, re.IGNORECASE): - raise SeedValidationError( - where_clause, f"WHERE clause contains potentially dangerous SQL patterns ({name})" - ) + f"WHERE clause contains potentially dangerous SQL patterns ({exc.dangerous_keyword})", + ) from exc def validate_seed_value(value: Any) -> None: diff --git a/src/dbslice/output/csv_out.py b/src/dbslice/output/csv_out.py index 4d30dd7..a5d8063 100644 --- a/src/dbslice/output/csv_out.py +++ b/src/dbslice/output/csv_out.py @@ -212,7 +212,7 @@ def _format_value(self, value: Any) -> str: Format a Python value as CSV field value. Type conversions: - - None -> empty string (CSV convention for NULL) + - None -> ``\\N`` (PostgreSQL COPY convention for NULL) - bool -> "true"/"false" - datetime -> ISO 8601 string - date -> ISO 8601 date string @@ -231,8 +231,9 @@ def _format_value(self, value: Any) -> str: String representation suitable for CSV """ if value is None: - # CSV convention: NULL is represented as empty field - return "" + # Use \N sentinel (PostgreSQL COPY convention) to distinguish + # NULL from empty string, enabling lossless round-trips. + return "\\N" if isinstance(value, bool): # Use lowercase for consistency with JSON diff --git a/src/dbslice/output/json_out.py b/src/dbslice/output/json_out.py index bdf6be3..28b6584 100644 --- a/src/dbslice/output/json_out.py +++ b/src/dbslice/output/json_out.py @@ -19,7 +19,7 @@ class DatabaseTypeEncoder(json.JSONEncoder): - date -> ISO 8601 date string (YYYY-MM-DD) - time -> ISO 8601 time string (HH:MM:SS[.ffffff]) - timedelta -> total seconds (as float) - - Decimal -> float + - Decimal -> string (preserves exact precision) - UUID -> string - bytes -> hex string - Any other non-serializable type -> string representation @@ -49,8 +49,8 @@ def default(self, obj: Any) -> Any: return obj.total_seconds() if isinstance(obj, Decimal): - # Convert to float for JSON compatibility - return float(obj) + # Convert to string to preserve exact precision for financial data + return str(obj) if isinstance(obj, UUID): return str(obj) diff --git a/src/dbslice/utils/anonymizer.py b/src/dbslice/utils/anonymizer.py index afbfddd..f47f6c9 100644 --- a/src/dbslice/utils/anonymizer.py +++ b/src/dbslice/utils/anonymizer.py @@ -266,6 +266,9 @@ def configure( ] self.security_null_fields = [pattern.lower() for pattern in (security_null_fields or [])] + # Validate that all explicit provider names are resolvable + self._validate_provider_names() + logger.info( "Anonymizer configured", redact_field_count=len(self.redact_fields), @@ -275,6 +278,36 @@ def configure( security_null_pattern_count=len(self.security_null_fields), ) + def _validate_provider_names(self) -> None: + """Validate that all configured provider names are valid Faker methods or custom transformers.""" + from dbslice.compliance.transformers import CUSTOM_TRANSFORMERS + + # Collect all explicitly configured provider names + all_providers: dict[str, str] = {} # provider -> source description + for field, provider in self.field_providers.items(): + all_providers.setdefault(provider, f"field_providers[{field}]") + for pattern, provider in self.custom_patterns: + all_providers.setdefault(provider, f"patterns[{pattern}]") + for pattern, provider in self.fallback_patterns: + all_providers.setdefault(provider, f"fallback_patterns[{pattern}]") + + invalid = [] + for provider, source in all_providers.items(): + if provider in CUSTOM_TRANSFORMERS: + continue + if hasattr(self.fake, provider): + continue + invalid.append((provider, source)) + + if invalid: + details = ", ".join(f"'{p}' (from {s})" for p, s in invalid) + raise ValueError( + f"Unknown anonymization provider(s): {details}. " + f"Each provider must be a valid Faker method or a custom transformer " + f"({', '.join(sorted(CUSTOM_TRANSFORMERS))}). " + f"Check for typos in your anonymization configuration." + ) + def _is_foreign_key_column(self, table: str, column: str) -> bool: """ Check if a column is part of a foreign key. diff --git a/src/dbslice/validation.py b/src/dbslice/validation.py index 240fafa..b5ac218 100644 --- a/src/dbslice/validation.py +++ b/src/dbslice/validation.py @@ -295,7 +295,7 @@ def _extract_fk_values( Returns: Tuple of FK values """ - return tuple(row.get(col) for col in fk_columns) + return tuple(row[col] for col in fk_columns) def _has_parent_record( self, diff --git a/tests/conftest.py b/tests/conftest.py index 93d94c4..f1c2cca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -305,7 +305,9 @@ def fetch_referencing_pks( fk_tuple = tuple(row[col] for col in fk_cols) if fk_tuple in target_pk_values: pk_tuple = tuple(row[col] for col in pk_cols) - result.add(pk_tuple) + # Filter out NULL values (nullable FKs) + if None not in pk_tuple: + result.add(pk_tuple) return result diff --git a/tests/test_anonymizer.py b/tests/test_anonymizer.py index dc86580..fec03e1 100644 --- a/tests/test_anonymizer.py +++ b/tests/test_anonymizer.py @@ -1,5 +1,7 @@ """Tests for anonymization functionality.""" +import pytest + from dbslice.models import Column, ForeignKey, SchemaGraph, Table from dbslice.utils.anonymizer import DeterministicAnonymizer @@ -548,3 +550,40 @@ def test_column_name_in_hash(self): assert result_firstname == result_firstname2 assert result_lastname == result_lastname2 + + def test_invalid_field_provider_raises(self): + """Typo in field_providers should raise ValueError at configure time.""" + anon = DeterministicAnonymizer() + with pytest.raises(ValueError, match="Unknown anonymization provider"): + anon.configure( + [], + field_providers={"users.email": "nonexistent_faker_method"}, + ) + + def test_invalid_pattern_provider_raises(self): + """Typo in pattern provider should raise ValueError at configure time.""" + anon = DeterministicAnonymizer() + with pytest.raises(ValueError, match="Unknown anonymization provider"): + anon.configure( + [], + patterns={"*.email": "hipaa_zip_3"}, # typo: should be hipaa_zip3 + ) + + def test_custom_transformer_names_are_valid(self): + """Custom transformer names (e.g., hipaa_zip3) should pass validation.""" + anon = DeterministicAnonymizer() + # Should not raise + anon.configure( + [], + patterns={"*.zip": "hipaa_zip3", "*.dob": "year_only"}, + ) + + def test_valid_faker_methods_pass_validation(self): + """Standard Faker method names should pass validation.""" + anon = DeterministicAnonymizer() + # Should not raise + anon.configure( + [], + field_providers={"users.email": "email", "users.name": "name"}, + patterns={"*.phone": "phone_number"}, + ) diff --git a/tests/test_compliance.py b/tests/test_compliance.py index c50b52b..7350913 100644 --- a/tests/test_compliance.py +++ b/tests/test_compliance.py @@ -699,3 +699,88 @@ def fetch_rows(self, table: str, where_clause: str, params: tuple[object, ...]): target_table=None, console=MagicMock(), ) + + +class TestManifestComplianceValidation: + """Tests for post-extraction compliance manifest validation.""" + + def test_detects_unanonymized_column(self): + """Columns matching compliance patterns that were not masked should be flagged.""" + manifest = ComplianceManifest() + manifest.initialize("test-123", compliance_profiles=["gdpr"]) + + # Record that users.id was unmasked (fine) and users.email was unmasked (bad) + manifest.record_unmasked_field("users", "id") + manifest.record_unmasked_field("users", "email") + + extracted = { + "users": [{"id": 1, "email": "alice@example.com"}], + } + + warnings = manifest.validate_compliance(extracted, ["gdpr"]) + + assert len(warnings) == 1 + assert warnings[0].table == "users" + assert warnings[0].column == "email" + assert warnings[0].severity == "error" + assert "GDPR" in warnings[0].reason + + def test_no_warning_for_masked_column(self): + """Masked columns should not produce warnings.""" + manifest = ComplianceManifest() + manifest.initialize("test-123", compliance_profiles=["gdpr"]) + + manifest.record_masked_field("users", "email", "email") + + extracted = { + "users": [{"id": 1, "email": "fake@example.com"}], + } + + warnings = manifest.validate_compliance(extracted, ["gdpr"]) + + email_warnings = [w for w in warnings if w.column == "email"] + assert len(email_warnings) == 0 + + def test_no_warning_for_nulled_column(self): + """NULLed columns should not produce warnings.""" + manifest = ComplianceManifest() + manifest.initialize("test-123", compliance_profiles=["hipaa"]) + + manifest.record_nulled_field("users", "password", "security_null_pattern") + + extracted = { + "users": [{"id": 1, "password": None}], + } + + warnings = manifest.validate_compliance(extracted, ["hipaa"]) + + password_warnings = [w for w in warnings if w.column == "password"] + assert len(password_warnings) == 0 + + def test_fk_preserved_columns_not_flagged(self): + """FK-preserved columns should not produce warnings even if they match patterns.""" + manifest = ComplianceManifest() + manifest.initialize("test-123", compliance_profiles=["gdpr"]) + + manifest.record_fk_preserved("orders", "user_name_id") + + extracted = { + "orders": [{"id": 1, "user_name_id": 42}], + } + + warnings = manifest.validate_compliance(extracted, ["gdpr"]) + + fk_warnings = [w for w in warnings if w.column == "user_name_id"] + assert len(fk_warnings) == 0 + + def test_empty_tables_no_warnings(self): + """Empty tables should not produce warnings.""" + manifest = ComplianceManifest() + manifest.initialize("test-123", compliance_profiles=["gdpr"]) + + extracted: dict[str, list[dict]] = { + "users": [], + } + + warnings = manifest.validate_compliance(extracted, ["gdpr"]) + assert len(warnings) == 0 diff --git a/tests/test_csv_output.py b/tests/test_csv_output.py index 37a4344..0d89462 100644 --- a/tests/test_csv_output.py +++ b/tests/test_csv_output.py @@ -133,7 +133,7 @@ def test_generate_per_table_separate_headers(self, per_table_generator, sample_t assert "id,user_id,total" in orders_lines[0] def test_format_value_none(self, generator): - assert generator._format_value(None) == "" + assert generator._format_value(None) == "\\N" def test_format_value_bool(self, generator): assert generator._format_value(True) == "true" @@ -307,7 +307,7 @@ def test_null_value_in_csv(self, generator, sample_tables_schema): reader = csv.DictReader(csv_output.strip().split("\n")) rows = list(reader) - assert rows[0]["name"] == "" # NULL becomes empty string in CSV + assert rows[0]["name"] == "\\N" # NULL uses \N sentinel (PostgreSQL COPY convention) def test_write_to_file_single_mode(self, generator, sample_tables_schema, tmp_path): tables_data = { diff --git a/tests/test_cycles.py b/tests/test_cycles.py index 71a0f99..25889fe 100644 --- a/tests/test_cycles.py +++ b/tests/test_cycles.py @@ -7,6 +7,7 @@ DeferredUpdate, break_cycles_at_nullable_fks, build_deferred_updates, + build_deferred_updates_chunked, find_cycles_dfs, identify_cycle_fks, select_nullable_fk_to_break, @@ -427,6 +428,148 @@ def test_format_where_clause_composite(self): assert " AND " in where_clause +class TestBuildDeferredUpdatesChunked: + """Tests for the chunked (streaming-friendly) deferred update builder.""" + + def test_matches_non_chunked_output(self): + """Chunked variant should produce same DeferredUpdates as the original.""" + fk = ForeignKey( + name="fk_manager", + source_table="employees", + source_columns=("manager_id",), + target_table="employees", + target_columns=("id",), + is_nullable=True, + ) + + employees_table = Table( + name="employees", + schema="public", + columns=[], + primary_key=("id",), + foreign_keys=[], + ) + + schema = SchemaGraph(tables={"employees": employees_table}, edges=[fk]) + + rows = [ + {"id": 1, "name": "Alice", "manager_id": 2}, + {"id": 2, "name": "Bob", "manager_id": 1}, + {"id": 3, "name": "Charlie", "manager_id": None}, + ] + + # Non-chunked + expected = build_deferred_updates([fk], {"employees": rows}, schema) + + # Chunked — split rows into multiple small chunks + def chunk_iter(): + yield [rows[0]] + yield [rows[1], rows[2]] + + chunked_result = build_deferred_updates_chunked( + [fk], schema, {"employees": chunk_iter()} + ) + + assert len(chunked_result) == len(expected) + for a, b in zip(expected, chunked_result): + assert a.table == b.table + assert a.pk_columns == b.pk_columns + assert a.pk_values == b.pk_values + assert a.fk_column == b.fk_column + assert a.fk_value == b.fk_value + + def test_skips_null_fk_values(self): + """Chunked variant should skip NULL FK values like the original.""" + fk = ForeignKey( + name="fk_manager", + source_table="employees", + source_columns=("manager_id",), + target_table="employees", + target_columns=("id",), + is_nullable=True, + ) + + employees_table = Table( + name="employees", + schema="public", + columns=[], + primary_key=("id",), + foreign_keys=[], + ) + + schema = SchemaGraph(tables={"employees": employees_table}, edges=[fk]) + + def chunk_iter(): + yield [ + {"id": 1, "name": "Alice", "manager_id": None}, + {"id": 2, "name": "Bob", "manager_id": 1}, + ] + + updates = build_deferred_updates_chunked( + [fk], schema, {"employees": chunk_iter()} + ) + + assert len(updates) == 1 + assert updates[0].pk_values == (2,) + assert updates[0].fk_value == 1 + + def test_empty_iterator(self): + """Chunked variant handles empty iterator.""" + fk = ForeignKey( + name="fk_manager", + source_table="employees", + source_columns=("manager_id",), + target_table="employees", + target_columns=("id",), + is_nullable=True, + ) + + employees_table = Table( + name="employees", + schema="public", + columns=[], + primary_key=("id",), + foreign_keys=[], + ) + + schema = SchemaGraph(tables={"employees": employees_table}, edges=[fk]) + + def chunk_iter(): + return iter([]) + + updates = build_deferred_updates_chunked( + [fk], schema, {"employees": chunk_iter()} + ) + + assert updates == [] + + def test_missing_table_in_iterators(self): + """Chunked variant handles missing table gracefully.""" + fk = ForeignKey( + name="fk_manager", + source_table="employees", + source_columns=("manager_id",), + target_table="employees", + target_columns=("id",), + is_nullable=True, + ) + + employees_table = Table( + name="employees", + schema="public", + columns=[], + primary_key=("id",), + foreign_keys=[], + ) + + schema = SchemaGraph(tables={"employees": employees_table}, edges=[fk]) + + # No iterator for the table + updates = build_deferred_updates_chunked([fk], schema, {}) + + assert updates == [] + + class TestCycleInfo: """Tests for CycleInfo dataclass.""" diff --git a/tests/test_json_output.py b/tests/test_json_output.py index d322790..ef130ec 100644 --- a/tests/test_json_output.py +++ b/tests/test_json_output.py @@ -43,6 +43,13 @@ def test_encode_decimal(self): result = json.dumps(d, cls=DatabaseTypeEncoder) assert "99.99" in result + def test_encode_decimal_large_value_preserves_precision(self): + """Verify that large Decimal values preserve exact precision (no float rounding).""" + d = Decimal("99999999999999.99") + result = json.dumps(d, cls=DatabaseTypeEncoder) + parsed = json.loads(result) + assert parsed == "99999999999999.99" + def test_encode_uuid(self): u = UUID("12345678-1234-5678-1234-567812345678") result = json.dumps(u, cls=DatabaseTypeEncoder) @@ -83,7 +90,7 @@ def test_encode_mixed_types(self): # Verify all types were encoded correctly assert "2024-01-15" in parsed["datetime"] assert parsed["date"] == "2024-01-15" - assert parsed["decimal"] == 99.99 + assert parsed["decimal"] == "99.99" assert "12345678-1234-5678" in parsed["uuid"] assert parsed["bytes"] == "0001" assert parsed["string"] == "test" @@ -272,7 +279,7 @@ def test_generate_with_special_types(self, sample_tables_schema): user = parsed["tables"]["users"][0] assert "2024-01-15" in user["created_at"] assert user["birthday"] == "1990-05-20" - assert user["balance"] == 1234.56 + assert user["balance"] == "1234.56" assert "12345678-1234-5678" in user["uuid"] assert user["avatar"] == "89504e47" # hex of b"\x89PNG" @@ -554,4 +561,4 @@ def test_full_extraction_workflow(self, tmp_path): assert len(parsed["tables"]["users"]) == 2 assert len(parsed["tables"]["orders"]) == 3 assert parsed["tables"]["users"][0]["email"] == "user1@example.com" - assert parsed["tables"]["orders"][0]["amount"] == 99.99 + assert parsed["tables"]["orders"][0]["amount"] == "99.99" diff --git a/tests/test_max_seed_rows.py b/tests/test_max_seed_rows.py new file mode 100644 index 0000000..b94d65b --- /dev/null +++ b/tests/test_max_seed_rows.py @@ -0,0 +1,203 @@ +"""Tests for max_seed_rows safety limit on seed queries.""" + +from unittest.mock import patch + +import pytest + +from dbslice.config import ExtractConfig, SeedSpec +from dbslice.constants import DEFAULT_MAX_SEED_ROWS, SEED_ROW_WARNING_THRESHOLD +from dbslice.core.engine import ExtractionEngine +from dbslice.exceptions import ExtractionError +from dbslice.models import Column, SchemaGraph, Table +from tests.conftest import MockAdapter + + +def _make_schema_and_adapter( + row_count: int, +) -> tuple[SchemaGraph, MockAdapter]: + """Create a simple schema with a users table and the given number of rows.""" + users_table = Table( + name="users", + schema="public", + columns=[ + Column(name="id", data_type="INTEGER", nullable=False, is_primary_key=True), + Column(name="status", data_type="TEXT", nullable=True, is_primary_key=False), + ], + primary_key=("id",), + foreign_keys=[], + ) + schema = SchemaGraph(tables={"users": users_table}, edges=[]) + rows = [{"id": i, "status": "active"} for i in range(1, row_count + 1)] + adapter = MockAdapter(schema, {"users": rows}) + return schema, adapter + + +class TestMaxSeedRowsDefault: + """The default max_seed_rows value works correctly.""" + + def test_default_value_is_set(self): + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + ) + assert config.max_seed_rows == DEFAULT_MAX_SEED_ROWS + + def test_seed_within_default_limit_succeeds(self): + schema, adapter = _make_schema_and_adapter(5) + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + validate=False, + ) + engine = ExtractionEngine(config) + engine.schema = schema + engine.adapter = adapter + + # Should not raise + engine._process_seed(config.seeds[0]) + + +class TestMaxSeedRowsExceedsLimit: + """Seed queries exceeding the limit raise ExtractionError.""" + + def test_exceeds_default_limit(self): + row_count = DEFAULT_MAX_SEED_ROWS + 1 + schema, adapter = _make_schema_and_adapter(row_count) + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + validate=False, + ) + engine = ExtractionEngine(config) + engine.schema = schema + engine.adapter = adapter + + with pytest.raises(ExtractionError, match="exceeding limit"): + engine._process_seed(config.seeds[0]) + + def test_exceeds_custom_limit(self): + schema, adapter = _make_schema_and_adapter(6) + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + validate=False, + max_seed_rows=5, + ) + engine = ExtractionEngine(config) + engine.schema = schema + engine.adapter = adapter + + with pytest.raises(ExtractionError, match="exceeding limit of 5"): + engine._process_seed(config.seeds[0]) + + def test_exactly_at_limit_succeeds(self): + schema, adapter = _make_schema_and_adapter(5) + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + validate=False, + max_seed_rows=5, + ) + engine = ExtractionEngine(config) + engine.schema = schema + engine.adapter = adapter + + # Should not raise -- exactly at limit, not over + engine._process_seed(config.seeds[0]) + + def test_error_message_includes_row_count_and_limit(self): + schema, adapter = _make_schema_and_adapter(20) + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + validate=False, + max_seed_rows=10, + ) + engine = ExtractionEngine(config) + engine.schema = schema + engine.adapter = adapter + + with pytest.raises(ExtractionError, match=r"20 rows.*limit of 10"): + engine._process_seed(config.seeds[0]) + + def test_error_message_mentions_cli_flag(self): + schema, adapter = _make_schema_and_adapter(20) + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + validate=False, + max_seed_rows=10, + ) + engine = ExtractionEngine(config) + engine.schema = schema + engine.adapter = adapter + + with pytest.raises(ExtractionError, match="--max-seed-rows"): + engine._process_seed(config.seeds[0]) + + +class TestMaxSeedRowsCustomConfig: + """Custom max_seed_rows via config works correctly.""" + + def test_custom_limit_via_config(self): + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + max_seed_rows=500, + ) + assert config.max_seed_rows == 500 + + def test_high_limit_allows_large_result(self): + schema, adapter = _make_schema_and_adapter(50) + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + validate=False, + max_seed_rows=100, + ) + engine = ExtractionEngine(config) + engine.schema = schema + engine.adapter = adapter + + # 50 rows is under the 100 limit, should succeed + engine._process_seed(config.seeds[0]) + + +class TestSeedRowWarningThreshold: + """Warning is emitted when seed rows exceed the warning threshold.""" + + def test_warning_logged_above_threshold(self): + row_count = SEED_ROW_WARNING_THRESHOLD + 1 + schema, adapter = _make_schema_and_adapter(row_count) + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + validate=False, + max_seed_rows=row_count + 1, # Don't hit the hard limit + ) + engine = ExtractionEngine(config) + engine.schema = schema + engine.adapter = adapter + + with patch("dbslice.core.engine.logger") as mock_logger: + engine._process_seed(config.seeds[0]) + mock_logger.warning.assert_called_once() + call_args = mock_logger.warning.call_args + assert "large number of rows" in call_args[0][0] + + def test_no_warning_below_threshold(self): + row_count = SEED_ROW_WARNING_THRESHOLD + schema, adapter = _make_schema_and_adapter(row_count) + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + validate=False, + max_seed_rows=row_count + 1, + ) + engine = ExtractionEngine(config) + engine.schema = schema + engine.adapter = adapter + + with patch("dbslice.core.engine.logger") as mock_logger: + engine._process_seed(config.seeds[0]) + mock_logger.warning.assert_not_called() diff --git a/tests/test_passthrough.py b/tests/test_passthrough.py index ace170a..514a156 100644 --- a/tests/test_passthrough.py +++ b/tests/test_passthrough.py @@ -3,6 +3,7 @@ import pytest from dbslice.core.graph import GraphTraverser, TraversalConfig +from dbslice.exceptions import ExtractionError from dbslice.models import Column, ForeignKey, SchemaGraph, Table from tests.conftest import MockAdapter @@ -383,7 +384,7 @@ def test_passthrough_with_exclude_tables( def test_passthrough_table_without_pk(passthrough_schema: SchemaGraph): - """Test that passthrough tables without primary keys are skipped.""" + """Test that passthrough tables without primary keys raise an error.""" # Add a table without a primary key no_pk_table = Table( name="logs", @@ -418,10 +419,5 @@ def test_passthrough_table_without_pk(passthrough_schema: SchemaGraph): passthrough_tables={"logs", "countries"}, ) - result = traverser.traverse("users", seed_pks, config) - - # Table without PK should not be included - assert "logs" not in result.records - - # Table with PK should be included - assert "countries" in result.records + with pytest.raises(ExtractionError, match="no primary key"): + traverser.traverse("users", seed_pks, config) diff --git a/tests/test_performance.py b/tests/test_performance.py index f235b39..97e8b51 100644 --- a/tests/test_performance.py +++ b/tests/test_performance.py @@ -225,7 +225,7 @@ def test_engine_passes_configured_batch_size_to_adapter(monkeypatch): captured: dict[str, int | None] = {} class FakePostgreSQLAdapter: - def __init__(self, batch_size=None, profiler=None, schema=None, allow_unsafe_where=False): + def __init__(self, batch_size=None, profiler=None, schema=None, allow_unsafe_where=False, statement_timeout_ms=0): captured["batch_size"] = batch_size def connect(self, url: str) -> None: diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 503e6ff..6a36249 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -509,6 +509,153 @@ def test_streaming_with_cycles(sample_schema, tmp_path): assert "orders" in sql_content +def test_streaming_deferred_updates_use_chunked_fetch(sample_schema, tmp_path): + """Streaming mode should build deferred updates via chunked fetching, not list(). + + This verifies that the streaming extract path uses fetch_by_pk_chunked + (not fetch_by_pk with list()) when processing broken FK tables for + deferred updates, so that memory stays bounded. + """ + from dbslice.models import Column, ForeignKey, Table + + # Build a schema with a circular dependency: users <-> orders + users_table = Table( + name="users", + schema="main", + columns=[ + Column(name="id", data_type="INTEGER", nullable=False, is_primary_key=True), + Column(name="email", data_type="TEXT", nullable=False, is_primary_key=False), + Column(name="name", data_type="TEXT", nullable=True, is_primary_key=False), + Column( + name="last_order_id", + data_type="INTEGER", + nullable=True, + is_primary_key=False, + ), + ], + primary_key=("id",), + foreign_keys=[], + ) + + orders_table = Table( + name="orders", + schema="main", + columns=[ + Column(name="id", data_type="INTEGER", nullable=False, is_primary_key=True), + Column( + name="user_id", + data_type="INTEGER", + nullable=False, + is_primary_key=False, + ), + Column(name="total", data_type="REAL", nullable=True, is_primary_key=False), + Column(name="status", data_type="TEXT", nullable=True, is_primary_key=False), + ], + primary_key=("id",), + foreign_keys=[], + ) + + fk_orders_users = ForeignKey( + name="fk_orders_users", + source_table="orders", + source_columns=("user_id",), + target_table="users", + target_columns=("id",), + is_nullable=False, + ) + fk_users_orders = ForeignKey( + name="fk_users_orders", + source_table="users", + source_columns=("last_order_id",), + target_table="orders", + target_columns=("id",), + is_nullable=True, + ) + + from dbslice.models import SchemaGraph + + schema = SchemaGraph( + tables={"users": users_table, "orders": orders_table}, + edges=[fk_orders_users, fk_users_orders], + ) + + data = { + "users": [ + {"id": i, "email": f"u{i}@test.com", "name": f"U{i}", "last_order_id": i} + for i in range(1, 6) + ], + "orders": [ + {"id": i, "user_id": i, "total": 10.0 * i, "status": "ok"} for i in range(1, 6) + ], + } + + # Adapter that tracks whether fetch_by_pk_chunked was used + class TrackingAdapter(ChunkedMockAdapter): + def __init__(self, schema, data): + super().__init__(schema, data, chunk_size=2) + self.fetch_by_pk_calls: list[str] = [] + + def fetch_by_pk(self, table, pk_columns, pk_values): + self.fetch_by_pk_calls.append(table) + return super().fetch_by_pk(table, pk_columns, pk_values) + + def fetch_by_pk_chunked(self, table, pk_columns, pk_values, chunk_size=1000): + self.chunked_fetch_calls.append( + {"table": table, "pk_count": len(pk_values), "chunk_size": chunk_size} + ) + # Yield in small chunks to prove we don't need all at once + all_rows = list(super().fetch_by_pk(table, pk_columns, pk_values)) + for i in range(0, len(all_rows), chunk_size): + yield all_rows[i : i + chunk_size] + + adapter = TrackingAdapter(schema, data) + adapter.connect("test://localhost/test") + + all_records = { + "users": {(i,) for i in range(1, 6)}, + "orders": {(i,) for i in range(1, 6)}, + } + + config = ExtractConfig( + database_url="test://localhost/test", + seeds=[SeedSpec.parse("users.id=1")], + stream=True, + output_file=str(tmp_path / "chunked_deferred.sql"), + streaming_chunk_size=2, + ) + + engine = ExtractionEngine(config) + engine.adapter = adapter + engine.schema = schema + + # Reset call tracking before the streaming extract + adapter.fetch_by_pk_calls.clear() + adapter.chunked_fetch_calls.clear() + + result = engine._do_streaming_extract( + db_type=DatabaseType.POSTGRESQL, + all_records=all_records, + insert_order=["orders", "users"], + broken_fks=[fk_users_orders], + cycle_infos=[], + all_paths=[], + used_deferred_cycle_strategy=True, + ) + + # Verify chunked fetch was used for the broken FK table + chunked_tables = [c["table"] for c in adapter.chunked_fetch_calls] + assert "users" in chunked_tables, ( + "Expected fetch_by_pk_chunked to be used for the broken FK table 'users'" + ) + + # The chunked_fetch_calls list confirms the chunked path was taken + # instead of loading all rows into memory with list(fetch_by_pk(...)). + assert len(adapter.chunked_fetch_calls) > 0 + + # Result should still be valid + assert result.total_rows() > 0 + + def test_streaming_empty_table(sample_schema, tmp_path): """Test streaming handles empty tables gracefully.""" data = { diff --git a/tests/test_validation.py b/tests/test_validation.py index 815ea4c..c8ac3f1 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -477,6 +477,23 @@ def test_orphan_with_composite_pk(self): orphan = result.orphaned_records[0] assert orphan.fk_values == (1, 999) + def test_extract_fk_values_raises_on_missing_column(self, simple_schema): + """Test that missing FK columns raise KeyError instead of silently returning None.""" + validator = ExtractionValidator(simple_schema) + + # Row is missing the 'user_id' FK column entirely + tables = { + "users": [ + {"id": 1, "email": "alice@example.com"}, + ], + "orders": [ + {"id": 1, "total": 100.0}, # 'user_id' column missing + ], + } + + with pytest.raises(KeyError, match="user_id"): + validator.validate(tables) + class TestOrphanedRecord: """Tests for OrphanedRecord dataclass.""" diff --git a/tests/test_validators.py b/tests/test_validators.py index 5eb7dd8..9113cd3 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -453,3 +453,62 @@ def test_case_sensitivity(self): for case_variant in ["DROP", "drop", "Drop", "dRoP"]: with pytest.raises(IdentifierValidationError): validate_identifier(case_variant, "test") + + +class TestWhereClauseDelegation: + """Verify input_validators.validate_where_clause catches attacks via config delegation. + + These vectors were previously only caught by config.validate_where_clause. + After the delegation fix, the input_validators wrapper must catch them too. + """ + + def test_dollar_quoting_blocked(self): + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("$$DROP TABLE users$$") + + def test_tagged_dollar_quoting_blocked(self): + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("$body$DROP TABLE users$body$") + + def test_type_cast_blocked(self): + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("id = 1::int") + + def test_escape_string_blocked(self): + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("id = E'\\x44ROP'") + + def test_pg_sleep_blocked(self): + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("id = 1 AND pg_sleep(10)") + + def test_lo_import_blocked(self): + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("id = lo_import('/etc/passwd')") + + def test_dblink_blocked(self): + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("id = 1 AND dblink('host=evil.com')") + + def test_fullwidth_unicode_drop_blocked(self): + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("\uff24\uff32\uff2f\uff30 TABLE users") + + def test_comment_sequences_blocked(self): + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("id = 1 -- comment") + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("id = 1 /* comment */") + + def test_standalone_drop_keyword_blocked(self): + """DROP as a bare keyword (not just DROP TABLE) should be caught.""" + with pytest.raises(SeedValidationError, match="dangerous SQL patterns"): + validate_where_clause("1=1 DROP TABLE users") + + def test_safe_clauses_still_pass(self): + """Ensure delegation does not cause false positives on safe inputs.""" + validate_where_clause("status = 'active'") + validate_where_clause("id IN (1, 2, 3)") + validate_where_clause("dropbox_id = 123") + validate_where_clause("category = 'DELETE'") + validate_where_clause("name = 'O''Brien'")