diff --git a/django-backend/soroscan/ingest/serializers.py b/django-backend/soroscan/ingest/serializers.py index 165d6469..4c8ca9b0 100644 --- a/django-backend/soroscan/ingest/serializers.py +++ b/django-backend/soroscan/ingest/serializers.py @@ -1,6 +1,8 @@ """ DRF Serializers for SoroScan API. """ +import re + from rest_framework import serializers from django.utils.text import slugify @@ -22,6 +24,9 @@ WebhookSubscription, ) +_CONTRACT_ID_RE = re.compile(r"^C[A-Z2-7]{55}$") +_VALID_NETWORKS = {choice[0] for choice in TrackedContract.Network.choices} + class OrganizationSerializer(serializers.ModelSerializer): """Organization serializer with owner-managed tenancy settings.""" @@ -127,6 +132,11 @@ class TrackedContractSerializer(serializers.ModelSerializer): Used for creating, updating, and returning tracked Soroban smart contracts. """ + # Declare these as plain CharField to bypass model-level RegexValidator/UniqueValidator + # and choices validation so our validate_* methods control error messages entirely. + contract_id = serializers.CharField(validators=[]) + network = serializers.CharField(required=False, default=TrackedContract.Network.MAINNET) + event_count = serializers.SerializerMethodField() warnings = serializers.SerializerMethodField() team = serializers.PrimaryKeyRelatedField( @@ -143,6 +153,7 @@ class Meta: "name", "alias", "description", + "network", "abi_schema", "json_schema", "is_active", @@ -169,6 +180,34 @@ def get_warnings(self, obj) -> list[dict[str, str]]: warning = obj.deprecation_warning() return [warning] if warning else [] + def validate_contract_id(self, value: str) -> str: + value = value.strip() + + if not _CONTRACT_ID_RE.match(value): + raise serializers.ValidationError( + "Invalid contract address. A Soroban contract address must start " + "with 'C', be exactly 56 characters long, and use only uppercase " + "Base32 characters (A-Z and 2-7)." + ) + + # On create, reject duplicates with a clear message. + if self.instance is None: + if TrackedContract.objects.filter(contract_id=value).exists(): + raise serializers.ValidationError( + f"Contract '{value}' is already registered. " + "Each contract address can only be tracked once." + ) + + return value + + def validate_network(self, value: str) -> str: + if value not in _VALID_NETWORKS: + valid = ", ".join(sorted(_VALID_NETWORKS)) + raise serializers.ValidationError( + f"'{value}' is not a valid network. Choose one of: {valid}." + ) + return value + def validate_team(self, value): request = self.context.get("request") user = getattr(request, "user", None) diff --git a/django-backend/soroscan/ingest/tests/test_contract_validation.py b/django-backend/soroscan/ingest/tests/test_contract_validation.py index 70cafcac..a67aef0e 100644 --- a/django-backend/soroscan/ingest/tests/test_contract_validation.py +++ b/django-backend/soroscan/ingest/tests/test_contract_validation.py @@ -1,11 +1,12 @@ """ -Tests for Contract model validation rules (issue #). +Tests for Contract model and serializer validation rules (issue #590). """ import pytest from django.core.exceptions import ValidationError from soroscan.ingest.models import TrackedContract -from soroscan.ingest.tests.factories import UserFactory +from soroscan.ingest.serializers import TrackedContractSerializer +from soroscan.ingest.tests.factories import TrackedContractFactory, UserFactory @pytest.mark.django_db @@ -93,7 +94,7 @@ def test_empty_and_whitespace(self): "CABCDEFGHIJKLMNOPQRSTUVWXYZ234567ABCDEFGHIJKLMNOPQRST ", # trailing space " CABCDEFGHIJKLMNOPQRSTUVWXYZ234567ABCDEFGHIJKLMNOPQRST", # leading space ] - + for invalid_address in invalid_addresses: contract = TrackedContract( contract_id=invalid_address, @@ -103,3 +104,110 @@ def test_empty_and_whitespace(self): with pytest.raises(ValidationError) as exc: contract.full_clean() assert "contract_id" in exc.value.error_dict + + +# ── Serializer-level validation tests ───────────────────────────────────────── + +_VALID_CONTRACT_ID = "C" + "A" * 55 +_VALID_PAYLOAD = { + "contract_id": _VALID_CONTRACT_ID, + "name": "My Contract", + "network": "testnet", +} + + +def _serialize(data, instance=None): + return TrackedContractSerializer(instance=instance, data=data) + + +@pytest.mark.django_db +class TestTrackedContractSerializerValidation: + + # --- contract_id format --- + + def test_valid_contract_id_passes(self): + s = _serialize(_VALID_PAYLOAD) + assert s.is_valid(), s.errors + + def test_invalid_prefix_rejected_by_serializer(self): + data = {**_VALID_PAYLOAD, "contract_id": "G" + "A" * 55} + s = _serialize(data) + assert not s.is_valid() + assert "contract_id" in s.errors + assert "Soroban contract address" in str(s.errors["contract_id"]) + + def test_wrong_length_rejected_by_serializer(self): + for bad_id in ["C" + "A" * 54, "C" + "A" * 56]: + s = _serialize({**_VALID_PAYLOAD, "contract_id": bad_id}) + assert not s.is_valid() + assert "contract_id" in s.errors + + def test_invalid_charset_rejected_by_serializer(self): + bad_id = "C" + "0" * 55 # '0' is not in Base32 alphabet + s = _serialize({**_VALID_PAYLOAD, "contract_id": bad_id}) + assert not s.is_valid() + assert "contract_id" in s.errors + + def test_leading_whitespace_stripped_and_validated(self): + # Padded with spaces makes length wrong → should fail + bad_id = " " + "C" + "A" * 54 # 56 chars but leading space stripped → 55 chars + s = _serialize({**_VALID_PAYLOAD, "contract_id": bad_id}) + assert not s.is_valid() + assert "contract_id" in s.errors + + # --- duplicate check --- + + def test_duplicate_contract_id_rejected(self): + TrackedContractFactory(contract_id=_VALID_CONTRACT_ID) + s = _serialize(_VALID_PAYLOAD) + assert not s.is_valid() + assert "contract_id" in s.errors + assert "already registered" in str(s.errors["contract_id"]) + + def test_duplicate_check_skipped_on_update(self): + """Re-submitting the same contract_id on an update (PUT/PATCH) must not fail.""" + existing = TrackedContractFactory(contract_id=_VALID_CONTRACT_ID) + s = _serialize({**_VALID_PAYLOAD, "name": "Updated Name"}, instance=existing) + assert s.is_valid(), s.errors + + # --- network validity --- + + def test_valid_networks_accepted(self): + for net in ("mainnet", "testnet", "futurenet"): + s = _serialize({**_VALID_PAYLOAD, "contract_id": "C" + "B" * 55, "network": net}) + assert s.is_valid(), f"Expected {net} to be valid, got: {s.errors}" + + def test_invalid_network_rejected(self): + s = _serialize({**_VALID_PAYLOAD, "network": "devnet"}) + assert not s.is_valid() + assert "network" in s.errors + assert "valid network" in str(s.errors["network"]).lower() + + def test_empty_network_rejected(self): + s = _serialize({**_VALID_PAYLOAD, "network": ""}) + assert not s.is_valid() + assert "network" in s.errors + + # --- error message clarity --- + + def test_error_message_mentions_base32(self): + bad = {**_VALID_PAYLOAD, "contract_id": "XABCDE" + "A" * 50} + s = _serialize(bad) + assert not s.is_valid() + msg = str(s.errors["contract_id"]) + assert "Base32" in msg or "C" in msg + + def test_network_error_lists_valid_choices(self): + s = _serialize({**_VALID_PAYLOAD, "network": "unknown"}) + assert not s.is_valid() + msg = str(s.errors["network"]) + # At least one valid network name should appear in the error + assert any(n in msg for n in ("mainnet", "testnet", "futurenet")) + + # --- network field is now included in output --- + + def test_network_field_present_in_serialized_output(self): + contract = TrackedContractFactory(network="testnet") + s = TrackedContractSerializer(instance=contract) + assert "network" in s.data + assert s.data["network"] == "testnet"