Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ services:
datacenter:
ipv4_address: 10.5.0.5
healthcheck:
test: ["CMD-SHELL", "pg_isready"]
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 10s
timeout: 5s
retries: 5
Expand Down
308 changes: 298 additions & 10 deletions pgbelt/cmd/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,30 @@
from pgbelt.config import get_config_async
from pgbelt.models.base import CommandError
from pgbelt.models.base import CommandResult
from pgbelt.models.connectivity import ConnectivityCheckResult
from pgbelt.models.connectivity import ConnectivityCheckRow
from pgbelt.models.connections import ConnectionsResult
from pgbelt.models.connections import ConnectionsRow
from pgbelt.models.connections import ConnectionsSide
from pgbelt.models.preflight import ExtensionInfo
from pgbelt.models.preflight import PrecheckResult
from pgbelt.models.preflight import PrecheckSide
from pgbelt.models.preflight import RelationInfo
from pgbelt.models.preflight import RoleInfo
from pgbelt.models.preflight import TableReplicationInfo
from pgbelt.models.schema import CreateIndexesResult
from pgbelt.models.schema import DiffSchemaRow
from pgbelt.models.schema import DiffSchemasResult
from pgbelt.models.schema import IndexDetail
from pgbelt.models.status import ReplicationLag
from pgbelt.models.status import StatusResult
from pgbelt.models.status import StatusRow
from pgbelt.models.sync import SequenceSyncDetail
from pgbelt.models.sync import SyncSequencesResult
from pgbelt.models.sync import SyncTablesResult
from pgbelt.models.sync import TableSyncDetail
from pgbelt.models.sync import TableValidationDetail
from pgbelt.models.sync import ValidateDataResult
from typer import Argument
from typer import Option
from typer import Typer
Expand All @@ -24,6 +48,262 @@
T = TypeVar("T")


def _build_connectivity_result(
results: list[dict], base_kwargs: dict
) -> ConnectivityCheckResult:
rows = [ConnectivityCheckRow(**r) for r in results if isinstance(r, dict)]
return ConnectivityCheckResult(
success=all(r.all_ok for r in rows),
results=rows,
**base_kwargs,
)


def _build_connections_result(
results: list[dict], base_kwargs: dict
) -> ConnectionsResult:
rows = []
for r in results:
if not isinstance(r, dict):
continue
rows.append(
ConnectionsRow(
db=r["db"],
source=ConnectionsSide(
total_connections=r["src_count"],
by_user=r.get("src_usernames", {}),
),
destination=ConnectionsSide(
total_connections=r["dst_count"],
by_user=r.get("dst_usernames", {}),
),
)
)
return ConnectionsResult(success=True, results=rows, **base_kwargs)


def _build_status_result(results: list[dict], base_kwargs: dict) -> StatusResult:
rows = []
for r in results:
if not isinstance(r, dict):
continue
lag = ReplicationLag(
sent_lag=r.get("sent_lag", "unknown"),
write_lag=r.get("write_lag", "unknown"),
flush_lag=r.get("flush_lag", "unknown"),
replay_lag=r.get("replay_lag", "unknown"),
)
rows.append(
StatusRow(
db=r.get("db", ""),
forward_replication=r.get("pg1_pg2", "unconfigured"),
back_replication=r.get("pg2_pg1", "unconfigured"),
lag=lag,
src_dataset_size=r.get("src_dataset_size"),
dst_dataset_size=r.get("dst_dataset_size"),
progress=r.get("progress"),
)
)
return StatusResult(success=True, results=rows, **base_kwargs)


def _build_precheck_side(raw: dict, pkeys: list | None = None) -> PrecheckSide:
"""Convert the raw dict from precheck_info into a PrecheckSide model."""
users_raw = raw.get("users", {})
root_raw = users_raw.get("root", {})
owner_raw = users_raw.get("owner", {})

tables_raw = raw.get("tables", [])
pkey_names = set(pkeys) if pkeys else set()

tables = []
for t in tables_raw:
name = t.get("Name", "")
schema = t.get("Schema", "")
owner = t.get("Owner", "")
has_pk = name in pkey_names
can_rep = schema == raw.get("schema", "") and owner == owner_raw.get(
"rolname", ""
)
if can_rep:
method = "pglogical" if has_pk else "dump_and_load"
else:
method = "unavailable"
tables.append(
TableReplicationInfo(
name=name,
schema_name=schema,
owner=owner,
has_primary_key=has_pk,
replication_method=method,
)
)

sequences = [
RelationInfo(
name=s.get("Name", ""),
schema_name=s.get("Schema", ""),
owner=s.get("Owner", ""),
object_type="sequence",
)
for s in raw.get("sequences", [])
]

extensions = [
ExtensionInfo(extname=e["extname"] if isinstance(e, dict) else e)
for e in raw.get("extensions", [])
]

return PrecheckSide(
db=raw.get("db", ""),
schema_name=raw.get("schema", "public"),
server_version=raw.get("server_version", "unknown"),
max_replication_slots=raw.get("max_replication_slots", "0"),
max_worker_processes=raw.get("max_worker_processes", "0"),
max_wal_senders=raw.get("max_wal_senders", "0"),
shared_preload_libraries=raw.get("shared_preload_libraries", []),
rds_logical_replication=raw.get("rds.logical_replication", "unknown"),
root_user=RoleInfo(
rolname=root_raw.get("rolname", ""),
rolcanlogin=root_raw.get("rolcanlogin", False),
rolcreaterole=root_raw.get("rolcreaterole", False),
rolinherit=root_raw.get("rolinherit", False),
rolsuper=root_raw.get("rolsuper", False),
memberof=root_raw.get("memberof", []),
),
owner_user=RoleInfo(
rolname=owner_raw.get("rolname", ""),
rolcanlogin=owner_raw.get("rolcanlogin", False),
rolcreaterole=owner_raw.get("rolcreaterole", False),
rolinherit=owner_raw.get("rolinherit", False),
rolsuper=owner_raw.get("rolsuper", False),
memberof=owner_raw.get("memberof", []),
can_create=owner_raw.get("can_create"),
),
tables=tables,
sequences=sequences,
extensions=extensions,
)


def _build_precheck_result(results: list[dict], base_kwargs: dict) -> PrecheckResult:
if len(results) == 1 and isinstance(results[0], dict):
raw = results[0]
src_raw = raw.get("src", {})
dst_raw = raw.get("dst", {})
src_side = _build_precheck_side(src_raw, pkeys=src_raw.get("pkeys"))
dst_side = _build_precheck_side(dst_raw)

src_ext_names = {e.extname for e in src_side.extensions}
for ext in dst_side.extensions:
ext.in_other_side = ext.extname in src_ext_names
dst_ext_names = {e.extname for e in dst_side.extensions}
for ext in src_side.extensions:
ext.in_other_side = ext.extname in dst_ext_names

return PrecheckResult(success=True, src=src_side, dst=dst_side, **base_kwargs)

# Multi-DB: store raw dicts since we get separate src/dst per DB
return PrecheckResult(success=True, **base_kwargs)


def _build_sync_sequences_result(
results: list[dict], base_kwargs: dict
) -> SyncSequencesResult:
if len(results) == 1 and isinstance(results[0], dict):
r = results[0]
return SyncSequencesResult(
success=True,
schema_name=r.get("schema_name"),
stride=r.get("stride"),
pk_sequences=[SequenceSyncDetail(**s) for s in r.get("pk_sequences", [])],
non_pk_sequences=[
SequenceSyncDetail(**s) for s in r.get("non_pk_sequences", [])
],
**base_kwargs,
)
return SyncSequencesResult(success=True, **base_kwargs)


def _build_sync_tables_result(
results: list[dict], base_kwargs: dict
) -> SyncTablesResult:
if len(results) == 1 and isinstance(results[0], dict):
r = results[0]
return SyncTablesResult(
success=True,
schema_name=r.get("schema_name"),
discovery_mode=r.get("discovery_mode", "auto"),
tables=[TableSyncDetail(**t) for t in r.get("tables", [])],
**base_kwargs,
)
return SyncTablesResult(success=True, **base_kwargs)


def _build_validate_data_result(
results: list[dict], base_kwargs: dict
) -> ValidateDataResult:
if len(results) == 1 and isinstance(results[0], dict):
r = results[0]
return ValidateDataResult(
success=all(t.get("passed", True) for t in r.get("tables", [])),
schema_name=r.get("schema_name"),
tables=[TableValidationDetail(**t) for t in r.get("tables", [])],
**base_kwargs,
)
return ValidateDataResult(success=True, **base_kwargs)


def _build_create_indexes_result(
results: list[dict], base_kwargs: dict
) -> CreateIndexesResult:
if len(results) == 1 and isinstance(results[0], dict):
r = results[0]
indexes = [IndexDetail(**i) for i in r.get("indexes", [])]
has_failures = any(i.status == "failed" for i in indexes)
return CreateIndexesResult(
success=not has_failures,
indexes_file=r.get("indexes_file"),
indexes=indexes,
analyze_ran=r.get("analyze_ran", False),
**base_kwargs,
)
return CreateIndexesResult(success=True, **base_kwargs)


def _build_diff_schemas_result(
results: list[dict], base_kwargs: dict
) -> DiffSchemasResult:
rows = [
DiffSchemaRow(
db=r.get("db", ""),
result=r.get("result", "skipped"),
diff=r.get("diff"),
)
for r in results
if isinstance(r, dict)
]
has_mismatch = any(r.result == "mismatch" for r in rows)
return DiffSchemasResult(
success=not has_mismatch,
results=rows,
**base_kwargs,
)


_RICH_MODEL_BUILDERS: dict[str, Callable] = {
"check-connectivity": _build_connectivity_result,
"connections": _build_connections_result,
"status": _build_status_result,
"precheck": _build_precheck_result,
"sync-sequences": _build_sync_sequences_result,
"sync-tables": _build_sync_tables_result,
"validate-data": _build_validate_data_result,
"create-indexes": _build_create_indexes_result,
"diff-schemas": _build_diff_schemas_result,
}


def _build_json_output(
command_name: str,
dc: str,
Expand All @@ -41,21 +321,29 @@ def _build_json_output(
message=str(error),
)

detail = {}
if results and len(results) == 1 and isinstance(results[0], dict):
detail = results[0]
elif results and all(isinstance(r, dict) for r in results):
detail = {"databases": results}

result = CommandResult(
base_kwargs = dict(
db=db or dc,
dc=dc,
command=command_name,
success=success,
duration_ms=duration_ms,
error=cmd_error,
detail=detail,
)

builder = _RICH_MODEL_BUILDERS.get(command_name)
if builder and not cmd_error:
result = builder(results, base_kwargs)
else:
detail = {}
if results and len(results) == 1 and isinstance(results[0], dict):
detail = results[0]
elif results and all(isinstance(r, dict) for r in results):
detail = {"databases": results}
result = CommandResult(
command=command_name,
success=success,
detail=detail,
**base_kwargs,
)

return result.model_dump_json(indent=2)


Expand Down
21 changes: 17 additions & 4 deletions pgbelt/cmd/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
from typer import Option
from typer import style

from typing import Any

from pgbelt.cmd.helpers import run_with_configs
from pgbelt.config.models import DbupgradeConfig
from pgbelt.util.dump import apply_target_constraints
from pgbelt.util.dump import apply_target_schema
from pgbelt.util.dump import create_target_indexes
from pgbelt.util.dump import create_target_indexes_with_details
from pgbelt.util.dump import dump_source_schema
from pgbelt.util.dump import remove_dst_not_valid_constraints
from pgbelt.util.dump import remove_dst_indexes
from pgbelt.util.dump import schema_file
from pgbelt.util.dump import ONLY_INDEXES
from pgbelt.util.dump import validate_schema_dump
from pgbelt.util.logs import get_logger
from pgbelt.util.postgres import run_analyze
Expand Down Expand Up @@ -86,7 +90,9 @@ async def remove_indexes(config_future: Awaitable[DbupgradeConfig]) -> None:


@run_with_configs(skip_src=True)
async def create_indexes(config_future: Awaitable[DbupgradeConfig]) -> None:
async def create_indexes(
config_future: Awaitable[DbupgradeConfig],
) -> dict[str, Any] | None:
"""
Creates indexes from the file schemas/dc/db/indexes.sql into the destination
as the owner user. This must only be done after most data is synchronized
Expand All @@ -98,9 +104,10 @@ async def create_indexes(config_future: Awaitable[DbupgradeConfig]) -> None:
"""
conf = await config_future
logger = get_logger(conf.db, conf.dc, "schema.dst")
await create_target_indexes(conf, logger, during_sync=False)
index_details = await create_target_indexes_with_details(
conf, logger, during_sync=False
)

# Run ANALYZE after creating indexes (without statement timeout)
async with create_pool(
conf.dst.root_uri,
min_size=1,
Expand All @@ -110,6 +117,12 @@ async def create_indexes(config_future: Awaitable[DbupgradeConfig]) -> None:
) as dst_pool:
await run_analyze(dst_pool, logger)

return {
"indexes_file": schema_file(conf.db, conf.dc, ONLY_INDEXES),
"indexes": index_details,
"analyze_ran": True,
}


async def _print_diff_table(results: list[dict[str, str]]) -> list[list[str]]:
table = [
Expand Down
Loading
Loading