diff --git a/django-backend/soroscan/ingest/management/commands/import_contracts.py b/django-backend/soroscan/ingest/management/commands/import_contracts.py new file mode 100644 index 00000000..2c83622f --- /dev/null +++ b/django-backend/soroscan/ingest/management/commands/import_contracts.py @@ -0,0 +1,154 @@ +""" +Management command: import_contracts + +Imports TrackedContract records from a JSON file containing address/name mappings. + +Usage: + python manage.py import_contracts --input contracts.json + python manage.py import_contracts --file contracts.json --owner admin +""" +import json + +from django.contrib.auth import get_user_model +from django.core.exceptions import ValidationError +from django.core.management.base import BaseCommand, CommandError + +from soroscan.ingest.models import TrackedContract + + +DEFAULT_IMPORT_USERNAME = "soroscan-import" + + +class Command(BaseCommand): + help = "Import tracked contracts from a JSON address/name mapping file." + + def add_arguments(self, parser): + parser.add_argument( + "--file", + "--input", + dest="file", + required=True, + help="Input JSON file path", + ) + parser.add_argument( + "--owner", + default=None, + help=( + "Username, email, or id of the owner for newly imported contracts. " + f"Defaults to a service user named {DEFAULT_IMPORT_USERNAME}." + ), + ) + + def handle(self, *args, **options): + path = options["file"] + + try: + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + except json.JSONDecodeError as exc: + raise CommandError(f"Invalid JSON in {path}: {exc}") from exc + except OSError as exc: + raise CommandError(f"Unable to read {path}: {exc}") from exc + + owner = self._resolve_owner(options["owner"]) + contracts = self._extract_contracts(data) + + created = 0 + skipped = 0 + + for contract in contracts: + _, was_created = TrackedContract.objects.get_or_create( + contract_id=contract["contract_id"], + defaults={ + "name": contract["name"], + "owner": owner, + }, + ) + if was_created: + created += 1 + else: + skipped += 1 + + self.stdout.write( + self.style.SUCCESS( + f"Imported contracts: created={created} skipped_existing={skipped}" + ) + ) + + def _resolve_owner(self, owner_lookup): + User = get_user_model() + + if not owner_lookup: + owner, _ = User.objects.get_or_create( + username=DEFAULT_IMPORT_USERNAME, + defaults={"email": "soroscan-import@example.com"}, + ) + return owner + + filters = [{"username": owner_lookup}, {"email": owner_lookup}] + if str(owner_lookup).isdigit(): + filters.insert(0, {"id": int(owner_lookup)}) + + for lookup in filters: + try: + return User.objects.get(**lookup) + except User.DoesNotExist: + continue + except User.MultipleObjectsReturned as exc: + raise CommandError( + f"Multiple users matched owner {owner_lookup!r}" + ) from exc + + raise CommandError(f"Owner user not found: {owner_lookup}") + + def _extract_contracts(self, data): + if isinstance(data, dict) and "contracts" in data: + raw_contracts = data["contracts"] + elif isinstance(data, dict): + raw_contracts = self._mapping_to_contracts(data) + elif isinstance(data, list): + raw_contracts = data + else: + raise CommandError( + "Invalid import format: expected a JSON object or list of contracts." + ) + + if isinstance(raw_contracts, dict): + raw_contracts = self._mapping_to_contracts(raw_contracts) + + if not isinstance(raw_contracts, list): + raise CommandError("Invalid import format: contracts must be a list.") + + contracts = [] + for index, item in enumerate(raw_contracts, start=1): + try: + contract = self._normalize_contract(item) + except (TypeError, ValueError, ValidationError) as exc: + raise CommandError(f"Invalid contract at row {index}: {exc}") from exc + contracts.append(contract) + + return contracts + + @staticmethod + def _mapping_to_contracts(mapping): + return [ + {"contract_id": contract_id, "name": name} + for contract_id, name in mapping.items() + ] + + def _normalize_contract(self, item): + if not isinstance(item, dict): + raise TypeError("expected an object with contract_id and name") + + contract_id = str(item.get("contract_id") or item.get("address") or "").strip() + name = str(item.get("name") or "").strip() + + if not contract_id: + raise ValueError("missing contract_id") + if not name: + raise ValueError("missing name") + + validator = TrackedContract(contract_id=contract_id, name=name) + validator.full_clean(exclude=["owner"], validate_unique=False) + + return {"contract_id": contract_id, "name": name} diff --git a/django-backend/soroscan/ingest/tests/test_import_contracts_command.py b/django-backend/soroscan/ingest/tests/test_import_contracts_command.py new file mode 100644 index 00000000..be255eef --- /dev/null +++ b/django-backend/soroscan/ingest/tests/test_import_contracts_command.py @@ -0,0 +1,119 @@ +import json + +import pytest +from django.contrib.auth import get_user_model +from django.core.management import call_command +from django.core.management.base import CommandError + +from soroscan.ingest.management.commands.import_contracts import ( + DEFAULT_IMPORT_USERNAME, +) +from soroscan.ingest.models import TrackedContract +from soroscan.ingest.tests.factories import TrackedContractFactory, UserFactory + + +def contract_id(n: int) -> str: + alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" + encoded = "".join(alphabet[n >> (5 * i) & 0x1F] for i in range(54, -1, -1)) + return f"C{encoded}" + + +@pytest.mark.django_db +class TestImportContractsCommand: + def test_imports_contracts_from_export_format(self, tmp_path): + path = tmp_path / "contracts.json" + path.write_text( + json.dumps( + { + "contracts": [ + {"contract_id": contract_id(1), "name": "Token"}, + {"contract_id": contract_id(2), "name": "Marketplace"}, + ] + } + ), + encoding="utf-8", + ) + + call_command("import_contracts", "--input", str(path)) + + assert TrackedContract.objects.count() == 2 + assert TrackedContract.objects.get(contract_id=contract_id(1)).name == "Token" + assert ( + TrackedContract.objects.get(contract_id=contract_id(2)).name + == "Marketplace" + ) + assert get_user_model().objects.filter(username=DEFAULT_IMPORT_USERNAME).exists() + + def test_skips_existing_contracts_without_updating(self, tmp_path): + existing = TrackedContractFactory(contract_id=contract_id(3), name="Original") + path = tmp_path / "contracts.json" + path.write_text( + json.dumps( + { + "contracts": [ + {"contract_id": existing.contract_id, "name": "Changed"}, + {"contract_id": contract_id(4), "name": "New Contract"}, + ] + } + ), + encoding="utf-8", + ) + + call_command("import_contracts", file=str(path)) + + existing.refresh_from_db() + assert existing.name == "Original" + assert TrackedContract.objects.count() == 2 + assert TrackedContract.objects.filter(contract_id=contract_id(4)).exists() + + def test_uses_explicit_owner_for_new_contracts(self, tmp_path): + owner = UserFactory(username="admin") + path = tmp_path / "contracts.json" + path.write_text( + json.dumps( + {"contracts": [{"contract_id": contract_id(5), "name": "Owned"}]} + ), + encoding="utf-8", + ) + + call_command("import_contracts", file=str(path), owner=owner.username) + + assert TrackedContract.objects.get(contract_id=contract_id(5)).owner == owner + + def test_imports_plain_address_name_mapping(self, tmp_path): + path = tmp_path / "contracts.json" + path.write_text( + json.dumps({contract_id(6): "Plain Mapping"}), + encoding="utf-8", + ) + + call_command("import_contracts", file=str(path)) + + assert ( + TrackedContract.objects.get(contract_id=contract_id(6)).name + == "Plain Mapping" + ) + + def test_imports_nested_address_name_mapping(self, tmp_path): + path = tmp_path / "contracts.json" + path.write_text( + json.dumps({"contracts": {contract_id(7): "Nested Mapping"}}), + encoding="utf-8", + ) + + call_command("import_contracts", file=str(path)) + + assert ( + TrackedContract.objects.get(contract_id=contract_id(7)).name + == "Nested Mapping" + ) + + def test_invalid_contract_raises_command_error(self, tmp_path): + path = tmp_path / "contracts.json" + path.write_text( + json.dumps({"contracts": [{"contract_id": "not-valid", "name": "Bad"}]}), + encoding="utf-8", + ) + + with pytest.raises(CommandError, match="Invalid contract at row 1"): + call_command("import_contracts", file=str(path))