Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/user-guide/cli-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ dbslice extract [OPTIONS] [DATABASE_URL]
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--seed`, `-s` | TEXT | *Required* | Seed record specification (repeatable) |
| `--allow-unsafe-where` / `--no-allow-unsafe-where` | FLAG | Disabled | Allow subqueries in seed WHERE clauses (trusted inputs only) |

**Seed Formats:**
- `table.column=value` - Simple equality (e.g., `orders.id=12345`)
Expand Down Expand Up @@ -468,6 +469,9 @@ Foreign Keys (23)

Self-references (potential cycles):
categories.parent_id

Potential implicit relationships:
audit_log.user_id -> users.id
```

##### Inspect Specific Table
Expand Down Expand Up @@ -570,10 +574,12 @@ Accepted formats:
|----------|-------------|---------|
| `DBSLICE_ANONYMIZE` | Enable anonymization | `true` |
| `DBSLICE_REDACT_FIELDS` | Comma-separated redact fields | `users.ssn,payments.card` |
| `DBSLICE_ALLOW_UNSAFE_WHERE` | Allow seed subqueries for advanced filters | `true` |

Accepted formats:
- `DBSLICE_ANONYMIZE`: `1/0`, `true/false`, `yes/no`, or `on/off` (case-insensitive).
- `DBSLICE_REDACT_FIELDS`: comma-separated `table.column` values.
- `DBSLICE_ALLOW_UNSAFE_WHERE`: `1/0`, `true/false`, `yes/no`, or `on/off` (case-insensitive).

### Examples

Expand Down
17 changes: 16 additions & 1 deletion docs/user-guide/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ extraction:
validate: boolean # Enable validation
fail_on_validation_error: boolean # Stop on validation errors
max_rows_per_table: integer # Optional global row soft-cap
allow_unsafe_where: boolean # Allow subqueries in seed WHERE clauses (trusted input only)
```

#### Fields
Expand All @@ -180,13 +181,19 @@ extraction:
| `validate` | Boolean | No | `true` | Validate extraction for referential integrity |
| `fail_on_validation_error` | Boolean | No | `false` | Stop execution if validation finds issues |
| `max_rows_per_table` | Integer | No | unlimited | Global per-table soft-cap with integrity closure |
| `allow_unsafe_where` | Boolean | No | `false` | Allow seed subqueries like `IN (SELECT ...)` for trusted inputs |

`max_rows_per_table` is deterministic and integrity-first:
- dbslice first caps each table deterministically by primary key sort.
- It then adds required parent rows so FK integrity is preserved.
- Parent closure may exceed the configured cap.
- If any row limit is configured, streaming mode is disabled automatically.

`allow_unsafe_where` notes:
- Default is `false` for security.
- When `true`, subqueries in seed WHERE clauses are allowed (for advanced filtering/join-style selection).
- Dangerous operations (`DROP`, `DELETE`, comments, stacked queries, etc.) are still blocked.

#### Examples

```yaml
Expand Down Expand Up @@ -217,6 +224,10 @@ extraction:
default_depth: 10
direction: up
validate: true

# Trusted advanced WHERE filters (subqueries)
extraction:
allow_unsafe_where: true
```

---
Expand Down Expand Up @@ -374,7 +385,7 @@ output:
| `format` | String | No | `"sql"` | Output format: `sql`, `json`, or `csv` |
| `include_transaction` | Boolean | No | `true` | Wrap SQL in BEGIN/COMMIT |
| `include_truncate` | Boolean | No | `false` | Include `TRUNCATE TABLE ... CASCADE` before inserts |
| `disable_fk_checks` | Boolean | No | `false` | Disable FK checks during import |
| `disable_fk_checks` | Boolean | No | `false` | For PostgreSQL SQL output, emits deferred-constraint statements and enables non-nullable cycle fallback when FKs are DEFERRABLE |
| `file_mode` | String/Octal | No | `"600"` | File permissions for generated outputs |
| `json_mode` | String | No | `"single"` | JSON mode: `single` or `per-table` |
| `json_pretty` | Boolean | No | `true` | Pretty-print JSON output |
Expand Down Expand Up @@ -402,6 +413,10 @@ output:

`include_drop_tables` is still accepted as a backward-compatible alias for `include_truncate`, but is deprecated.

Cycle note for PostgreSQL SQL imports:
- When cycles have no nullable FK, dbslice can still generate SQL if `disable_fk_checks: true` and cycle FKs are `DEFERRABLE`.
- If cycle FKs are not deferrable, extraction fails with a clear error.

```yaml
# JSON output (single file)
output:
Expand Down
23 changes: 20 additions & 3 deletions src/dbslice/adapters/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ def __init__(
batch_size: int | None = None,
profiler: Any = None,
schema: str | None = None,
allow_unsafe_where: bool = False,
):
self._conn: Any = None
self._schema_name = schema or "public"
self._schema_cache: SchemaGraph | None = None
self.batch_size = batch_size or self.DEFAULT_BATCH_SIZE
self.profiler = profiler
self.allow_unsafe_where = allow_unsafe_where

def connect(self, url: str) -> None:
"""Establish PostgreSQL connection."""
Expand Down Expand Up @@ -216,7 +218,8 @@ def _fetch_foreign_keys(self, schema: str) -> list[ForeignKey]:
a_source.attname AS source_column,
target_cls.relname AS target_table,
a_target.attname AS target_column,
NOT a_source.attnotnull AS is_nullable
NOT a_source.attnotnull AS is_nullable,
c.condeferrable AS is_deferrable
FROM pg_constraint c
JOIN pg_class source_cls ON c.conrelid = source_cls.oid
JOIN pg_class target_cls ON c.confrelid = target_cls.oid
Expand All @@ -239,7 +242,15 @@ def _fetch_foreign_keys(self, schema: str) -> list[ForeignKey]:
# Group by constraint name for multi-column FKs
fk_data: dict[str, dict] = {}
for row in cur.fetchall():
constraint_name, source_table, source_col, target_table, target_col, is_nullable = (
(
constraint_name,
source_table,
source_col,
target_table,
target_col,
is_nullable,
is_deferrable,
) = (
row
)

Expand All @@ -251,6 +262,7 @@ def _fetch_foreign_keys(self, schema: str) -> list[ForeignKey]:
"target_table": target_table,
"target_columns": [],
"is_nullable": bool(is_nullable),
"is_deferrable": bool(is_deferrable),
}

fk_data[constraint_name]["source_columns"].append(source_col)
Expand All @@ -265,6 +277,7 @@ def _fetch_foreign_keys(self, schema: str) -> list[ForeignKey]:
target_table=data["target_table"],
target_columns=tuple(data["target_columns"]),
is_nullable=data["is_nullable"],
is_deferrable=data["is_deferrable"],
)
)

Expand All @@ -291,7 +304,11 @@ def fetch_rows(
# Defense-in-depth: validate WHERE clause even if it was validated earlier
from dbslice.config import validate_where_clause

validate_where_clause(where_clause, f"{table}:{where_clause}")
validate_where_clause(
where_clause,
f"{table}:{where_clause}",
allow_unsafe_subqueries=self.allow_unsafe_where,
)

try:
query = f'SELECT * FROM "{table}" WHERE {where_clause}'
Expand Down
130 changes: 125 additions & 5 deletions src/dbslice/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def _parse_env_comma_list(var_name: str) -> list[str] | None:
def _parse_and_validate_seeds(
seeds: list[str],
console: Console,
allow_unsafe_subqueries: bool = False,
) -> list[SeedSpec]:
"""
Parse and validate seed specifications from CLI arguments.
Expand All @@ -203,7 +204,9 @@ def _parse_and_validate_seeds(
parsed_seeds = []
for s in seeds:
try:
parsed_seeds.append(SeedSpec.parse(s))
parsed_seeds.append(
SeedSpec.parse(s, allow_unsafe_subqueries=allow_unsafe_subqueries)
)
except ValueError as e:
raise InvalidSeedError(s, str(e))

Expand Down Expand Up @@ -270,6 +273,7 @@ def _build_extract_config(
stream_chunk_size: int,
output_file_mode: int,
schema: str | None = None,
allow_unsafe_where: bool = False,
) -> ExtractConfig:
"""
Build ExtractConfig from validated CLI parameters.
Expand Down Expand Up @@ -321,6 +325,7 @@ def _build_extract_config(
streaming_chunk_size=stream_chunk_size,
output_file_mode=output_file_mode,
schema=schema,
allow_unsafe_where=bool(allow_unsafe_where),
)


Expand Down Expand Up @@ -402,9 +407,14 @@ def _show_extraction_summary(

if result.has_cycles:
console.print()
console.print("[yellow]⚠ Circular dependencies detected and resolved[/yellow]")
console.print(f" Broken FKs: [cyan]{len(result.broken_fks)}[/cyan]")
console.print(f" Deferred UPDATEs: [cyan]{len(result.deferred_updates)}[/cyan]")
if result.used_deferred_cycle_strategy:
console.print("[yellow]⚠ Circular dependencies detected (deferred-constraint strategy)[/yellow]")
console.print(" Strategy: [cyan]Deterministic order + SET CONSTRAINTS ALL DEFERRED[/cyan]")
console.print(f" Cycles: [cyan]{len(result.cycle_infos)}[/cyan]")
else:
console.print("[yellow]⚠ Circular dependencies detected and resolved[/yellow]")
console.print(f" Broken FKs: [cyan]{len(result.broken_fks)}[/cyan]")
console.print(f" Deferred UPDATEs: [cyan]{len(result.deferred_updates)}[/cyan]")
if config.verbose:
for cycle_info in result.cycle_infos:
console.print(f" [dim]Cycle: {cycle_info}[/dim]")
Expand Down Expand Up @@ -966,6 +976,16 @@ def extract(
help="PostgreSQL schema name (default: 'public')",
),
] = None,
allow_unsafe_where: Annotated[
bool | None,
typer.Option(
"--allow-unsafe-where/--no-allow-unsafe-where",
help=(
"Allow seed WHERE subqueries (e.g. IN (SELECT ... JOIN ...)). "
"Use only with trusted inputs."
),
),
] = None,
):
"""
Extract a database subset starting from seed record(s).
Expand Down Expand Up @@ -1034,6 +1054,9 @@ def extract(
redact_override = redact
if redact_override is None:
redact_override = _parse_env_comma_list("DBSLICE_REDACT_FIELDS")
allow_unsafe_where_override = allow_unsafe_where
if allow_unsafe_where_override is None:
allow_unsafe_where_override = _parse_env_bool("DBSLICE_ALLOW_UNSAFE_WHERE")
except ValueError as e:
console.print(f"[red]Validation Error:[/red] {e}")
raise typer.Exit(1)
Expand Down Expand Up @@ -1076,6 +1099,16 @@ def extract(
)
raise typer.Exit(1)

effective_allow_unsafe_where = (
allow_unsafe_where_override
if allow_unsafe_where_override is not None
else (
loaded_config.extraction.allow_unsafe_where
if loaded_config is not None
else False
)
)

resolved_database_url = database_url_override
if not resolved_database_url and loaded_config:
resolved_database_url = loaded_config.database.url
Expand Down Expand Up @@ -1138,7 +1171,11 @@ def extract(
console.print(f"[red]Validation Error:[/red] {e}")
raise typer.Exit(1)

seed_specs = _parse_and_validate_seeds(seed or [], console)
seed_specs = _parse_and_validate_seeds(
seed or [],
console,
allow_unsafe_subqueries=effective_allow_unsafe_where,
)

if loaded_config:
direction_enum = (
Expand Down Expand Up @@ -1176,6 +1213,7 @@ def extract(
if output_file_mode is not None
else None,
schema=schema,
allow_unsafe_where=allow_unsafe_where_override,
)
output_format = extract_config.output_format
else:
Expand Down Expand Up @@ -1204,6 +1242,7 @@ def extract(
stream_chunk_size=effective_stream_chunk_size,
output_file_mode=effective_output_file_mode,
schema=schema,
allow_unsafe_where=effective_allow_unsafe_where,
)

if verbose and not no_progress:
Expand Down Expand Up @@ -1498,6 +1537,59 @@ def _detect_sensitive_fields(schema) -> dict[str, str]:
return detected


def _detect_potential_implicit_fks(schema) -> list[tuple[str, str, str]]:
"""
Detect likely implicit FK relationships using naming heuristics.

Returns:
List of (source_table, source_column, target_table) tuples.
"""
existing_fk_columns: set[tuple[str, str]] = set()
for fk in schema.edges:
for source_col in fk.source_columns:
existing_fk_columns.add((fk.source_table, source_col))

tables_by_lower = {name.lower(): name for name in schema.tables.keys()}

candidates: list[tuple[str, str, str]] = []

for table_name, table in schema.tables.items():
for column in table.columns:
if (table_name, column.name) in existing_fk_columns:
continue

col_lower = column.name.lower()
if not col_lower.endswith("_id"):
continue

base = col_lower[:-3]
if not base:
continue

guesses = [base, f"{base}s"]
if base.endswith("y") and len(base) > 1:
guesses.append(f"{base[:-1]}ies")
else:
guesses.append(f"{base}es")

target_table: str | None = None
for guess in guesses:
actual = tables_by_lower.get(guess)
if not actual or actual == table_name:
continue
target_info = schema.tables.get(actual)
if not target_info or not target_info.primary_key:
continue
if "id" in {pk.lower() for pk in target_info.primary_key}:
target_table = actual
break

if target_table:
candidates.append((table_name, column.name, target_table))

return sorted(candidates, key=lambda item: (item[0], item[1], item[2]))


@app.command()
def inspect(
database_url: Annotated[
Expand Down Expand Up @@ -1594,6 +1686,21 @@ def inspect(
f" [cyan]{child_table}[/cyan].{', '.join(fk.source_columns)}"
)

implicit_candidates = [
candidate
for candidate in _detect_potential_implicit_fks(db_schema)
if candidate[0] == table
]
if implicit_candidates:
console.print("\n [yellow]Potential implicit relationships:[/yellow]")
for src_table, src_col, target_table in implicit_candidates:
console.print(
f" {src_table}.{src_col} -> [cyan]{target_table}[/cyan].id"
)
console.print(
" [dim]Tip: add these as virtual_foreign_keys if they are real relationships.[/dim]"
)

else:
console.print(f"\n[bold]Tables ({len(db_schema.tables)})[/bold]")
for name in sorted(db_schema.tables.keys()):
Expand All @@ -1616,6 +1723,19 @@ def inspect(
for fk in self_refs:
console.print(f" {fk.source_table}.{', '.join(fk.source_columns)}")

implicit_candidates = _detect_potential_implicit_fks(db_schema)
if implicit_candidates:
console.print("\n[yellow]Potential implicit relationships:[/yellow]")
for src_table, src_col, target_table in implicit_candidates[:25]:
console.print(f" {src_table}.{src_col} -> [cyan]{target_table}[/cyan].id")
if len(implicit_candidates) > 25:
console.print(
f" [dim]... and {len(implicit_candidates) - 25} more[/dim]"
)
console.print(
" [dim]Tip: define virtual_foreign_keys for confirmed implicit links.[/dim]"
)

finally:
adapter.close()

Expand Down
Loading