diff --git a/docs/user-guide/cli-reference.md b/docs/user-guide/cli-reference.md index 46afb23..2ed97ac 100644 --- a/docs/user-guide/cli-reference.md +++ b/docs/user-guide/cli-reference.md @@ -61,6 +61,7 @@ dbslice extract [OPTIONS] [DATABASE_URL] | Option | Type | Default | Description | |--------|------|---------|-------------| | `--seed`, `-s` | TEXT | *Required* | Seed record specification (repeatable) | +| `--allow-unsafe-where` / `--no-allow-unsafe-where` | FLAG | Disabled | Allow subqueries in seed WHERE clauses (trusted inputs only) | **Seed Formats:** - `table.column=value` - Simple equality (e.g., `orders.id=12345`) @@ -468,6 +469,9 @@ Foreign Keys (23) Self-references (potential cycles): categories.parent_id + +Potential implicit relationships: + audit_log.user_id -> users.id ``` ##### Inspect Specific Table @@ -570,10 +574,12 @@ Accepted formats: |----------|-------------|---------| | `DBSLICE_ANONYMIZE` | Enable anonymization | `true` | | `DBSLICE_REDACT_FIELDS` | Comma-separated redact fields | `users.ssn,payments.card` | +| `DBSLICE_ALLOW_UNSAFE_WHERE` | Allow seed subqueries for advanced filters | `true` | Accepted formats: - `DBSLICE_ANONYMIZE`: `1/0`, `true/false`, `yes/no`, or `on/off` (case-insensitive). - `DBSLICE_REDACT_FIELDS`: comma-separated `table.column` values. +- `DBSLICE_ALLOW_UNSAFE_WHERE`: `1/0`, `true/false`, `yes/no`, or `on/off` (case-insensitive). ### Examples diff --git a/docs/user-guide/configuration.md b/docs/user-guide/configuration.md index 7400fe6..d2b8b44 100644 --- a/docs/user-guide/configuration.md +++ b/docs/user-guide/configuration.md @@ -168,6 +168,7 @@ extraction: validate: boolean # Enable validation fail_on_validation_error: boolean # Stop on validation errors max_rows_per_table: integer # Optional global row soft-cap + allow_unsafe_where: boolean # Allow subqueries in seed WHERE clauses (trusted input only) ``` #### Fields @@ -180,6 +181,7 @@ extraction: | `validate` | Boolean | No | `true` | Validate extraction for referential integrity | | `fail_on_validation_error` | Boolean | No | `false` | Stop execution if validation finds issues | | `max_rows_per_table` | Integer | No | unlimited | Global per-table soft-cap with integrity closure | +| `allow_unsafe_where` | Boolean | No | `false` | Allow seed subqueries like `IN (SELECT ...)` for trusted inputs | `max_rows_per_table` is deterministic and integrity-first: - dbslice first caps each table deterministically by primary key sort. @@ -187,6 +189,11 @@ extraction: - Parent closure may exceed the configured cap. - If any row limit is configured, streaming mode is disabled automatically. +`allow_unsafe_where` notes: +- Default is `false` for security. +- When `true`, subqueries in seed WHERE clauses are allowed (for advanced filtering/join-style selection). +- Dangerous operations (`DROP`, `DELETE`, comments, stacked queries, etc.) are still blocked. + #### Examples ```yaml @@ -217,6 +224,10 @@ extraction: default_depth: 10 direction: up validate: true + +# Trusted advanced WHERE filters (subqueries) +extraction: + allow_unsafe_where: true ``` --- @@ -374,7 +385,7 @@ output: | `format` | String | No | `"sql"` | Output format: `sql`, `json`, or `csv` | | `include_transaction` | Boolean | No | `true` | Wrap SQL in BEGIN/COMMIT | | `include_truncate` | Boolean | No | `false` | Include `TRUNCATE TABLE ... CASCADE` before inserts | -| `disable_fk_checks` | Boolean | No | `false` | Disable FK checks during import | +| `disable_fk_checks` | Boolean | No | `false` | For PostgreSQL SQL output, emits deferred-constraint statements and enables non-nullable cycle fallback when FKs are DEFERRABLE | | `file_mode` | String/Octal | No | `"600"` | File permissions for generated outputs | | `json_mode` | String | No | `"single"` | JSON mode: `single` or `per-table` | | `json_pretty` | Boolean | No | `true` | Pretty-print JSON output | @@ -402,6 +413,10 @@ output: `include_drop_tables` is still accepted as a backward-compatible alias for `include_truncate`, but is deprecated. +Cycle note for PostgreSQL SQL imports: +- When cycles have no nullable FK, dbslice can still generate SQL if `disable_fk_checks: true` and cycle FKs are `DEFERRABLE`. +- If cycle FKs are not deferrable, extraction fails with a clear error. + ```yaml # JSON output (single file) output: diff --git a/src/dbslice/adapters/postgresql.py b/src/dbslice/adapters/postgresql.py index 3d6638d..99f9266 100644 --- a/src/dbslice/adapters/postgresql.py +++ b/src/dbslice/adapters/postgresql.py @@ -26,12 +26,14 @@ def __init__( batch_size: int | None = None, profiler: Any = None, schema: str | None = None, + allow_unsafe_where: bool = False, ): self._conn: Any = None self._schema_name = schema or "public" self._schema_cache: SchemaGraph | None = None self.batch_size = batch_size or self.DEFAULT_BATCH_SIZE self.profiler = profiler + self.allow_unsafe_where = allow_unsafe_where def connect(self, url: str) -> None: """Establish PostgreSQL connection.""" @@ -216,7 +218,8 @@ def _fetch_foreign_keys(self, schema: str) -> list[ForeignKey]: a_source.attname AS source_column, target_cls.relname AS target_table, a_target.attname AS target_column, - NOT a_source.attnotnull AS is_nullable + NOT a_source.attnotnull AS is_nullable, + c.condeferrable AS is_deferrable FROM pg_constraint c JOIN pg_class source_cls ON c.conrelid = source_cls.oid JOIN pg_class target_cls ON c.confrelid = target_cls.oid @@ -239,7 +242,15 @@ def _fetch_foreign_keys(self, schema: str) -> list[ForeignKey]: # Group by constraint name for multi-column FKs fk_data: dict[str, dict] = {} for row in cur.fetchall(): - constraint_name, source_table, source_col, target_table, target_col, is_nullable = ( + ( + constraint_name, + source_table, + source_col, + target_table, + target_col, + is_nullable, + is_deferrable, + ) = ( row ) @@ -251,6 +262,7 @@ def _fetch_foreign_keys(self, schema: str) -> list[ForeignKey]: "target_table": target_table, "target_columns": [], "is_nullable": bool(is_nullable), + "is_deferrable": bool(is_deferrable), } fk_data[constraint_name]["source_columns"].append(source_col) @@ -265,6 +277,7 @@ def _fetch_foreign_keys(self, schema: str) -> list[ForeignKey]: target_table=data["target_table"], target_columns=tuple(data["target_columns"]), is_nullable=data["is_nullable"], + is_deferrable=data["is_deferrable"], ) ) @@ -291,7 +304,11 @@ def fetch_rows( # Defense-in-depth: validate WHERE clause even if it was validated earlier from dbslice.config import validate_where_clause - validate_where_clause(where_clause, f"{table}:{where_clause}") + validate_where_clause( + where_clause, + f"{table}:{where_clause}", + allow_unsafe_subqueries=self.allow_unsafe_where, + ) try: query = f'SELECT * FROM "{table}" WHERE {where_clause}' diff --git a/src/dbslice/cli.py b/src/dbslice/cli.py index 17222e9..f1e04df 100644 --- a/src/dbslice/cli.py +++ b/src/dbslice/cli.py @@ -186,6 +186,7 @@ def _parse_env_comma_list(var_name: str) -> list[str] | None: def _parse_and_validate_seeds( seeds: list[str], console: Console, + allow_unsafe_subqueries: bool = False, ) -> list[SeedSpec]: """ Parse and validate seed specifications from CLI arguments. @@ -203,7 +204,9 @@ def _parse_and_validate_seeds( parsed_seeds = [] for s in seeds: try: - parsed_seeds.append(SeedSpec.parse(s)) + parsed_seeds.append( + SeedSpec.parse(s, allow_unsafe_subqueries=allow_unsafe_subqueries) + ) except ValueError as e: raise InvalidSeedError(s, str(e)) @@ -270,6 +273,7 @@ def _build_extract_config( stream_chunk_size: int, output_file_mode: int, schema: str | None = None, + allow_unsafe_where: bool = False, ) -> ExtractConfig: """ Build ExtractConfig from validated CLI parameters. @@ -321,6 +325,7 @@ def _build_extract_config( streaming_chunk_size=stream_chunk_size, output_file_mode=output_file_mode, schema=schema, + allow_unsafe_where=bool(allow_unsafe_where), ) @@ -402,9 +407,14 @@ def _show_extraction_summary( if result.has_cycles: console.print() - console.print("[yellow]⚠ Circular dependencies detected and resolved[/yellow]") - console.print(f" Broken FKs: [cyan]{len(result.broken_fks)}[/cyan]") - console.print(f" Deferred UPDATEs: [cyan]{len(result.deferred_updates)}[/cyan]") + if result.used_deferred_cycle_strategy: + console.print("[yellow]⚠ Circular dependencies detected (deferred-constraint strategy)[/yellow]") + console.print(" Strategy: [cyan]Deterministic order + SET CONSTRAINTS ALL DEFERRED[/cyan]") + console.print(f" Cycles: [cyan]{len(result.cycle_infos)}[/cyan]") + else: + console.print("[yellow]⚠ Circular dependencies detected and resolved[/yellow]") + console.print(f" Broken FKs: [cyan]{len(result.broken_fks)}[/cyan]") + console.print(f" Deferred UPDATEs: [cyan]{len(result.deferred_updates)}[/cyan]") if config.verbose: for cycle_info in result.cycle_infos: console.print(f" [dim]Cycle: {cycle_info}[/dim]") @@ -966,6 +976,16 @@ def extract( help="PostgreSQL schema name (default: 'public')", ), ] = None, + allow_unsafe_where: Annotated[ + bool | None, + typer.Option( + "--allow-unsafe-where/--no-allow-unsafe-where", + help=( + "Allow seed WHERE subqueries (e.g. IN (SELECT ... JOIN ...)). " + "Use only with trusted inputs." + ), + ), + ] = None, ): """ Extract a database subset starting from seed record(s). @@ -1034,6 +1054,9 @@ def extract( redact_override = redact if redact_override is None: redact_override = _parse_env_comma_list("DBSLICE_REDACT_FIELDS") + allow_unsafe_where_override = allow_unsafe_where + if allow_unsafe_where_override is None: + allow_unsafe_where_override = _parse_env_bool("DBSLICE_ALLOW_UNSAFE_WHERE") except ValueError as e: console.print(f"[red]Validation Error:[/red] {e}") raise typer.Exit(1) @@ -1076,6 +1099,16 @@ def extract( ) raise typer.Exit(1) + effective_allow_unsafe_where = ( + allow_unsafe_where_override + if allow_unsafe_where_override is not None + else ( + loaded_config.extraction.allow_unsafe_where + if loaded_config is not None + else False + ) + ) + resolved_database_url = database_url_override if not resolved_database_url and loaded_config: resolved_database_url = loaded_config.database.url @@ -1138,7 +1171,11 @@ def extract( console.print(f"[red]Validation Error:[/red] {e}") raise typer.Exit(1) - seed_specs = _parse_and_validate_seeds(seed or [], console) + seed_specs = _parse_and_validate_seeds( + seed or [], + console, + allow_unsafe_subqueries=effective_allow_unsafe_where, + ) if loaded_config: direction_enum = ( @@ -1176,6 +1213,7 @@ def extract( if output_file_mode is not None else None, schema=schema, + allow_unsafe_where=allow_unsafe_where_override, ) output_format = extract_config.output_format else: @@ -1204,6 +1242,7 @@ def extract( stream_chunk_size=effective_stream_chunk_size, output_file_mode=effective_output_file_mode, schema=schema, + allow_unsafe_where=effective_allow_unsafe_where, ) if verbose and not no_progress: @@ -1498,6 +1537,59 @@ def _detect_sensitive_fields(schema) -> dict[str, str]: return detected +def _detect_potential_implicit_fks(schema) -> list[tuple[str, str, str]]: + """ + Detect likely implicit FK relationships using naming heuristics. + + Returns: + List of (source_table, source_column, target_table) tuples. + """ + existing_fk_columns: set[tuple[str, str]] = set() + for fk in schema.edges: + for source_col in fk.source_columns: + existing_fk_columns.add((fk.source_table, source_col)) + + tables_by_lower = {name.lower(): name for name in schema.tables.keys()} + + candidates: list[tuple[str, str, str]] = [] + + for table_name, table in schema.tables.items(): + for column in table.columns: + if (table_name, column.name) in existing_fk_columns: + continue + + col_lower = column.name.lower() + if not col_lower.endswith("_id"): + continue + + base = col_lower[:-3] + if not base: + continue + + guesses = [base, f"{base}s"] + if base.endswith("y") and len(base) > 1: + guesses.append(f"{base[:-1]}ies") + else: + guesses.append(f"{base}es") + + target_table: str | None = None + for guess in guesses: + actual = tables_by_lower.get(guess) + if not actual or actual == table_name: + continue + target_info = schema.tables.get(actual) + if not target_info or not target_info.primary_key: + continue + if "id" in {pk.lower() for pk in target_info.primary_key}: + target_table = actual + break + + if target_table: + candidates.append((table_name, column.name, target_table)) + + return sorted(candidates, key=lambda item: (item[0], item[1], item[2])) + + @app.command() def inspect( database_url: Annotated[ @@ -1594,6 +1686,21 @@ def inspect( f" [cyan]{child_table}[/cyan].{', '.join(fk.source_columns)}" ) + implicit_candidates = [ + candidate + for candidate in _detect_potential_implicit_fks(db_schema) + if candidate[0] == table + ] + if implicit_candidates: + console.print("\n [yellow]Potential implicit relationships:[/yellow]") + for src_table, src_col, target_table in implicit_candidates: + console.print( + f" {src_table}.{src_col} -> [cyan]{target_table}[/cyan].id" + ) + console.print( + " [dim]Tip: add these as virtual_foreign_keys if they are real relationships.[/dim]" + ) + else: console.print(f"\n[bold]Tables ({len(db_schema.tables)})[/bold]") for name in sorted(db_schema.tables.keys()): @@ -1616,6 +1723,19 @@ def inspect( for fk in self_refs: console.print(f" {fk.source_table}.{', '.join(fk.source_columns)}") + implicit_candidates = _detect_potential_implicit_fks(db_schema) + if implicit_candidates: + console.print("\n[yellow]Potential implicit relationships:[/yellow]") + for src_table, src_col, target_table in implicit_candidates[:25]: + console.print(f" {src_table}.{src_col} -> [cyan]{target_table}[/cyan].id") + if len(implicit_candidates) > 25: + console.print( + f" [dim]... and {len(implicit_candidates) - 25} more[/dim]" + ) + console.print( + " [dim]Tip: define virtual_foreign_keys for confirmed implicit links.[/dim]" + ) + finally: adapter.close() diff --git a/src/dbslice/config.py b/src/dbslice/config.py index 8722004..3ae619b 100644 --- a/src/dbslice/config.py +++ b/src/dbslice/config.py @@ -85,7 +85,9 @@ class OutputFormat(Enum): } -def validate_where_clause(where_clause: str, seed_str: str = "") -> None: +def validate_where_clause( + where_clause: str, seed_str: str = "", allow_unsafe_subqueries: bool = False +) -> None: """ Validate that a WHERE clause doesn't contain dangerous SQL keywords. @@ -96,6 +98,7 @@ def validate_where_clause(where_clause: str, seed_str: str = "") -> None: Args: where_clause: The WHERE clause to validate (without the WHERE keyword) seed_str: Original seed string for error reporting + allow_unsafe_subqueries: Allow subqueries (e.g. IN (SELECT ...)) when true Raises: InsecureWhereClauseError: If dangerous keywords are detected @@ -140,8 +143,8 @@ def validate_where_clause(where_clause: str, seed_str: str = "") -> None: if re.search(pattern, normalized_lower): raise InsecureWhereClauseError(seed_str or where_clause, func_name + "()") - # Block subqueries (SELECT inside parentheses) - if re.search(r"\(\s*SELECT\b", normalized_upper): + # Block subqueries (SELECT inside parentheses) unless explicitly opted in. + if not allow_unsafe_subqueries and re.search(r"\(\s*SELECT\b", normalized_upper): raise InsecureWhereClauseError(seed_str or where_clause, "subquery (SELECT)") # Block type casts with :: (PostgreSQL-specific, can be used to smuggle data) @@ -168,7 +171,7 @@ class SeedSpec: where_clause: str | None # Raw WHERE clause if provided @classmethod - def parse(cls, seed_str: str) -> "SeedSpec": + def parse(cls, seed_str: str, allow_unsafe_subqueries: bool = False) -> "SeedSpec": """ Parse a seed string into a SeedSpec with comprehensive validation. @@ -204,7 +207,11 @@ def parse(cls, seed_str: str) -> "SeedSpec": except Exception as e: raise ValueError(f"Invalid seed table name: {e}") - validate_where_clause(where_clause, seed_str) + validate_where_clause( + where_clause, + seed_str, + allow_unsafe_subqueries=allow_unsafe_subqueries, + ) return cls( table=table, @@ -256,7 +263,9 @@ def parse(cls, seed_str: str) -> "SeedSpec": "Use 'table.column=value' or 'table:WHERE_CLAUSE'" ) - def to_where_clause(self) -> tuple[str, tuple[Any, ...]]: + def to_where_clause( + self, allow_unsafe_subqueries: bool = False + ) -> tuple[str, tuple[Any, ...]]: """ Convert to WHERE clause and parameters. @@ -265,7 +274,11 @@ def to_where_clause(self) -> tuple[str, tuple[Any, ...]]: """ if self.where_clause: # Re-validate in case object was constructed directly (not via parse) - validate_where_clause(self.where_clause, f"{self.table}:{self.where_clause}") + validate_where_clause( + self.where_clause, + f"{self.table}:{self.where_clause}", + allow_unsafe_subqueries=allow_unsafe_subqueries, + ) return (self.where_clause, ()) else: return (f"{self.column} = %s", (self.value,)) @@ -309,3 +322,4 @@ class ExtractConfig: security_null_fields: list[str] = field(default_factory=list) virtual_foreign_keys: list[VirtualForeignKey] = field(default_factory=list) schema: str | None = None # PostgreSQL schema name (default: public) + allow_unsafe_where: bool = False diff --git a/src/dbslice/config_file.py b/src/dbslice/config_file.py index ea15c64..f1a45c6 100644 --- a/src/dbslice/config_file.py +++ b/src/dbslice/config_file.py @@ -40,6 +40,7 @@ "validate", "fail_on_validation_error", "max_rows_per_table", + "allow_unsafe_where", } _ANONYMIZATION_KEYS = { "enabled", @@ -290,6 +291,9 @@ class ExtractionConfig: max_rows_per_table: int | None = None """Global limit on rows per table (None = unlimited).""" + allow_unsafe_where: bool = False + """Allow seed WHERE clauses with subqueries (trusted inputs only).""" + @dataclass class AnonymizationConfig: @@ -525,6 +529,7 @@ def _from_dict(cls, data: dict[str, Any]) -> "DbsliceConfig": validate=extraction_data.get("validate", True), fail_on_validation_error=extraction_data.get("fail_on_validation_error", False), max_rows_per_table=extraction_data.get("max_rows_per_table"), + allow_unsafe_where=extraction_data.get("allow_unsafe_where", False), ) if not isinstance(extraction.default_depth, int) or extraction.default_depth < 1: @@ -545,6 +550,8 @@ def _from_dict(cls, data: dict[str, Any]) -> "DbsliceConfig": raise ValueError("'extraction.validate' must be true or false") if not isinstance(extraction.fail_on_validation_error, bool): raise ValueError("'extraction.fail_on_validation_error' must be true or false") + if not isinstance(extraction.allow_unsafe_where, bool): + raise ValueError("'extraction.allow_unsafe_where' must be true or false") if extraction.max_rows_per_table is not None and ( not isinstance(extraction.max_rows_per_table, int) or extraction.max_rows_per_table <= 0 ): @@ -827,6 +834,7 @@ def to_extract_config( stream_chunk_size: int | None = None, output_file_mode: int | None = None, schema: str | None = None, + allow_unsafe_where: bool | None = None, ) -> ExtractConfig: """ Convert to ExtractConfig for use by the extraction engine. @@ -855,6 +863,7 @@ def to_extract_config( stream_chunk_size: Streaming chunk size override (from CLI) output_file_mode: Output file permissions override (from CLI) schema: PostgreSQL schema name override (from CLI) + allow_unsafe_where: Unsafe WHERE override (from CLI/env) Returns: ExtractConfig ready for extraction @@ -926,6 +935,11 @@ def to_extract_config( if stream_chunk_size is not None else self.performance.streaming.chunk_size ) + final_allow_unsafe_where = ( + allow_unsafe_where + if allow_unsafe_where is not None + else self.extraction.allow_unsafe_where + ) final_output_file_mode = ( output_file_mode if output_file_mode is not None else self.output.file_mode ) @@ -1037,6 +1051,7 @@ def to_extract_config( security_null_fields=effective_security_null_fields, virtual_foreign_keys=virtual_fks, schema=final_schema, + allow_unsafe_where=final_allow_unsafe_where, ) def to_yaml(self, include_comments: bool = True) -> str: @@ -1093,6 +1108,8 @@ def to_yaml(self, include_comments: bool = True) -> str: " fail_on_validation_error: " f"{str(self.extraction.fail_on_validation_error).lower()}" ) + if self.extraction.allow_unsafe_where: + output.append(" allow_unsafe_where: true") if self.extraction.exclude_tables: output.append(" exclude_tables:") for table in self.extraction.exclude_tables: diff --git a/src/dbslice/core/engine.py b/src/dbslice/core/engine.py index 36edf63..b447851 100644 --- a/src/dbslice/core/engine.py +++ b/src/dbslice/core/engine.py @@ -8,6 +8,7 @@ from dbslice.config import ( DatabaseType, ExtractConfig, + OutputFormat, SeedSpec, TraversalDirection, ) @@ -76,6 +77,7 @@ class ExtractionResult: cycle_infos: list[Any] = field(default_factory=list) # list[CycleInfo] validation_result: ValidationResult | None = None profiler: Any = None # Optional QueryProfiler + used_deferred_cycle_strategy: bool = False def total_rows(self) -> int: """Get total number of extracted rows.""" @@ -169,6 +171,7 @@ def extract(self) -> tuple[ExtractionResult, SchemaGraph]: batch_size=self.config.db_batch_size, profiler=profiler, schema=self.config.schema, + allow_unsafe_where=self.config.allow_unsafe_where, ) else: self.adapter = get_adapter_for_url(self.config.database_url) @@ -279,7 +282,9 @@ def _do_extract(self, db_type: DatabaseType) -> ExtractionResult: self._log("sort", "Sorting tables by dependencies...") with logger.timed_operation("topological_sort", table_count=len(all_records)): - insert_order, broken_fks, cycle_infos = self._topological_sort(set(all_records.keys())) + insert_order, broken_fks, cycle_infos, used_deferred_cycle_strategy = ( + self._topological_sort(set(all_records.keys()), db_type) + ) if broken_fks: logger.warning( @@ -320,10 +325,11 @@ def _do_extract(self, db_type: DatabaseType) -> ExtractionResult: insert_order=insert_order, stats=dry_run_stats, traversal_path=all_paths, - has_cycles=len(broken_fks) > 0, + has_cycles=bool(cycle_infos), broken_fks=broken_fks, deferred_updates=[], cycle_infos=cycle_infos, + used_deferred_cycle_strategy=used_deferred_cycle_strategy, ) total_rows_estimate = sum(len(pks) for pks in all_records.values()) @@ -345,7 +351,13 @@ def _do_extract(self, db_type: DatabaseType) -> ExtractionResult: ) return self._do_streaming_extract( - db_type, all_records, insert_order, broken_fks, cycle_infos, all_paths + db_type, + all_records, + insert_order, + broken_fks, + cycle_infos, + all_paths, + used_deferred_cycle_strategy, ) logger.info( @@ -465,11 +477,12 @@ def _do_extract(self, db_type: DatabaseType) -> ExtractionResult: insert_order=insert_order, stats=stats, traversal_path=all_paths, - has_cycles=len(broken_fks) > 0, + has_cycles=bool(cycle_infos), broken_fks=broken_fks, deferred_updates=deferred_updates, cycle_infos=cycle_infos, validation_result=validation_result, + used_deferred_cycle_strategy=used_deferred_cycle_strategy, ) def _anonymize_table_data(self, table: str, rows: list[dict[str, Any]]) -> list[dict[str, Any]]: @@ -642,7 +655,9 @@ def _process_seed(self, seed: SeedSpec) -> TraversalResult: table=seed.table, ) - where_clause, params = seed.to_where_clause() + where_clause, params = seed.to_where_clause( + allow_unsafe_subqueries=self.config.allow_unsafe_where + ) seed_rows = list(self.adapter.fetch_rows(seed.table, where_clause, params)) if not seed_rows: @@ -714,6 +729,7 @@ def _do_streaming_extract( broken_fks: list[Any], cycle_infos: list[Any], all_paths: list[str], + used_deferred_cycle_strategy: bool, ) -> ExtractionResult: """ Perform streaming extraction to file. @@ -788,10 +804,14 @@ def _do_streaming_extract( result.traversal_path = all_paths result.cycle_infos = cycle_infos + result.has_cycles = bool(cycle_infos) + result.used_deferred_cycle_strategy = used_deferred_cycle_strategy return result - def _topological_sort(self, tables: set[str]) -> tuple[list[str], list[Any], list[Any]]: + def _topological_sort( + self, tables: set[str], db_type: DatabaseType + ) -> tuple[list[str], list[Any], list[Any], bool]: """ Topologically sort tables based on FK dependencies with cycle handling. @@ -799,11 +819,16 @@ def _topological_sort(self, tables: set[str]) -> tuple[list[str], list[Any], lis If cycles are detected, breaks them at nullable foreign keys. Returns: - Tuple of (insert_order, broken_fks, cycle_infos) + Tuple of (insert_order, broken_fks, cycle_infos, used_deferred_cycle_strategy) """ assert self.schema is not None - from dbslice.core.cycles import break_cycles_at_nullable_fks + from dbslice.core.cycles import ( + CycleInfo, + break_cycles_at_nullable_fks, + find_cycles_dfs, + identify_cycle_fks, + ) dependencies: dict[str, set[str]] = {t: set() for t in tables} @@ -815,16 +840,72 @@ def _topological_sort(self, tables: set[str]) -> tuple[list[str], list[Any], lis try: insert_order = list(ts.static_order()) - return insert_order, [], [] + return insert_order, [], [], False except CycleError: try: fks_to_break, cycle_infos = break_cycles_at_nullable_fks( self.schema, tables, dependencies ) except ValueError as e: - # No nullable FK found to break cycle + cycles = find_cycles_dfs(dependencies) + cycle_infos = [ + CycleInfo( + tables=cycle, + fks_in_cycle=identify_cycle_fks(self.schema, cycle), + ) + for cycle in cycles + ] + non_deferrable: list[str] = [] + + # For JSON/CSV output, deterministic ordering is enough. + if self.config.output_format != OutputFormat.SQL: + logger.warning( + "Cycle detected with no nullable FK; using deterministic output order for non-SQL format", + cycle_count=len(cycle_infos), + format=self.config.output_format.value, + ) + return sorted(tables), [], cycle_infos, False + + # SQL output can fallback only when explicitly requested and feasible. + if ( + self.config.disable_fk_checks + and db_type == DatabaseType.POSTGRESQL + and cycle_infos + ): + for cycle_info in cycle_infos: + for fk in cycle_info.fks_in_cycle: + if not fk.is_deferrable: + non_deferrable.append( + f"{fk.source_table}.{', '.join(fk.source_columns)} -> " + f"{fk.target_table}.{', '.join(fk.target_columns)}" + ) + if not non_deferrable: + logger.warning( + "Cycle detected with no nullable FK; using deferred-constraint strategy", + cycle_count=len(cycle_infos), + ) + return sorted(tables), [], cycle_infos, True + from dbslice.exceptions import CircularReferenceError + if self.config.disable_fk_checks and db_type == DatabaseType.POSTGRESQL: + if non_deferrable: + non_deferrable_msg = "\n".join(f" - {item}" for item in non_deferrable[:10]) + if len(non_deferrable) > 10: + non_deferrable_msg += f"\n ... and {len(non_deferrable) - 10} more" + elif cycle_infos: + non_deferrable_msg = ( + " (cycle details available, but deferrability could not be determined)" + ) + else: + non_deferrable_msg = " (no cycle details available)" + raise CircularReferenceError( + "Circular dependency requires deferrable constraints for SQL fallback.\n\n" + "Non-deferrable FKs in detected cycles:\n" + f"{non_deferrable_msg}\n\n" + "Make at least one FK nullable (preferred), or change cycle FKs to DEFERRABLE." + ) + raise CircularReferenceError(str(e)) modified_deps = {t: set(deps) for t, deps in dependencies.items()} @@ -835,7 +916,7 @@ def _topological_sort(self, tables: set[str]) -> tuple[list[str], list[Any], lis ts = TopologicalSorter(modified_deps) insert_order = list(ts.static_order()) - return insert_order, fks_to_break, cycle_infos + return insert_order, fks_to_break, cycle_infos, False def extract_subset( diff --git a/src/dbslice/models.py b/src/dbslice/models.py index b0025b4..24eeeaf 100644 --- a/src/dbslice/models.py +++ b/src/dbslice/models.py @@ -26,6 +26,7 @@ class ForeignKey: target_table: str target_columns: tuple[str, ...] is_nullable: bool + is_deferrable: bool = False def __hash__(self) -> int: """Hash for use in sets and as dict keys.""" @@ -89,6 +90,7 @@ def to_foreign_key(self) -> ForeignKey: target_table=self.target_table, target_columns=self.target_columns, is_nullable=self.is_nullable, + is_deferrable=False, ) diff --git a/tests/test_cli_env.py b/tests/test_cli_env.py index 5efde19..55d7ee2 100644 --- a/tests/test_cli_env.py +++ b/tests/test_cli_env.py @@ -45,6 +45,7 @@ def test_extract_cli_flags_override_env(monkeypatch, capture_extract): monkeypatch.setenv("DBSLICE_OUTPUT_FORMAT", "json") monkeypatch.setenv("DBSLICE_ANONYMIZE", "false") monkeypatch.setenv("DBSLICE_REDACT_FIELDS", "users.email,users.phone") + monkeypatch.setenv("DBSLICE_ALLOW_UNSAFE_WHERE", "false") cli.extract( database_url="postgresql://cli_user:cli_pass@localhost:5432/clidb", @@ -54,6 +55,7 @@ def test_extract_cli_flags_override_env(monkeypatch, capture_extract): output="sql", anonymize=True, redact=["users.custom_field"], + allow_unsafe_where=True, no_progress=True, ) @@ -63,6 +65,7 @@ def test_extract_cli_flags_override_env(monkeypatch, capture_extract): assert extract_config.output_format == OutputFormat.SQL assert extract_config.anonymize is True assert extract_config.redact_fields == ["users.custom_field"] + assert extract_config.allow_unsafe_where is True def test_extract_env_overrides_config_when_cli_missing(monkeypatch, tmp_path, capture_extract): @@ -75,6 +78,7 @@ def test_extract_env_overrides_config_when_cli_missing(monkeypatch, tmp_path, ca "extraction:", " default_depth: 7", " direction: down", + " allow_unsafe_where: false", "anonymization:", " enabled: false", "output:", @@ -90,6 +94,7 @@ def test_extract_env_overrides_config_when_cli_missing(monkeypatch, tmp_path, ca monkeypatch.setenv("DBSLICE_OUTPUT_FORMAT", "json") monkeypatch.setenv("DBSLICE_ANONYMIZE", "true") monkeypatch.setenv("DBSLICE_REDACT_FIELDS", "users.ssn,users.passport") + monkeypatch.setenv("DBSLICE_ALLOW_UNSAFE_WHERE", "true") cli.extract( config=Path(config_path), @@ -104,6 +109,7 @@ def test_extract_env_overrides_config_when_cli_missing(monkeypatch, tmp_path, ca assert extract_config.output_format == OutputFormat.JSON assert extract_config.anonymize is True assert extract_config.redact_fields == ["users.ssn", "users.passport"] + assert extract_config.allow_unsafe_where is True def test_extract_invalid_env_value_fails_fast(monkeypatch, capsys): diff --git a/tests/test_cli_inspect_helpers.py b/tests/test_cli_inspect_helpers.py new file mode 100644 index 0000000..7ac44a4 --- /dev/null +++ b/tests/test_cli_inspect_helpers.py @@ -0,0 +1,61 @@ +"""Tests for inspect helper heuristics.""" + +from dbslice.cli import _detect_potential_implicit_fks +from dbslice.models import Column, ForeignKey, SchemaGraph, Table + + +def _table(name: str, columns: list[Column], pk: tuple[str, ...] = ("id",)) -> Table: + return Table( + name=name, + schema="public", + columns=columns, + primary_key=pk, + foreign_keys=[], + ) + + +def test_detect_potential_implicit_fks_suggests_missing_user_id_relationship(): + users = _table( + "users", + [Column(name="id", data_type="integer", nullable=False, is_primary_key=True)], + ) + audit_log = _table( + "audit_log", + [ + Column(name="id", data_type="integer", nullable=False, is_primary_key=True), + Column(name="user_id", data_type="integer", nullable=False, is_primary_key=False), + ], + ) + schema = SchemaGraph(tables={"users": users, "audit_log": audit_log}, edges=[]) + + candidates = _detect_potential_implicit_fks(schema) + + assert ("audit_log", "user_id", "users") in candidates + + +def test_detect_potential_implicit_fks_skips_columns_with_real_fk(): + users = _table( + "users", + [Column(name="id", data_type="integer", nullable=False, is_primary_key=True)], + ) + orders = _table( + "orders", + [ + Column(name="id", data_type="integer", nullable=False, is_primary_key=True), + Column(name="user_id", data_type="integer", nullable=False, is_primary_key=False), + ], + ) + fk = ForeignKey( + name="fk_orders_users", + source_table="orders", + source_columns=("user_id",), + target_table="users", + target_columns=("id",), + is_nullable=False, + ) + orders.foreign_keys.append(fk) + schema = SchemaGraph(tables={"users": users, "orders": orders}, edges=[fk]) + + candidates = _detect_potential_implicit_fks(schema) + + assert ("orders", "user_id", "users") not in candidates diff --git a/tests/test_config_file.py b/tests/test_config_file.py index 832c2be..6d0c924 100644 --- a/tests/test_config_file.py +++ b/tests/test_config_file.py @@ -57,6 +57,7 @@ def test_default_values(self): assert config.validate is True assert config.fail_on_validation_error is False assert config.max_rows_per_table is None + assert config.allow_unsafe_where is False def test_custom_values(self): config = ExtractionConfig( @@ -492,6 +493,23 @@ def test_from_yaml_invalid_direction(self): finally: Path(temp_path).unlink() + def test_from_yaml_invalid_allow_unsafe_where_type(self): + yaml_content = """ +extraction: + allow_unsafe_where: "yes" +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write(yaml_content) + f.flush() + temp_path = f.name + + try: + with pytest.raises(ConfigFileError) as exc_info: + DbsliceConfig.from_yaml(temp_path) + assert "allow_unsafe_where" in str(exc_info.value) + finally: + Path(temp_path).unlink() + def test_from_yaml_invalid_output_format(self): yaml_content = """ output: @@ -1026,6 +1044,24 @@ def test_extraction_validation_defaults_come_from_config(self): assert extract_config.validate is False assert extract_config.fail_on_validation_error is True + def test_allow_unsafe_where_propagates_from_config(self): + config = DbsliceConfig( + database=DatabaseConfig(url="postgres://localhost/test"), + extraction=ExtractionConfig(allow_unsafe_where=True), + ) + seeds = [SeedSpec.parse("users.id=1")] + extract_config = config.to_extract_config(seeds=seeds) + assert extract_config.allow_unsafe_where is True + + def test_allow_unsafe_where_cli_override_wins(self): + config = DbsliceConfig( + database=DatabaseConfig(url="postgres://localhost/test"), + extraction=ExtractionConfig(allow_unsafe_where=False), + ) + seeds = [SeedSpec.parse("users.id=1")] + extract_config = config.to_extract_config(seeds=seeds, allow_unsafe_where=True) + assert extract_config.allow_unsafe_where is True + class TestToYaml: """Tests for exporting config to YAML.""" diff --git a/tests/test_cycle_fallback.py b/tests/test_cycle_fallback.py new file mode 100644 index 0000000..cf88cea --- /dev/null +++ b/tests/test_cycle_fallback.py @@ -0,0 +1,113 @@ +"""Tests for non-nullable cycle fallback behavior.""" + +import pytest + +from dbslice.config import DatabaseType, ExtractConfig, OutputFormat, SeedSpec +from dbslice.core.engine import ExtractionEngine +from dbslice.exceptions import CircularReferenceError +from dbslice.models import Column, ForeignKey, SchemaGraph, Table + + +def _make_cycle_schema(*, deferrable: bool) -> SchemaGraph: + table_a = Table( + name="a", + schema="public", + columns=[ + Column(name="id", data_type="integer", nullable=False, is_primary_key=True), + Column(name="b_id", data_type="integer", nullable=False, is_primary_key=False), + ], + primary_key=("id",), + foreign_keys=[], + ) + table_b = Table( + name="b", + schema="public", + columns=[ + Column(name="id", data_type="integer", nullable=False, is_primary_key=True), + Column(name="a_id", data_type="integer", nullable=False, is_primary_key=False), + ], + primary_key=("id",), + foreign_keys=[], + ) + + fk_a_to_b = ForeignKey( + name="fk_a_b", + source_table="a", + source_columns=("b_id",), + target_table="b", + target_columns=("id",), + is_nullable=False, + is_deferrable=deferrable, + ) + fk_b_to_a = ForeignKey( + name="fk_b_a", + source_table="b", + source_columns=("a_id",), + target_table="a", + target_columns=("id",), + is_nullable=False, + is_deferrable=deferrable, + ) + + table_a.foreign_keys.append(fk_a_to_b) + table_b.foreign_keys.append(fk_b_to_a) + + return SchemaGraph( + tables={"a": table_a, "b": table_b}, + edges=[fk_a_to_b, fk_b_to_a], + ) + + +def test_non_sql_format_uses_deterministic_cycle_fallback(): + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("a.id=1")], + output_format=OutputFormat.JSON, + ) + engine = ExtractionEngine(config) + engine.schema = _make_cycle_schema(deferrable=False) + + insert_order, broken_fks, cycle_infos, used_deferred_cycle_strategy = engine._topological_sort( + {"a", "b"}, + DatabaseType.POSTGRESQL, + ) + + assert insert_order == ["a", "b"] + assert broken_fks == [] + assert len(cycle_infos) == 1 + assert used_deferred_cycle_strategy is False + + +def test_sql_disable_fk_checks_uses_deferred_cycle_fallback_when_deferrable(): + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("a.id=1")], + output_format=OutputFormat.SQL, + disable_fk_checks=True, + ) + engine = ExtractionEngine(config) + engine.schema = _make_cycle_schema(deferrable=True) + + insert_order, broken_fks, cycle_infos, used_deferred_cycle_strategy = engine._topological_sort( + {"a", "b"}, + DatabaseType.POSTGRESQL, + ) + + assert insert_order == ["a", "b"] + assert broken_fks == [] + assert len(cycle_infos) == 1 + assert used_deferred_cycle_strategy is True + + +def test_sql_disable_fk_checks_still_fails_for_non_deferrable_cycles(): + config = ExtractConfig( + database_url="postgresql://localhost/test", + seeds=[SeedSpec.parse("a.id=1")], + output_format=OutputFormat.SQL, + disable_fk_checks=True, + ) + engine = ExtractionEngine(config) + engine.schema = _make_cycle_schema(deferrable=False) + + with pytest.raises(CircularReferenceError, match="deferrable constraints"): + engine._topological_sort({"a", "b"}, DatabaseType.POSTGRESQL) diff --git a/tests/test_performance.py b/tests/test_performance.py index 5cf0c3b..f235b39 100644 --- a/tests/test_performance.py +++ b/tests/test_performance.py @@ -225,7 +225,7 @@ def test_engine_passes_configured_batch_size_to_adapter(monkeypatch): captured: dict[str, int | None] = {} class FakePostgreSQLAdapter: - def __init__(self, batch_size=None, profiler=None, schema=None): + def __init__(self, batch_size=None, profiler=None, schema=None, allow_unsafe_where=False): captured["batch_size"] = batch_size def connect(self, url: str) -> None: diff --git a/tests/test_security.py b/tests/test_security.py index e390926..9506093 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -258,6 +258,23 @@ def test_subquery_with_spaces(self): with pytest.raises(InsecureWhereClauseError): validate_where_clause("id = ( SELECT id FROM users )") + def test_subquery_allowed_when_explicitly_opted_in(self): + """Subqueries can be enabled for trusted complex filters.""" + validate_where_clause( + "id IN (SELECT user_id FROM orders WHERE total > 100)", + allow_unsafe_subqueries=True, + ) + + def test_seedspec_subquery_allowed_when_opted_in(self): + """SeedSpec parser supports explicit unsafe-subquery opt-in.""" + seed = SeedSpec.parse( + "users:id IN (SELECT user_id FROM orders WHERE total > 100)", + allow_unsafe_subqueries=True, + ) + where, params = seed.to_where_clause(allow_unsafe_subqueries=True) + assert "SELECT user_id FROM orders" in where + assert params == () + class TestNestedQuotingBypass: """Test that unbalanced and tricky quoting doesn't bypass validation.""" diff --git a/tests/test_virtual_fks.py b/tests/test_virtual_fks.py index 41781ae..ff0c0cd 100644 --- a/tests/test_virtual_fks.py +++ b/tests/test_virtual_fks.py @@ -241,6 +241,7 @@ def test_virtual_fk_to_foreign_key(self): assert fk.target_table == vfk.target_table assert fk.target_columns == vfk.target_columns assert fk.is_nullable == vfk.is_nullable + assert fk.is_deferrable is False def test_virtual_fk_hash(self): vfk1 = VirtualForeignKey(