From bbbeae503df6c48868a76ce67b61e656ca6987ba Mon Sep 17 00:00:00 2001 From: faisal-link Date: Tue, 7 Apr 2026 23:16:35 +0400 Subject: [PATCH 1/4] misc test and dummy contract updates --- .../ccip/sources/offramp_state_helper.move | 2 +- .../ccip/ccip/sources/receiver_registry.move | 24 +++++++++++---- .../tests/offramp_state_helper_tests.move | 1 + .../ccip/tests/receiver_registry_tests.move | 29 +++++++++++-------- .../sources/ccip_dummy_receiver.move | 3 +- 5 files changed, 40 insertions(+), 19 deletions(-) diff --git a/contracts/ccip/ccip/sources/offramp_state_helper.move b/contracts/ccip/ccip/sources/offramp_state_helper.move index dbd2ac86e..d4a1acfdd 100644 --- a/contracts/ccip/ccip/sources/offramp_state_helper.move +++ b/contracts/ccip/ccip/sources/offramp_state_helper.move @@ -233,7 +233,7 @@ public fun consume_any2sui_message( let receiver_package_id = address::from_ascii_bytes(&ascii::into_bytes(address_str)); let receiver_config = receiver_registry::get_receiver_config(ref, receiver_package_id); - let (_, proof_typename) = receiver_registry::get_receiver_config_fields(receiver_config); + let (_, proof_typename, _) = receiver_registry::get_receiver_config_fields(receiver_config); assert!(proof_typename == proof_tn.into_string(), ETypeProofMismatch); client::consume_any2sui_message(message, receiver_package_id) diff --git a/contracts/ccip/ccip/sources/receiver_registry.move b/contracts/ccip/ccip/sources/receiver_registry.move index dc0337906..dd7532961 100644 --- a/contracts/ccip/ccip/sources/receiver_registry.move +++ b/contracts/ccip/ccip/sources/receiver_registry.move @@ -13,6 +13,12 @@ use sui::linked_table::{Self, LinkedTable}; public struct ReceiverConfig has copy, drop, store { module_name: String, proof_typename: ascii::String, + /// The number of extra object IDs that the receiver's ccip_receive callback + /// expects beyond the standard 3 parameters (expected_message_id, + /// &CCIPObjectRef, Any2SuiMessage). The relayer uses this to validate that + /// the receiverObjectIds count in a CCIP message matches what the receiver + /// registered, preventing object injection attacks. + expected_receiver_object_id_count: u64, } public struct ReceiverRegistry has key, store { @@ -25,6 +31,7 @@ public struct ReceiverRegistered has copy, drop { receiver_package_id: address, receiver_module_name: String, proof_typename: ascii::String, + expected_receiver_object_id_count: u64, } public struct ReceiverUnregistered has copy, drop { @@ -57,6 +64,7 @@ public fun register_receiver( ref: &mut CCIPObjectRef, publisher_wrapper: PublisherWrapper, _proof: ProofType, + expected_receiver_object_id_count: u64, ) { verify_function_allowed( ref, @@ -73,6 +81,7 @@ public fun register_receiver( let receiver_config = ReceiverConfig { module_name: receiver_module_name, proof_typename: proof_typename.into_string(), + expected_receiver_object_id_count, }; registry.receiver_configs.push_back(receiver_package_id, receiver_config); @@ -80,6 +89,7 @@ public fun register_receiver( receiver_package_id, receiver_module_name, proof_typename: proof_typename.into_string(), + expected_receiver_object_id_count, }); } @@ -132,15 +142,15 @@ public fun get_receiver_config(ref: &CCIPObjectRef, receiver_package_id: address *registry.receiver_configs.borrow(receiver_package_id) } -public fun get_receiver_config_fields(rc: ReceiverConfig): (String, ascii::String) { - (rc.module_name, rc.proof_typename) +public fun get_receiver_config_fields(rc: ReceiverConfig): (String, ascii::String, u64) { + (rc.module_name, rc.proof_typename, rc.expected_receiver_object_id_count) } // this will return empty string if the receiver is not registered. public fun get_receiver_info( ref: &CCIPObjectRef, receiver_package_id: address, -): (String, ascii::String) { +): (String, ascii::String, u64) { verify_function_allowed( ref, string::utf8(b"receiver_registry"), @@ -151,8 +161,12 @@ public fun get_receiver_info( if (registry.receiver_configs.contains(receiver_package_id)) { let receiver_config = registry.receiver_configs.borrow(receiver_package_id); - return (receiver_config.module_name, receiver_config.proof_typename) + return ( + receiver_config.module_name, + receiver_config.proof_typename, + receiver_config.expected_receiver_object_id_count, + ) }; - (string::utf8(b""), ascii::string(b"")) + (string::utf8(b""), ascii::string(b""), 0) } diff --git a/contracts/ccip/ccip/tests/offramp_state_helper_tests.move b/contracts/ccip/ccip/tests/offramp_state_helper_tests.move index 0f69d4ef5..5cb8f4980 100644 --- a/contracts/ccip/ccip/tests/offramp_state_helper_tests.move +++ b/contracts/ccip/ccip/tests/offramp_state_helper_tests.move @@ -258,6 +258,7 @@ public fun test_extract_any2sui_message() { &mut ref, publisher_wrapper, TestTypeProof {}, + 0, ); package::burn_publisher(publisher); diff --git a/contracts/ccip/ccip/tests/receiver_registry_tests.move b/contracts/ccip/ccip/tests/receiver_registry_tests.move index 43d668de6..54e9d50ad 100644 --- a/contracts/ccip/ccip/tests/receiver_registry_tests.move +++ b/contracts/ccip/ccip/tests/receiver_registry_tests.move @@ -67,7 +67,7 @@ fun register_test_receiver( let publisher = package::test_claim(RECEIVER_REGISTRY_TESTS {}, ctx); let publisher_wrapper = publisher_wrapper::create(&publisher, proof); - receiver_registry::register_receiver(ref, publisher_wrapper, proof); + receiver_registry::register_receiver(ref, publisher_wrapper, proof, 0); package::burn_publisher(publisher); } @@ -114,12 +114,13 @@ public fun test_register_receiver() { // Get receiver config and verify fields let config = receiver_registry::get_receiver_config(&ref, package_id_1); - let (module_name, proof_typename) = receiver_registry::get_receiver_config_fields(config); + let (module_name, proof_typename, expected_count) = receiver_registry::get_receiver_config_fields(config); assert!(module_name == string::utf8(b"receiver_registry_tests")); assert!( proof_typename == type_name::into_string(type_name::with_defining_ids()), ); + assert!(expected_count == 0); cleanup_test(scenario, ref, owner_cap); } @@ -178,7 +179,7 @@ public fun test_register_multiple_receivers_same_package() { // Verify the config contains the first proof type let config = receiver_registry::get_receiver_config(&ref, package_id_1); - let (_, proof_type) = receiver_registry::get_receiver_config_fields(config); + let (_, proof_type, _) = receiver_registry::get_receiver_config_fields(config); assert!( proof_type == type_name::into_string(type_name::with_defining_ids()), @@ -278,13 +279,14 @@ public fun test_get_receiver_config() { // Get the config let package_id_1 = get_package_id_from_proof(); let config = receiver_registry::get_receiver_config(&ref, package_id_1); - let (module_name, proof_typename) = receiver_registry::get_receiver_config_fields(config); + let (module_name, proof_typename, expected_count) = receiver_registry::get_receiver_config_fields(config); // Verify all fields assert!(module_name == string::utf8(b"receiver_registry_tests")); assert!( proof_typename == type_name::into_string(type_name::with_defining_ids()), ); + assert!(expected_count == 0); cleanup_test(scenario, ref, owner_cap); } @@ -298,26 +300,27 @@ public fun test_get_receiver_module_and_state() { // Test unregistered receiver - should return empty values let package_id_1 = get_package_id_from_proof(); - let (module_name, proof_typename_str) = receiver_registry::get_receiver_info( + let (module_name, proof_typename_str, expected_count) = receiver_registry::get_receiver_info( &ref, package_id_1, ); assert!(module_name == string::utf8(b"")); assert!(proof_typename_str == ascii::string(b"")); + assert!(expected_count == 0); // Register a receiver register_test_receiver(&mut ref, TestReceiverProof {}, ctx); // Test registered receiver - should return actual values - let (module_name, proof_typename_str) = receiver_registry::get_receiver_info( + let (module_name, proof_typename_str, expected_count) = receiver_registry::get_receiver_info( &ref, package_id_1, ); assert!(module_name == string::utf8(b"receiver_registry_tests")); - // The proof typename string should contain the test receiver proof type assert!( proof_typename_str == type_name::into_string(type_name::with_defining_ids()), ); + assert!(expected_count == 0); cleanup_test(scenario, ref, owner_cap); } @@ -335,10 +338,10 @@ public fun test_register_receiver_with_zero_state_id() { // Verify the receiver is registered let package_id_1 = get_package_id_from_proof(); let config = receiver_registry::get_receiver_config(&ref, package_id_1); - let (_, _) = receiver_registry::get_receiver_config_fields(config); + let (_, _, _) = receiver_registry::get_receiver_config_fields(config); // Verify get_receiver_info returns correct values - let (module_name, proof_typename_str) = receiver_registry::get_receiver_info( + let (module_name, proof_typename_str, _) = receiver_registry::get_receiver_info( &ref, package_id_1, ); @@ -367,15 +370,16 @@ public fun test_complete_receiver_lifecycle() { // 3. Verify config is correct let config = receiver_registry::get_receiver_config(&ref, package_id_1); - let (module_name, proof_typename) = receiver_registry::get_receiver_config_fields(config); + let (module_name, proof_typename, expected_count) = receiver_registry::get_receiver_config_fields(config); assert!(module_name == string::utf8(b"receiver_registry_tests")); assert!( proof_typename == type_name::into_string(type_name::with_defining_ids()), ); + assert!(expected_count == 0); // 4. Verify module and proof typename lookup - let (lookup_module, lookup_proof_typename_str) = receiver_registry::get_receiver_info( + let (lookup_module, lookup_proof_typename_str, _) = receiver_registry::get_receiver_info( &ref, package_id_1, ); @@ -387,12 +391,13 @@ public fun test_complete_receiver_lifecycle() { assert!(!receiver_registry::is_registered_receiver(&ref, package_id_1)); // 6. Verify lookup returns empty values after unregistration - let (empty_module, empty_proof_typename_str) = receiver_registry::get_receiver_info( + let (empty_module, empty_proof_typename_str, empty_count) = receiver_registry::get_receiver_info( &ref, package_id_1, ); assert!(empty_module == string::utf8(b"")); assert!(empty_proof_typename_str == ascii::string(b"")); + assert!(empty_count == 0); cleanup_test(scenario, ref, owner_cap); } diff --git a/contracts/ccip/ccip_dummy_receiver/sources/ccip_dummy_receiver.move b/contracts/ccip/ccip_dummy_receiver/sources/ccip_dummy_receiver.move index e0ed0322e..9ea6ab139 100644 --- a/contracts/ccip/ccip_dummy_receiver/sources/ccip_dummy_receiver.move +++ b/contracts/ccip/ccip_dummy_receiver/sources/ccip_dummy_receiver.move @@ -87,7 +87,8 @@ fun init(otw: DUMMY_RECEIVER, ctx: &mut TxContext) { public fun register_receiver(owner_cap: &OwnerCap, ref: &mut CCIPObjectRef) { let publisher: &Publisher = df::borrow(&owner_cap.id, PublisherKey {}); let publisher_wrapper = publisher_wrapper::create(publisher, DummyReceiverProof {}); - receiver_registry::register_receiver(ref, publisher_wrapper, DummyReceiverProof {}); + // 2 extra object IDs: &Clock and &mut CCIPReceiverState + receiver_registry::register_receiver(ref, publisher_wrapper, DummyReceiverProof {}, 2); } public fun get_counter(state: &CCIPReceiverState): u64 { From 7b375be639681fff3de1ef8c8f24240a2fff5bc4 Mon Sep 17 00:00:00 2001 From: faisal-link Date: Tue, 7 Apr 2026 23:16:50 +0400 Subject: [PATCH 2/4] extend relayer executor tests --- relayer/chainwriter/ptb/offramp/execute.go | 20 ++ .../ptb/offramp/receiver_validation.go | 146 ++++++++ .../ptb/offramp/receiver_validation_test.go | 331 ++++++++++++++++++ 3 files changed, 497 insertions(+) create mode 100644 relayer/chainwriter/ptb/offramp/receiver_validation.go create mode 100644 relayer/chainwriter/ptb/offramp/receiver_validation_test.go diff --git a/relayer/chainwriter/ptb/offramp/execute.go b/relayer/chainwriter/ptb/offramp/execute.go index 56636229c..1b4b354a4 100644 --- a/relayer/chainwriter/ptb/offramp/execute.go +++ b/relayer/chainwriter/ptb/offramp/execute.go @@ -489,6 +489,16 @@ func AppendPTBCommandForReceiver( return nil, fmt.Errorf("failed to decode parameters for token pool function: %w", err) } + if err := ValidateReceiverCallbackSignature( + lggr, + functionSignature.(map[string]any), + paramTypes, + addressMappings.CcipPackageId, + addressMappings.OffRampPackageId, + ); err != nil { + return nil, fmt.Errorf("receiver callback validation failed for %s::%s: %w", moduleId, functionName, err) + } + lggr.Debugw("calling receiver", "paramTypes", paramTypes, "paramValues", paramValues) // Append extra args to the paramValues for the receiver call. @@ -515,11 +525,21 @@ func AppendPTBCommandForReceiver( lggr.Error("unexpected receiverObjectIds type", "type", fmt.Sprintf("%T", receiverObjectIds)) } + if err := ValidateReceiverObjectIdCount(paramTypes, len(extraArgsValues)); err != nil { + return nil, fmt.Errorf("receiver %s::%s: %w", moduleId, functionName, err) + } + + var receiverObjectIdStrings []string for _, value := range extraArgsValues { objectId := hex.EncodeToString(value) + receiverObjectIdStrings = append(receiverObjectIdStrings, "0x"+objectId) paramValues = append(paramValues, bind.Object{Id: "0x" + objectId}) } + if err := ValidateReceiverObjectIds(receiverObjectIdStrings, addressMappings); err != nil { + return nil, fmt.Errorf("receiver %s::%s: %w", moduleId, functionName, err) + } + encodedReceiverCall, err := boundReceiverContract.EncodeCallArgsWithGenerics( functionName, typeArgsList, diff --git a/relayer/chainwriter/ptb/offramp/receiver_validation.go b/relayer/chainwriter/ptb/offramp/receiver_validation.go new file mode 100644 index 000000000..8696042f6 --- /dev/null +++ b/relayer/chainwriter/ptb/offramp/receiver_validation.go @@ -0,0 +1,146 @@ +package offramp + +import ( + "fmt" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +// receiverStandardParamCount is the number of standard parameters that every +// ccip_receive callback must begin with: +// +// [0] expected_message_id: vector +// [1] ref: &CCIPObjectRef +// [2] message: Any2SuiMessage +const receiverStandardParamCount = 3 + +// ValidateReceiverCallbackSignature validates that a receiver's ccip_receive +// callback does not declare extra parameters whose types belong to known CCIP +// protocol packages. This prevents a malicious receiver from tricking the +// relayer into injecting protocol-owned objects (e.g. OnRampState) as mutable +// PTB inputs in the transmitter-signed transaction. +func ValidateReceiverCallbackSignature( + lggr logger.Logger, + functionSig map[string]any, + decodedParamTypes []string, + ccipPackageId string, + offRampPackageId string, +) error { + if len(decodedParamTypes) < receiverStandardParamCount { + return fmt.Errorf( + "receiver callback has %d parameters, expected at least %d standard parameters "+ + "(expected_message_id, &CCIPObjectRef, Any2SuiMessage)", + len(decodedParamTypes), receiverStandardParamCount, + ) + } + + parametersRaw, ok := functionSig["parameters"] + if !ok { + return fmt.Errorf("missing 'parameters' field in receiver function signature") + } + parameters, ok := parametersRaw.([]any) + if !ok { + return fmt.Errorf("'parameters' field is not an array in receiver function signature") + } + + // Walk raw parameters, skipping TxContext (mirroring DecodeParameters), + // and inspect every extra parameter beyond the standard 3. + decodedIdx := 0 + for i, rawParam := range parameters { + meta := decodeParam(lggr, rawParam, "Reference") + if meta.Name == "TxContext" { + continue + } + + if decodedIdx >= receiverStandardParamCount { + if meta.Reference == "MutableReference" { + if isDeniedProtocolPackage(meta.Address, ccipPackageId, offRampPackageId) { + return fmt.Errorf( + "receiver callback parameter %d declares mutable reference to CCIP protocol type %s::%s::%s; "+ + "receiver callbacks must not accept mutable references to CCIP protocol objects", + i, meta.Address, meta.Module, meta.Name, + ) + } + if isDeniedProtocolModule(meta.Module, meta.Name) { + return fmt.Errorf( + "receiver callback parameter %d references denied protocol type %s::%s; "+ + "receiver callbacks must not accept references to CCIP internal objects", + i, meta.Module, meta.Name, + ) + } + } + } + + decodedIdx++ + } + + return nil +} + +// ValidateReceiverObjectIdCount ensures the number of receiverObjectIds matches +// the number of extra parameters declared by the callback beyond the standard 3. +// A mismatch indicates the callback ABI and the message's extra args are +// inconsistent, which is a precondition for the object injection attack. +func ValidateReceiverObjectIdCount(decodedParamTypes []string, receiverObjectIdCount int) error { + expectedExtraParams := len(decodedParamTypes) - receiverStandardParamCount + if expectedExtraParams < 0 { + expectedExtraParams = 0 + } + if receiverObjectIdCount != expectedExtraParams { + return fmt.Errorf( + "receiver callback declares %d extra object parameters but receiverObjectIds contains %d entries; counts must match", + expectedExtraParams, receiverObjectIdCount, + ) + } + return nil +} + +// ValidateReceiverObjectIds checks that none of the supplied receiver object +// IDs reference known CCIP protocol objects. Accepting protocol objects as +// receiver callback arguments would let a malicious receiver modify protocol +// state via the transmitter-signed PTB. +func ValidateReceiverObjectIds(objectIds []string, addressMappings *OffRampAddressMappings) error { + denied := map[string]string{ + addressMappings.CcipObjectRef: "CCIPObjectRef", + addressMappings.OffRampState: "OffRampState", + } + if addressMappings.CcipOwnerCap != "" { + denied[addressMappings.CcipOwnerCap] = "CcipOwnerCap" + } + + for i, objectId := range objectIds { + if name, found := denied[objectId]; found { + return fmt.Errorf( + "receiverObjectIds[%d] (%s) references protocol object %s; "+ + "receiver callbacks must not be passed CCIP protocol objects", + i, objectId, name, + ) + } + } + return nil +} + +func isDeniedProtocolPackage(addr, ccipPackageId, offRampPackageId string) bool { + return addr != "" && (addr == ccipPackageId || addr == offRampPackageId) +} + +// isDeniedProtocolModule provides a defense-in-depth check against known CCIP +// protocol module+type combinations. This catches cases where the attacker's +// package references protocol types whose package ID isn't in addressMappings +// (e.g. the onramp package). +func isDeniedProtocolModule(module, name string) bool { + denied := map[string]map[string]bool{ + "onramp": {"OnRampState": true}, + "offramp": {"OffRampState": true}, + "fee_quoter": {"FeeQuoterState": true}, + "token_admin_registry": {"TokenAdminRegistryState": true}, + "receiver_registry": {"ReceiverRegistry": true}, + "nonce_manager": {"NonceManagerState": true}, + "state_object": {"CCIPObjectRef": true}, + "offramp_state_helper": {"ReceiverParams": true}, + } + if names, ok := denied[module]; ok { + return names[name] + } + return false +} diff --git a/relayer/chainwriter/ptb/offramp/receiver_validation_test.go b/relayer/chainwriter/ptb/offramp/receiver_validation_test.go new file mode 100644 index 000000000..9d39301de --- /dev/null +++ b/relayer/chainwriter/ptb/offramp/receiver_validation_test.go @@ -0,0 +1,331 @@ +package offramp + +import ( + "testing" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + testCcipPackageId = "0x00000000000000000000000000000000000000000000000000000000ccipccip" + testOffRampPackageId = "0x000000000000000000000000000000000000000000000000000000000ff2a3f0" + testCcipObjectRef = "0x000000000000000000000000000000000000000000000000000000000bj3c7r3" + testOffRampState = "0x000000000000000000000000000000000000000000000000000000000ff2a357" + testCcipOwnerCap = "0x0000000000000000000000000000000000000000000000000000000000ca9ca9" +) + +func testAddressMappings() *OffRampAddressMappings { + return &OffRampAddressMappings{ + CcipPackageId: testCcipPackageId, + CcipObjectRef: testCcipObjectRef, + CcipOwnerCap: testCcipOwnerCap, + OffRampPackageId: testOffRampPackageId, + OffRampState: testOffRampState, + } +} + +func standardParams(ccipPkgId string) []any { + return []any{ + map[string]any{"Vector": "U8"}, + map[string]any{ + "Reference": map[string]any{ + "Struct": map[string]any{ + "address": ccipPkgId, + "module": "state_object", + "name": "CCIPObjectRef", + "typeArguments": []any{}, + }, + }, + }, + map[string]any{ + "Struct": map[string]any{ + "address": ccipPkgId, + "module": "client", + "name": "Any2SuiMessage", + "typeArguments": []any{}, + }, + }, + } +} + +func TestValidateReceiverCallbackSignature_StandardParams(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + params := standardParams(testCcipPackageId) + funcSig := map[string]any{"parameters": params} + decodedTypes := []string{"vector", "&object", "object_id"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.NoError(t, err) +} + +func TestValidateReceiverCallbackSignature_LegitExtraParams(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + params := append(standardParams(testCcipPackageId), + map[string]any{ + "Reference": map[string]any{ + "Struct": map[string]any{ + "address": "0x2", + "module": "clock", + "name": "Clock", + "typeArguments": []any{}, + }, + }, + }, + map[string]any{ + "MutableReference": map[string]any{ + "Struct": map[string]any{ + "address": "0xdeadbeef", + "module": "my_receiver", + "name": "ReceiverState", + "typeArguments": []any{}, + }, + }, + }, + ) + funcSig := map[string]any{"parameters": params} + decodedTypes := []string{"vector", "&object", "object_id", "&object", "&mut object"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.NoError(t, err, "legitimate extra params (Clock + receiver's own state) should pass") +} + +func TestValidateReceiverCallbackSignature_RejectsMutableCcipProtocolType(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + params := append(standardParams(testCcipPackageId), + map[string]any{ + "MutableReference": map[string]any{ + "Struct": map[string]any{ + "address": testCcipPackageId, + "module": "fee_quoter", + "name": "FeeQuoterState", + "typeArguments": []any{}, + }, + }, + }, + ) + funcSig := map[string]any{"parameters": params} + decodedTypes := []string{"vector", "&object", "object_id", "&mut object"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.Error(t, err) + assert.Contains(t, err.Error(), "mutable reference to CCIP protocol type") + assert.Contains(t, err.Error(), "FeeQuoterState") +} + +func TestValidateReceiverCallbackSignature_RejectsMutableOnRampState(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + onrampPackageId := "0x0000000000000000000000000000000000000000000000000000000000012345" + params := append(standardParams(testCcipPackageId), + map[string]any{ + "MutableReference": map[string]any{ + "Struct": map[string]any{ + "address": onrampPackageId, + "module": "onramp", + "name": "OnRampState", + "typeArguments": []any{}, + }, + }, + }, + ) + funcSig := map[string]any{"parameters": params} + decodedTypes := []string{"vector", "&object", "object_id", "&mut object"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.Error(t, err, "OnRampState should be caught by module name denylist even when package ID is unknown") + assert.Contains(t, err.Error(), "denied protocol type") + assert.Contains(t, err.Error(), "OnRampState") +} + +func TestValidateReceiverCallbackSignature_RejectsMutableOffRampPackageType(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + params := append(standardParams(testCcipPackageId), + map[string]any{ + "MutableReference": map[string]any{ + "Struct": map[string]any{ + "address": testOffRampPackageId, + "module": "offramp", + "name": "OffRampState", + "typeArguments": []any{}, + }, + }, + }, + ) + funcSig := map[string]any{"parameters": params} + decodedTypes := []string{"vector", "&object", "object_id", "&mut object"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.Error(t, err) + assert.Contains(t, err.Error(), "CCIP protocol type") +} + +func TestValidateReceiverCallbackSignature_TooFewParams(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + funcSig := map[string]any{"parameters": []any{map[string]any{"Vector": "U8"}}} + decodedTypes := []string{"vector"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.Error(t, err) + assert.Contains(t, err.Error(), "expected at least 3 standard parameters") +} + +func TestValidateReceiverCallbackSignature_TxContextSkipped(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + params := append(standardParams(testCcipPackageId), + map[string]any{ + "MutableReference": map[string]any{ + "Struct": map[string]any{ + "address": "0x2", + "module": "tx_context", + "name": "TxContext", + "typeArguments": []any{}, + }, + }, + }, + ) + funcSig := map[string]any{"parameters": params} + decodedTypes := []string{"vector", "&object", "object_id"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.NoError(t, err, "TxContext parameter should be skipped") +} + +func TestValidateReceiverObjectIdCount_Matching(t *testing.T) { + t.Parallel() + + decodedTypes := []string{"vector", "&object", "object_id", "&object", "&mut object"} + err := ValidateReceiverObjectIdCount(decodedTypes, 2) + require.NoError(t, err) +} + +func TestValidateReceiverObjectIdCount_ExactlyStandard(t *testing.T) { + t.Parallel() + + decodedTypes := []string{"vector", "&object", "object_id"} + err := ValidateReceiverObjectIdCount(decodedTypes, 0) + require.NoError(t, err) +} + +func TestValidateReceiverObjectIdCount_Mismatch_TooMany(t *testing.T) { + t.Parallel() + + decodedTypes := []string{"vector", "&object", "object_id"} + err := ValidateReceiverObjectIdCount(decodedTypes, 2) + require.Error(t, err) + assert.Contains(t, err.Error(), "declares 0 extra object parameters but receiverObjectIds contains 2") +} + +func TestValidateReceiverObjectIdCount_Mismatch_TooFew(t *testing.T) { + t.Parallel() + + decodedTypes := []string{"vector", "&object", "object_id", "&mut object", "&mut object"} + err := ValidateReceiverObjectIdCount(decodedTypes, 1) + require.Error(t, err) + assert.Contains(t, err.Error(), "declares 2 extra object parameters but receiverObjectIds contains 1") +} + +func TestValidateReceiverObjectIds_Safe(t *testing.T) { + t.Parallel() + + objectIds := []string{ + "0x0000000000000000000000000000000000000000000000000000000000aaaaaa", + "0x0000000000000000000000000000000000000000000000000000000000bbbbbb", + } + err := ValidateReceiverObjectIds(objectIds, testAddressMappings()) + require.NoError(t, err) +} + +func TestValidateReceiverObjectIds_RejectsCcipObjectRef(t *testing.T) { + t.Parallel() + + objectIds := []string{testCcipObjectRef} + err := ValidateReceiverObjectIds(objectIds, testAddressMappings()) + require.Error(t, err) + assert.Contains(t, err.Error(), "CCIPObjectRef") +} + +func TestValidateReceiverObjectIds_RejectsOffRampState(t *testing.T) { + t.Parallel() + + objectIds := []string{testOffRampState} + err := ValidateReceiverObjectIds(objectIds, testAddressMappings()) + require.Error(t, err) + assert.Contains(t, err.Error(), "OffRampState") +} + +func TestValidateReceiverObjectIds_RejectsCcipOwnerCap(t *testing.T) { + t.Parallel() + + objectIds := []string{testCcipOwnerCap} + err := ValidateReceiverObjectIds(objectIds, testAddressMappings()) + require.Error(t, err) + assert.Contains(t, err.Error(), "CcipOwnerCap") +} + +func TestValidateReceiverObjectIds_EmptyList(t *testing.T) { + t.Parallel() + + err := ValidateReceiverObjectIds(nil, testAddressMappings()) + require.NoError(t, err) +} + +func TestIsDeniedProtocolPackage(t *testing.T) { + t.Parallel() + + assert.True(t, isDeniedProtocolPackage(testCcipPackageId, testCcipPackageId, testOffRampPackageId)) + assert.True(t, isDeniedProtocolPackage(testOffRampPackageId, testCcipPackageId, testOffRampPackageId)) + assert.False(t, isDeniedProtocolPackage("0xdeadbeef", testCcipPackageId, testOffRampPackageId)) + assert.False(t, isDeniedProtocolPackage("", testCcipPackageId, testOffRampPackageId)) +} + +func TestIsDeniedProtocolModule(t *testing.T) { + t.Parallel() + + assert.True(t, isDeniedProtocolModule("onramp", "OnRampState")) + assert.True(t, isDeniedProtocolModule("offramp", "OffRampState")) + assert.True(t, isDeniedProtocolModule("fee_quoter", "FeeQuoterState")) + assert.True(t, isDeniedProtocolModule("state_object", "CCIPObjectRef")) + assert.False(t, isDeniedProtocolModule("my_receiver", "ReceiverState")) + assert.False(t, isDeniedProtocolModule("onramp", "SomeOtherType")) + assert.False(t, isDeniedProtocolModule("clock", "Clock")) +} + +func TestValidateReceiverCallbackSignature_ImmutableCcipRefAllowed(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + // Immutable reference to a CCIP type as an extra param should be allowed + // (read-only access is not dangerous in the same way mutable access is). + params := append(standardParams(testCcipPackageId), + map[string]any{ + "Reference": map[string]any{ + "Struct": map[string]any{ + "address": testCcipPackageId, + "module": "state_object", + "name": "CCIPObjectRef", + "typeArguments": []any{}, + }, + }, + }, + ) + funcSig := map[string]any{"parameters": params} + decodedTypes := []string{"vector", "&object", "object_id", "&object"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.NoError(t, err, "immutable references are safe; only mutable references to protocol types are denied") +} From 57a4552db8b5e248f70553c9a3cc549b0f8d8e65 Mon Sep 17 00:00:00 2001 From: faisal-link Date: Wed, 8 Apr 2026 21:35:01 +0400 Subject: [PATCH 3/4] registry bindings gen --- .../receiver_registry/receiver_registry.go | 47 ++++++++++++------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/bindings/generated/ccip/ccip/receiver_registry/receiver_registry.go b/bindings/generated/ccip/ccip/receiver_registry/receiver_registry.go index 24bd7ef01..88acd2617 100644 --- a/bindings/generated/ccip/ccip/receiver_registry/receiver_registry.go +++ b/bindings/generated/ccip/ccip/receiver_registry/receiver_registry.go @@ -19,12 +19,12 @@ var ( _ = big.NewInt ) -const FunctionInfo = `[{"package":"ccip","module":"receiver_registry","name":"get_receiver_config","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"receiver_package_id","type":"address"}]},{"package":"ccip","module":"receiver_registry","name":"get_receiver_config_fields","parameters":[{"name":"rc","type":"ReceiverConfig"}]},{"package":"ccip","module":"receiver_registry","name":"get_receiver_info","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"receiver_package_id","type":"address"}]},{"package":"ccip","module":"receiver_registry","name":"initialize","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"owner_cap","type":"OwnerCap"}]},{"package":"ccip","module":"receiver_registry","name":"is_registered_receiver","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"receiver_package_id","type":"address"}]},{"package":"ccip","module":"receiver_registry","name":"register_receiver","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"publisher_wrapper","type":"PublisherWrapper"},{"name":"_proof","type":"ProofType"}]},{"package":"ccip","module":"receiver_registry","name":"type_and_version","parameters":null},{"package":"ccip","module":"receiver_registry","name":"unregister_receiver","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"owner_cap","type":"OwnerCap"},{"name":"receiver_package_id","type":"address"}]}]` +const FunctionInfo = `[{"package":"ccip","module":"receiver_registry","name":"get_receiver_config","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"receiver_package_id","type":"address"}]},{"package":"ccip","module":"receiver_registry","name":"get_receiver_config_fields","parameters":[{"name":"rc","type":"ReceiverConfig"}]},{"package":"ccip","module":"receiver_registry","name":"get_receiver_info","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"receiver_package_id","type":"address"}]},{"package":"ccip","module":"receiver_registry","name":"initialize","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"owner_cap","type":"OwnerCap"}]},{"package":"ccip","module":"receiver_registry","name":"is_registered_receiver","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"receiver_package_id","type":"address"}]},{"package":"ccip","module":"receiver_registry","name":"register_receiver","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"publisher_wrapper","type":"PublisherWrapper"},{"name":"_proof","type":"ProofType"},{"name":"expected_receiver_object_id_count","type":"u64"}]},{"package":"ccip","module":"receiver_registry","name":"type_and_version","parameters":null},{"package":"ccip","module":"receiver_registry","name":"unregister_receiver","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"owner_cap","type":"OwnerCap"},{"name":"receiver_package_id","type":"address"}]}]` type IReceiverRegistry interface { TypeAndVersion(ctx context.Context, opts *bind.CallOpts) (*models.SuiTransactionBlockResponse, error) Initialize(ctx context.Context, opts *bind.CallOpts, ref bind.Object, ownerCap bind.Object) (*models.SuiTransactionBlockResponse, error) - RegisterReceiver(ctx context.Context, opts *bind.CallOpts, typeArgs []string, ref bind.Object, publisherWrapper bind.Object, proof bind.Object) (*models.SuiTransactionBlockResponse, error) + RegisterReceiver(ctx context.Context, opts *bind.CallOpts, typeArgs []string, ref bind.Object, publisherWrapper bind.Object, proof bind.Object, expectedReceiverObjectIdCount uint64) (*models.SuiTransactionBlockResponse, error) UnregisterReceiver(ctx context.Context, opts *bind.CallOpts, ref bind.Object, ownerCap bind.Object, receiverPackageId string) (*models.SuiTransactionBlockResponse, error) IsRegisteredReceiver(ctx context.Context, opts *bind.CallOpts, ref bind.Object, receiverPackageId string) (*models.SuiTransactionBlockResponse, error) GetReceiverConfig(ctx context.Context, opts *bind.CallOpts, ref bind.Object, receiverPackageId string) (*models.SuiTransactionBlockResponse, error) @@ -48,7 +48,7 @@ type ReceiverRegistryEncoder interface { TypeAndVersionWithArgs(args ...any) (*bind.EncodedCall, error) Initialize(ref bind.Object, ownerCap bind.Object) (*bind.EncodedCall, error) InitializeWithArgs(args ...any) (*bind.EncodedCall, error) - RegisterReceiver(typeArgs []string, ref bind.Object, publisherWrapper bind.Object, proof bind.Object) (*bind.EncodedCall, error) + RegisterReceiver(typeArgs []string, ref bind.Object, publisherWrapper bind.Object, proof bind.Object, expectedReceiverObjectIdCount uint64) (*bind.EncodedCall, error) RegisterReceiverWithArgs(typeArgs []string, args ...any) (*bind.EncodedCall, error) UnregisterReceiver(ref bind.Object, ownerCap bind.Object, receiverPackageId string) (*bind.EncodedCall, error) UnregisterReceiverWithArgs(args ...any) (*bind.EncodedCall, error) @@ -102,8 +102,9 @@ func (c *ReceiverRegistryContract) DevInspect() IReceiverRegistryDevInspect { } type ReceiverConfig struct { - ModuleName string `move:"0x1::string::String"` - ProofTypename string `move:"ascii::String"` + ModuleName string `move:"0x1::string::String"` + ProofTypename string `move:"ascii::String"` + ExpectedReceiverObjectIdCount uint64 `move:"u64"` } type ReceiverRegistry struct { @@ -112,9 +113,10 @@ type ReceiverRegistry struct { } type ReceiverRegistered struct { - ReceiverPackageId string `move:"address"` - ReceiverModuleName string `move:"0x1::string::String"` - ProofTypename string `move:"ascii::String"` + ReceiverPackageId string `move:"address"` + ReceiverModuleName string `move:"0x1::string::String"` + ProofTypename string `move:"ascii::String"` + ExpectedReceiverObjectIdCount uint64 `move:"u64"` } type ReceiverUnregistered struct { @@ -122,17 +124,19 @@ type ReceiverUnregistered struct { } type bcsReceiverRegistered struct { - ReceiverPackageId [32]byte - ReceiverModuleName string - ProofTypename string + ReceiverPackageId [32]byte + ReceiverModuleName string + ProofTypename string + ExpectedReceiverObjectIdCount uint64 } func convertReceiverRegisteredFromBCS(bcs bcsReceiverRegistered) (ReceiverRegistered, error) { return ReceiverRegistered{ - ReceiverPackageId: fmt.Sprintf("0x%x", bcs.ReceiverPackageId), - ReceiverModuleName: bcs.ReceiverModuleName, - ProofTypename: bcs.ProofTypename, + ReceiverPackageId: fmt.Sprintf("0x%x", bcs.ReceiverPackageId), + ReceiverModuleName: bcs.ReceiverModuleName, + ProofTypename: bcs.ProofTypename, + ExpectedReceiverObjectIdCount: bcs.ExpectedReceiverObjectIdCount, }, nil } @@ -267,8 +271,8 @@ func (c *ReceiverRegistryContract) Initialize(ctx context.Context, opts *bind.Ca } // RegisterReceiver executes the register_receiver Move function. -func (c *ReceiverRegistryContract) RegisterReceiver(ctx context.Context, opts *bind.CallOpts, typeArgs []string, ref bind.Object, publisherWrapper bind.Object, proof bind.Object) (*models.SuiTransactionBlockResponse, error) { - encoded, err := c.receiverRegistryEncoder.RegisterReceiver(typeArgs, ref, publisherWrapper, proof) +func (c *ReceiverRegistryContract) RegisterReceiver(ctx context.Context, opts *bind.CallOpts, typeArgs []string, ref bind.Object, publisherWrapper bind.Object, proof bind.Object, expectedReceiverObjectIdCount uint64) (*models.SuiTransactionBlockResponse, error) { + encoded, err := c.receiverRegistryEncoder.RegisterReceiver(typeArgs, ref, publisherWrapper, proof, expectedReceiverObjectIdCount) if err != nil { return nil, fmt.Errorf("failed to encode function call: %w", err) } @@ -398,6 +402,7 @@ func (d *ReceiverRegistryDevInspect) GetReceiverConfig(ctx context.Context, opts // // [0]: 0x1::string::String // [1]: ascii::String +// [2]: u64 func (d *ReceiverRegistryDevInspect) GetReceiverConfigFields(ctx context.Context, opts *bind.CallOpts, rc ReceiverConfig) ([]any, error) { encoded, err := d.contract.receiverRegistryEncoder.GetReceiverConfigFields(rc) if err != nil { @@ -412,6 +417,7 @@ func (d *ReceiverRegistryDevInspect) GetReceiverConfigFields(ctx context.Context // // [0]: 0x1::string::String // [1]: ascii::String +// [2]: u64 func (d *ReceiverRegistryDevInspect) GetReceiverInfo(ctx context.Context, opts *bind.CallOpts, ref bind.Object, receiverPackageId string) ([]any, error) { encoded, err := d.contract.receiverRegistryEncoder.GetReceiverInfo(ref, receiverPackageId) if err != nil { @@ -478,7 +484,7 @@ func (c receiverRegistryEncoder) InitializeWithArgs(args ...any) (*bind.EncodedC } // RegisterReceiver encodes a call to the register_receiver Move function. -func (c receiverRegistryEncoder) RegisterReceiver(typeArgs []string, ref bind.Object, publisherWrapper bind.Object, proof bind.Object) (*bind.EncodedCall, error) { +func (c receiverRegistryEncoder) RegisterReceiver(typeArgs []string, ref bind.Object, publisherWrapper bind.Object, proof bind.Object, expectedReceiverObjectIdCount uint64) (*bind.EncodedCall, error) { typeArgsList := typeArgs typeParamsList := []string{ "ProofType", @@ -487,10 +493,12 @@ func (c receiverRegistryEncoder) RegisterReceiver(typeArgs []string, ref bind.Ob "&mut CCIPObjectRef", "PublisherWrapper", "ProofType", + "u64", }, []any{ ref, publisherWrapper, proof, + expectedReceiverObjectIdCount, }, nil) } @@ -501,6 +509,7 @@ func (c receiverRegistryEncoder) RegisterReceiverWithArgs(typeArgs []string, arg "&mut CCIPObjectRef", "PublisherWrapper", "ProofType", + "u64", } if len(args) != len(expectedParams) { @@ -622,6 +631,7 @@ func (c receiverRegistryEncoder) GetReceiverConfigFields(rc ReceiverConfig) (*bi }, []string{ "0x1::string::String", "ascii::String", + "u64", }) } @@ -640,6 +650,7 @@ func (c receiverRegistryEncoder) GetReceiverConfigFieldsWithArgs(args ...any) (* return c.EncodeCallArgsWithGenerics("get_receiver_config_fields", typeArgsList, typeParamsList, expectedParams, args, []string{ "0x1::string::String", "ascii::String", + "u64", }) } @@ -656,6 +667,7 @@ func (c receiverRegistryEncoder) GetReceiverInfo(ref bind.Object, receiverPackag }, []string{ "0x1::string::String", "ascii::String", + "u64", }) } @@ -675,5 +687,6 @@ func (c receiverRegistryEncoder) GetReceiverInfoWithArgs(args ...any) (*bind.Enc return c.EncodeCallArgsWithGenerics("get_receiver_info", typeArgsList, typeParamsList, expectedParams, args, []string{ "0x1::string::String", "ascii::String", + "u64", }) } From f3394ca3a2bde1832b0643cc3cb7f7d592f7b8b6 Mon Sep 17 00:00:00 2001 From: faisal-link Date: Thu, 9 Apr 2026 00:13:24 +0400 Subject: [PATCH 4/4] fix: harden off-chain ABI decoding against malicious receiver metadata (Report #71024) - Revert on-chain receiver_registry changes to keep fix entirely off-chain - Convert decodeParam panics to errors with checked type assertions - Add explicit TypeParameter rejection in ABI parameter decoding - Add defer/recover defense-in-depth in BuildOffRampExecutePTB - Fix unchecked assertions in token pool and receiver PTB construction - Add comprehensive unit tests for malformed and adversarial ABI shapes --- .../receiver_registry/receiver_registry.go | 47 +-- .../ccip/sources/offramp_state_helper.move | 2 +- .../ccip/ccip/sources/receiver_registry.move | 24 +- .../tests/offramp_state_helper_tests.move | 1 - .../ccip/tests/receiver_registry_tests.move | 29 +- .../sources/ccip_dummy_receiver.move | 3 +- relayer/chainwriter/ptb/offramp/execute.go | 33 +- relayer/chainwriter/ptb/offramp/helpers.go | 156 ++++++-- .../chainwriter/ptb/offramp/helpers_test.go | 369 ++++++++++++++++++ .../ptb/offramp/receiver_validation.go | 7 +- .../ptb/offramp/receiver_validation_test.go | 54 ++- 11 files changed, 611 insertions(+), 114 deletions(-) create mode 100644 relayer/chainwriter/ptb/offramp/helpers_test.go diff --git a/bindings/generated/ccip/ccip/receiver_registry/receiver_registry.go b/bindings/generated/ccip/ccip/receiver_registry/receiver_registry.go index 88acd2617..24bd7ef01 100644 --- a/bindings/generated/ccip/ccip/receiver_registry/receiver_registry.go +++ b/bindings/generated/ccip/ccip/receiver_registry/receiver_registry.go @@ -19,12 +19,12 @@ var ( _ = big.NewInt ) -const FunctionInfo = `[{"package":"ccip","module":"receiver_registry","name":"get_receiver_config","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"receiver_package_id","type":"address"}]},{"package":"ccip","module":"receiver_registry","name":"get_receiver_config_fields","parameters":[{"name":"rc","type":"ReceiverConfig"}]},{"package":"ccip","module":"receiver_registry","name":"get_receiver_info","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"receiver_package_id","type":"address"}]},{"package":"ccip","module":"receiver_registry","name":"initialize","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"owner_cap","type":"OwnerCap"}]},{"package":"ccip","module":"receiver_registry","name":"is_registered_receiver","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"receiver_package_id","type":"address"}]},{"package":"ccip","module":"receiver_registry","name":"register_receiver","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"publisher_wrapper","type":"PublisherWrapper"},{"name":"_proof","type":"ProofType"},{"name":"expected_receiver_object_id_count","type":"u64"}]},{"package":"ccip","module":"receiver_registry","name":"type_and_version","parameters":null},{"package":"ccip","module":"receiver_registry","name":"unregister_receiver","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"owner_cap","type":"OwnerCap"},{"name":"receiver_package_id","type":"address"}]}]` +const FunctionInfo = `[{"package":"ccip","module":"receiver_registry","name":"get_receiver_config","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"receiver_package_id","type":"address"}]},{"package":"ccip","module":"receiver_registry","name":"get_receiver_config_fields","parameters":[{"name":"rc","type":"ReceiverConfig"}]},{"package":"ccip","module":"receiver_registry","name":"get_receiver_info","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"receiver_package_id","type":"address"}]},{"package":"ccip","module":"receiver_registry","name":"initialize","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"owner_cap","type":"OwnerCap"}]},{"package":"ccip","module":"receiver_registry","name":"is_registered_receiver","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"receiver_package_id","type":"address"}]},{"package":"ccip","module":"receiver_registry","name":"register_receiver","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"publisher_wrapper","type":"PublisherWrapper"},{"name":"_proof","type":"ProofType"}]},{"package":"ccip","module":"receiver_registry","name":"type_and_version","parameters":null},{"package":"ccip","module":"receiver_registry","name":"unregister_receiver","parameters":[{"name":"ref","type":"CCIPObjectRef"},{"name":"owner_cap","type":"OwnerCap"},{"name":"receiver_package_id","type":"address"}]}]` type IReceiverRegistry interface { TypeAndVersion(ctx context.Context, opts *bind.CallOpts) (*models.SuiTransactionBlockResponse, error) Initialize(ctx context.Context, opts *bind.CallOpts, ref bind.Object, ownerCap bind.Object) (*models.SuiTransactionBlockResponse, error) - RegisterReceiver(ctx context.Context, opts *bind.CallOpts, typeArgs []string, ref bind.Object, publisherWrapper bind.Object, proof bind.Object, expectedReceiverObjectIdCount uint64) (*models.SuiTransactionBlockResponse, error) + RegisterReceiver(ctx context.Context, opts *bind.CallOpts, typeArgs []string, ref bind.Object, publisherWrapper bind.Object, proof bind.Object) (*models.SuiTransactionBlockResponse, error) UnregisterReceiver(ctx context.Context, opts *bind.CallOpts, ref bind.Object, ownerCap bind.Object, receiverPackageId string) (*models.SuiTransactionBlockResponse, error) IsRegisteredReceiver(ctx context.Context, opts *bind.CallOpts, ref bind.Object, receiverPackageId string) (*models.SuiTransactionBlockResponse, error) GetReceiverConfig(ctx context.Context, opts *bind.CallOpts, ref bind.Object, receiverPackageId string) (*models.SuiTransactionBlockResponse, error) @@ -48,7 +48,7 @@ type ReceiverRegistryEncoder interface { TypeAndVersionWithArgs(args ...any) (*bind.EncodedCall, error) Initialize(ref bind.Object, ownerCap bind.Object) (*bind.EncodedCall, error) InitializeWithArgs(args ...any) (*bind.EncodedCall, error) - RegisterReceiver(typeArgs []string, ref bind.Object, publisherWrapper bind.Object, proof bind.Object, expectedReceiverObjectIdCount uint64) (*bind.EncodedCall, error) + RegisterReceiver(typeArgs []string, ref bind.Object, publisherWrapper bind.Object, proof bind.Object) (*bind.EncodedCall, error) RegisterReceiverWithArgs(typeArgs []string, args ...any) (*bind.EncodedCall, error) UnregisterReceiver(ref bind.Object, ownerCap bind.Object, receiverPackageId string) (*bind.EncodedCall, error) UnregisterReceiverWithArgs(args ...any) (*bind.EncodedCall, error) @@ -102,9 +102,8 @@ func (c *ReceiverRegistryContract) DevInspect() IReceiverRegistryDevInspect { } type ReceiverConfig struct { - ModuleName string `move:"0x1::string::String"` - ProofTypename string `move:"ascii::String"` - ExpectedReceiverObjectIdCount uint64 `move:"u64"` + ModuleName string `move:"0x1::string::String"` + ProofTypename string `move:"ascii::String"` } type ReceiverRegistry struct { @@ -113,10 +112,9 @@ type ReceiverRegistry struct { } type ReceiverRegistered struct { - ReceiverPackageId string `move:"address"` - ReceiverModuleName string `move:"0x1::string::String"` - ProofTypename string `move:"ascii::String"` - ExpectedReceiverObjectIdCount uint64 `move:"u64"` + ReceiverPackageId string `move:"address"` + ReceiverModuleName string `move:"0x1::string::String"` + ProofTypename string `move:"ascii::String"` } type ReceiverUnregistered struct { @@ -124,19 +122,17 @@ type ReceiverUnregistered struct { } type bcsReceiverRegistered struct { - ReceiverPackageId [32]byte - ReceiverModuleName string - ProofTypename string - ExpectedReceiverObjectIdCount uint64 + ReceiverPackageId [32]byte + ReceiverModuleName string + ProofTypename string } func convertReceiverRegisteredFromBCS(bcs bcsReceiverRegistered) (ReceiverRegistered, error) { return ReceiverRegistered{ - ReceiverPackageId: fmt.Sprintf("0x%x", bcs.ReceiverPackageId), - ReceiverModuleName: bcs.ReceiverModuleName, - ProofTypename: bcs.ProofTypename, - ExpectedReceiverObjectIdCount: bcs.ExpectedReceiverObjectIdCount, + ReceiverPackageId: fmt.Sprintf("0x%x", bcs.ReceiverPackageId), + ReceiverModuleName: bcs.ReceiverModuleName, + ProofTypename: bcs.ProofTypename, }, nil } @@ -271,8 +267,8 @@ func (c *ReceiverRegistryContract) Initialize(ctx context.Context, opts *bind.Ca } // RegisterReceiver executes the register_receiver Move function. -func (c *ReceiverRegistryContract) RegisterReceiver(ctx context.Context, opts *bind.CallOpts, typeArgs []string, ref bind.Object, publisherWrapper bind.Object, proof bind.Object, expectedReceiverObjectIdCount uint64) (*models.SuiTransactionBlockResponse, error) { - encoded, err := c.receiverRegistryEncoder.RegisterReceiver(typeArgs, ref, publisherWrapper, proof, expectedReceiverObjectIdCount) +func (c *ReceiverRegistryContract) RegisterReceiver(ctx context.Context, opts *bind.CallOpts, typeArgs []string, ref bind.Object, publisherWrapper bind.Object, proof bind.Object) (*models.SuiTransactionBlockResponse, error) { + encoded, err := c.receiverRegistryEncoder.RegisterReceiver(typeArgs, ref, publisherWrapper, proof) if err != nil { return nil, fmt.Errorf("failed to encode function call: %w", err) } @@ -402,7 +398,6 @@ func (d *ReceiverRegistryDevInspect) GetReceiverConfig(ctx context.Context, opts // // [0]: 0x1::string::String // [1]: ascii::String -// [2]: u64 func (d *ReceiverRegistryDevInspect) GetReceiverConfigFields(ctx context.Context, opts *bind.CallOpts, rc ReceiverConfig) ([]any, error) { encoded, err := d.contract.receiverRegistryEncoder.GetReceiverConfigFields(rc) if err != nil { @@ -417,7 +412,6 @@ func (d *ReceiverRegistryDevInspect) GetReceiverConfigFields(ctx context.Context // // [0]: 0x1::string::String // [1]: ascii::String -// [2]: u64 func (d *ReceiverRegistryDevInspect) GetReceiverInfo(ctx context.Context, opts *bind.CallOpts, ref bind.Object, receiverPackageId string) ([]any, error) { encoded, err := d.contract.receiverRegistryEncoder.GetReceiverInfo(ref, receiverPackageId) if err != nil { @@ -484,7 +478,7 @@ func (c receiverRegistryEncoder) InitializeWithArgs(args ...any) (*bind.EncodedC } // RegisterReceiver encodes a call to the register_receiver Move function. -func (c receiverRegistryEncoder) RegisterReceiver(typeArgs []string, ref bind.Object, publisherWrapper bind.Object, proof bind.Object, expectedReceiverObjectIdCount uint64) (*bind.EncodedCall, error) { +func (c receiverRegistryEncoder) RegisterReceiver(typeArgs []string, ref bind.Object, publisherWrapper bind.Object, proof bind.Object) (*bind.EncodedCall, error) { typeArgsList := typeArgs typeParamsList := []string{ "ProofType", @@ -493,12 +487,10 @@ func (c receiverRegistryEncoder) RegisterReceiver(typeArgs []string, ref bind.Ob "&mut CCIPObjectRef", "PublisherWrapper", "ProofType", - "u64", }, []any{ ref, publisherWrapper, proof, - expectedReceiverObjectIdCount, }, nil) } @@ -509,7 +501,6 @@ func (c receiverRegistryEncoder) RegisterReceiverWithArgs(typeArgs []string, arg "&mut CCIPObjectRef", "PublisherWrapper", "ProofType", - "u64", } if len(args) != len(expectedParams) { @@ -631,7 +622,6 @@ func (c receiverRegistryEncoder) GetReceiverConfigFields(rc ReceiverConfig) (*bi }, []string{ "0x1::string::String", "ascii::String", - "u64", }) } @@ -650,7 +640,6 @@ func (c receiverRegistryEncoder) GetReceiverConfigFieldsWithArgs(args ...any) (* return c.EncodeCallArgsWithGenerics("get_receiver_config_fields", typeArgsList, typeParamsList, expectedParams, args, []string{ "0x1::string::String", "ascii::String", - "u64", }) } @@ -667,7 +656,6 @@ func (c receiverRegistryEncoder) GetReceiverInfo(ref bind.Object, receiverPackag }, []string{ "0x1::string::String", "ascii::String", - "u64", }) } @@ -687,6 +675,5 @@ func (c receiverRegistryEncoder) GetReceiverInfoWithArgs(args ...any) (*bind.Enc return c.EncodeCallArgsWithGenerics("get_receiver_info", typeArgsList, typeParamsList, expectedParams, args, []string{ "0x1::string::String", "ascii::String", - "u64", }) } diff --git a/contracts/ccip/ccip/sources/offramp_state_helper.move b/contracts/ccip/ccip/sources/offramp_state_helper.move index d4a1acfdd..dbd2ac86e 100644 --- a/contracts/ccip/ccip/sources/offramp_state_helper.move +++ b/contracts/ccip/ccip/sources/offramp_state_helper.move @@ -233,7 +233,7 @@ public fun consume_any2sui_message( let receiver_package_id = address::from_ascii_bytes(&ascii::into_bytes(address_str)); let receiver_config = receiver_registry::get_receiver_config(ref, receiver_package_id); - let (_, proof_typename, _) = receiver_registry::get_receiver_config_fields(receiver_config); + let (_, proof_typename) = receiver_registry::get_receiver_config_fields(receiver_config); assert!(proof_typename == proof_tn.into_string(), ETypeProofMismatch); client::consume_any2sui_message(message, receiver_package_id) diff --git a/contracts/ccip/ccip/sources/receiver_registry.move b/contracts/ccip/ccip/sources/receiver_registry.move index dd7532961..dc0337906 100644 --- a/contracts/ccip/ccip/sources/receiver_registry.move +++ b/contracts/ccip/ccip/sources/receiver_registry.move @@ -13,12 +13,6 @@ use sui::linked_table::{Self, LinkedTable}; public struct ReceiverConfig has copy, drop, store { module_name: String, proof_typename: ascii::String, - /// The number of extra object IDs that the receiver's ccip_receive callback - /// expects beyond the standard 3 parameters (expected_message_id, - /// &CCIPObjectRef, Any2SuiMessage). The relayer uses this to validate that - /// the receiverObjectIds count in a CCIP message matches what the receiver - /// registered, preventing object injection attacks. - expected_receiver_object_id_count: u64, } public struct ReceiverRegistry has key, store { @@ -31,7 +25,6 @@ public struct ReceiverRegistered has copy, drop { receiver_package_id: address, receiver_module_name: String, proof_typename: ascii::String, - expected_receiver_object_id_count: u64, } public struct ReceiverUnregistered has copy, drop { @@ -64,7 +57,6 @@ public fun register_receiver( ref: &mut CCIPObjectRef, publisher_wrapper: PublisherWrapper, _proof: ProofType, - expected_receiver_object_id_count: u64, ) { verify_function_allowed( ref, @@ -81,7 +73,6 @@ public fun register_receiver( let receiver_config = ReceiverConfig { module_name: receiver_module_name, proof_typename: proof_typename.into_string(), - expected_receiver_object_id_count, }; registry.receiver_configs.push_back(receiver_package_id, receiver_config); @@ -89,7 +80,6 @@ public fun register_receiver( receiver_package_id, receiver_module_name, proof_typename: proof_typename.into_string(), - expected_receiver_object_id_count, }); } @@ -142,15 +132,15 @@ public fun get_receiver_config(ref: &CCIPObjectRef, receiver_package_id: address *registry.receiver_configs.borrow(receiver_package_id) } -public fun get_receiver_config_fields(rc: ReceiverConfig): (String, ascii::String, u64) { - (rc.module_name, rc.proof_typename, rc.expected_receiver_object_id_count) +public fun get_receiver_config_fields(rc: ReceiverConfig): (String, ascii::String) { + (rc.module_name, rc.proof_typename) } // this will return empty string if the receiver is not registered. public fun get_receiver_info( ref: &CCIPObjectRef, receiver_package_id: address, -): (String, ascii::String, u64) { +): (String, ascii::String) { verify_function_allowed( ref, string::utf8(b"receiver_registry"), @@ -161,12 +151,8 @@ public fun get_receiver_info( if (registry.receiver_configs.contains(receiver_package_id)) { let receiver_config = registry.receiver_configs.borrow(receiver_package_id); - return ( - receiver_config.module_name, - receiver_config.proof_typename, - receiver_config.expected_receiver_object_id_count, - ) + return (receiver_config.module_name, receiver_config.proof_typename) }; - (string::utf8(b""), ascii::string(b""), 0) + (string::utf8(b""), ascii::string(b"")) } diff --git a/contracts/ccip/ccip/tests/offramp_state_helper_tests.move b/contracts/ccip/ccip/tests/offramp_state_helper_tests.move index 5cb8f4980..0f69d4ef5 100644 --- a/contracts/ccip/ccip/tests/offramp_state_helper_tests.move +++ b/contracts/ccip/ccip/tests/offramp_state_helper_tests.move @@ -258,7 +258,6 @@ public fun test_extract_any2sui_message() { &mut ref, publisher_wrapper, TestTypeProof {}, - 0, ); package::burn_publisher(publisher); diff --git a/contracts/ccip/ccip/tests/receiver_registry_tests.move b/contracts/ccip/ccip/tests/receiver_registry_tests.move index 54e9d50ad..43d668de6 100644 --- a/contracts/ccip/ccip/tests/receiver_registry_tests.move +++ b/contracts/ccip/ccip/tests/receiver_registry_tests.move @@ -67,7 +67,7 @@ fun register_test_receiver( let publisher = package::test_claim(RECEIVER_REGISTRY_TESTS {}, ctx); let publisher_wrapper = publisher_wrapper::create(&publisher, proof); - receiver_registry::register_receiver(ref, publisher_wrapper, proof, 0); + receiver_registry::register_receiver(ref, publisher_wrapper, proof); package::burn_publisher(publisher); } @@ -114,13 +114,12 @@ public fun test_register_receiver() { // Get receiver config and verify fields let config = receiver_registry::get_receiver_config(&ref, package_id_1); - let (module_name, proof_typename, expected_count) = receiver_registry::get_receiver_config_fields(config); + let (module_name, proof_typename) = receiver_registry::get_receiver_config_fields(config); assert!(module_name == string::utf8(b"receiver_registry_tests")); assert!( proof_typename == type_name::into_string(type_name::with_defining_ids()), ); - assert!(expected_count == 0); cleanup_test(scenario, ref, owner_cap); } @@ -179,7 +178,7 @@ public fun test_register_multiple_receivers_same_package() { // Verify the config contains the first proof type let config = receiver_registry::get_receiver_config(&ref, package_id_1); - let (_, proof_type, _) = receiver_registry::get_receiver_config_fields(config); + let (_, proof_type) = receiver_registry::get_receiver_config_fields(config); assert!( proof_type == type_name::into_string(type_name::with_defining_ids()), @@ -279,14 +278,13 @@ public fun test_get_receiver_config() { // Get the config let package_id_1 = get_package_id_from_proof(); let config = receiver_registry::get_receiver_config(&ref, package_id_1); - let (module_name, proof_typename, expected_count) = receiver_registry::get_receiver_config_fields(config); + let (module_name, proof_typename) = receiver_registry::get_receiver_config_fields(config); // Verify all fields assert!(module_name == string::utf8(b"receiver_registry_tests")); assert!( proof_typename == type_name::into_string(type_name::with_defining_ids()), ); - assert!(expected_count == 0); cleanup_test(scenario, ref, owner_cap); } @@ -300,27 +298,26 @@ public fun test_get_receiver_module_and_state() { // Test unregistered receiver - should return empty values let package_id_1 = get_package_id_from_proof(); - let (module_name, proof_typename_str, expected_count) = receiver_registry::get_receiver_info( + let (module_name, proof_typename_str) = receiver_registry::get_receiver_info( &ref, package_id_1, ); assert!(module_name == string::utf8(b"")); assert!(proof_typename_str == ascii::string(b"")); - assert!(expected_count == 0); // Register a receiver register_test_receiver(&mut ref, TestReceiverProof {}, ctx); // Test registered receiver - should return actual values - let (module_name, proof_typename_str, expected_count) = receiver_registry::get_receiver_info( + let (module_name, proof_typename_str) = receiver_registry::get_receiver_info( &ref, package_id_1, ); assert!(module_name == string::utf8(b"receiver_registry_tests")); + // The proof typename string should contain the test receiver proof type assert!( proof_typename_str == type_name::into_string(type_name::with_defining_ids()), ); - assert!(expected_count == 0); cleanup_test(scenario, ref, owner_cap); } @@ -338,10 +335,10 @@ public fun test_register_receiver_with_zero_state_id() { // Verify the receiver is registered let package_id_1 = get_package_id_from_proof(); let config = receiver_registry::get_receiver_config(&ref, package_id_1); - let (_, _, _) = receiver_registry::get_receiver_config_fields(config); + let (_, _) = receiver_registry::get_receiver_config_fields(config); // Verify get_receiver_info returns correct values - let (module_name, proof_typename_str, _) = receiver_registry::get_receiver_info( + let (module_name, proof_typename_str) = receiver_registry::get_receiver_info( &ref, package_id_1, ); @@ -370,16 +367,15 @@ public fun test_complete_receiver_lifecycle() { // 3. Verify config is correct let config = receiver_registry::get_receiver_config(&ref, package_id_1); - let (module_name, proof_typename, expected_count) = receiver_registry::get_receiver_config_fields(config); + let (module_name, proof_typename) = receiver_registry::get_receiver_config_fields(config); assert!(module_name == string::utf8(b"receiver_registry_tests")); assert!( proof_typename == type_name::into_string(type_name::with_defining_ids()), ); - assert!(expected_count == 0); // 4. Verify module and proof typename lookup - let (lookup_module, lookup_proof_typename_str, _) = receiver_registry::get_receiver_info( + let (lookup_module, lookup_proof_typename_str) = receiver_registry::get_receiver_info( &ref, package_id_1, ); @@ -391,13 +387,12 @@ public fun test_complete_receiver_lifecycle() { assert!(!receiver_registry::is_registered_receiver(&ref, package_id_1)); // 6. Verify lookup returns empty values after unregistration - let (empty_module, empty_proof_typename_str, empty_count) = receiver_registry::get_receiver_info( + let (empty_module, empty_proof_typename_str) = receiver_registry::get_receiver_info( &ref, package_id_1, ); assert!(empty_module == string::utf8(b"")); assert!(empty_proof_typename_str == ascii::string(b"")); - assert!(empty_count == 0); cleanup_test(scenario, ref, owner_cap); } diff --git a/contracts/ccip/ccip_dummy_receiver/sources/ccip_dummy_receiver.move b/contracts/ccip/ccip_dummy_receiver/sources/ccip_dummy_receiver.move index 9ea6ab139..e0ed0322e 100644 --- a/contracts/ccip/ccip_dummy_receiver/sources/ccip_dummy_receiver.move +++ b/contracts/ccip/ccip_dummy_receiver/sources/ccip_dummy_receiver.move @@ -87,8 +87,7 @@ fun init(otw: DUMMY_RECEIVER, ctx: &mut TxContext) { public fun register_receiver(owner_cap: &OwnerCap, ref: &mut CCIPObjectRef) { let publisher: &Publisher = df::borrow(&owner_cap.id, PublisherKey {}); let publisher_wrapper = publisher_wrapper::create(publisher, DummyReceiverProof {}); - // 2 extra object IDs: &Clock and &mut CCIPReceiverState - receiver_registry::register_receiver(ref, publisher_wrapper, DummyReceiverProof {}, 2); + receiver_registry::register_receiver(ref, publisher_wrapper, DummyReceiverProof {}); } public fun get_counter(state: &CCIPReceiverState): u64 { diff --git a/relayer/chainwriter/ptb/offramp/execute.go b/relayer/chainwriter/ptb/offramp/execute.go index 1b4b354a4..996be58b2 100644 --- a/relayer/chainwriter/ptb/offramp/execute.go +++ b/relayer/chainwriter/ptb/offramp/execute.go @@ -8,6 +8,7 @@ import ( "context" "encoding/hex" "fmt" + "runtime" "strings" "github.com/block-vision/sui-go-sdk/models" @@ -66,6 +67,18 @@ func BuildOffRampExecutePTB( signerAddress string, addressMappings OffRampAddressMappings, ) (err error) { + defer func() { + if r := recover(); r != nil { + buf := make([]byte, 4096) + n := runtime.Stack(buf, false) + lggr.Errorw("panic recovered in BuildOffRampExecutePTB", + "panic", fmt.Sprintf("%v", r), + "stack", string(buf[:n]), + ) + err = fmt.Errorf("BuildOffRampExecutePTB panicked: %v", r) + } + }() + sdkClient := ptbClient.GetClient() offrampArgs, err := DecodeOffRampExecCallArgs(args.Args) if err != nil { @@ -302,8 +315,12 @@ func AppendPTBCommandForTokenPool( return nil, fmt.Errorf("missing function signature for token pool function not found in module (%s)", OfframpTokenPoolFunctionName) } - // Figure out the parameter types from the normalized module of the token pool - paramTypes, err := DecodeParameters(lggr, functionSignature.(map[string]any), "parameters") + funcSigMap, ok := functionSignature.(map[string]any) + if !ok { + return nil, fmt.Errorf("token pool function signature is %T, expected map[string]any", functionSignature) + } + + paramTypes, err := DecodeParameters(lggr, funcSigMap, "parameters") if err != nil { return nil, fmt.Errorf("failed to decode parameters for token pool function: %w", err) } @@ -483,15 +500,19 @@ func AppendPTBCommandForReceiver( return nil, fmt.Errorf("missing function signature for receiver function not found in module (%s)", functionName) } - // Figure out the parameter types from the normalized module of the token pool - paramTypes, err = DecodeParameters(lggr, functionSignature.(map[string]any), "parameters") + funcSigMap, ok := functionSignature.(map[string]any) + if !ok { + return nil, fmt.Errorf("receiver function signature is %T, expected map[string]any", functionSignature) + } + + paramTypes, err = DecodeParameters(lggr, funcSigMap, "parameters") if err != nil { - return nil, fmt.Errorf("failed to decode parameters for token pool function: %w", err) + return nil, fmt.Errorf("failed to decode parameters for receiver function: %w", err) } if err := ValidateReceiverCallbackSignature( lggr, - functionSignature.(map[string]any), + funcSigMap, paramTypes, addressMappings.CcipPackageId, addressMappings.OffRampPackageId, diff --git a/relayer/chainwriter/ptb/offramp/helpers.go b/relayer/chainwriter/ptb/offramp/helpers.go index 87e50bf16..96acfcb60 100644 --- a/relayer/chainwriter/ptb/offramp/helpers.go +++ b/relayer/chainwriter/ptb/offramp/helpers.go @@ -145,8 +145,7 @@ type SuiArgumentMetadata struct { Type string `json:"type"` } -func decodeParam(lggr logger.Logger, param any, reference string) SuiArgumentMetadata { - // Handle primitive types (strings like "U64", "Bool", etc.) +func decodeParam(lggr logger.Logger, param any, reference string) (SuiArgumentMetadata, error) { if str, ok := param.(string); ok { return SuiArgumentMetadata{ Address: "", @@ -155,51 +154,138 @@ func decodeParam(lggr logger.Logger, param any, reference string) SuiArgumentMet Reference: reference, TypeArguments: []TypeParameter{}, Type: ParseParamType(lggr, str), - } + }, nil + } + + m, ok := param.(map[string]any) + if !ok { + return SuiArgumentMetadata{}, fmt.Errorf("unsupported parameter shape: expected string or map, got %T", param) } - // Handle complex types (maps) - m := param.(map[string]any) for k, v := range m { switch k { + case "TypeParameter": + return SuiArgumentMetadata{}, fmt.Errorf( + "unsupported TypeParameter in normalized module ABI (value: %v); "+ + "generic type parameters cannot be resolved by the relayer", v) case "Struct": - // Direct struct - s := v.(map[string]any) - typeArguments := []TypeParameter{} - for _, ta := range s["typeArguments"].([]any) { - typeArgument := ta.(map[string]any) - typeArguments = append(typeArguments, TypeParameter{TypeParameter: typeArgument["TypeParameter"].(float64)}) - } - return SuiArgumentMetadata{ - Address: s["address"].(string), - Module: s["module"].(string), - Name: s["name"].(string), - Reference: reference, - TypeArguments: typeArguments, - Type: ParseParamType(lggr, v), - } + return decodeStructParam(lggr, v, reference) case "Reference", "MutableReference", "Vector": - // Reference and MutableReference are the same thing - // We need to unwrap the struct return decodeParam(lggr, v, k) default: - inner := v.(map[string]any)["Struct"].(map[string]any) - typeArguments := []TypeParameter{} - for _, ta := range inner["typeArguments"].([]any) { - typeArgument := ta.(map[string]any) - typeArguments = append(typeArguments, TypeParameter{TypeParameter: typeArgument["TypeParameter"].(float64)}) + vMap, ok := v.(map[string]any) + if !ok { + return SuiArgumentMetadata{}, fmt.Errorf( + "unsupported parameter wrapper %q: expected map value, got %T", k, v) + } + innerRaw, exists := vMap["Struct"] + if !exists { + return SuiArgumentMetadata{}, fmt.Errorf( + "unsupported parameter wrapper %q: missing nested Struct field", k) + } + inner, ok := innerRaw.(map[string]any) + if !ok { + return SuiArgumentMetadata{}, fmt.Errorf( + "unsupported parameter wrapper %q: Struct field is %T, expected map", k, innerRaw) + } + typeArguments, err := extractTypeArguments(inner) + if err != nil { + return SuiArgumentMetadata{}, fmt.Errorf("parameter wrapper %q: %w", k, err) + } + addr, module, name, err := extractStructIdentity(inner) + if err != nil { + return SuiArgumentMetadata{}, fmt.Errorf("parameter wrapper %q: %w", k, err) } return SuiArgumentMetadata{ - Address: inner["address"].(string), - Module: inner["module"].(string), - Name: inner["name"].(string), + Address: addr, + Module: module, + Name: name, Reference: k, TypeArguments: typeArguments, Type: ParseParamType(lggr, v), - } + }, nil } } - return SuiArgumentMetadata{} + return SuiArgumentMetadata{}, nil +} + +func decodeStructParam(lggr logger.Logger, v any, reference string) (SuiArgumentMetadata, error) { + s, ok := v.(map[string]any) + if !ok { + return SuiArgumentMetadata{}, fmt.Errorf("Struct value is %T, expected map", v) + } + typeArguments, err := extractTypeArguments(s) + if err != nil { + return SuiArgumentMetadata{}, fmt.Errorf("Struct: %w", err) + } + addr, module, name, err := extractStructIdentity(s) + if err != nil { + return SuiArgumentMetadata{}, fmt.Errorf("Struct: %w", err) + } + return SuiArgumentMetadata{ + Address: addr, + Module: module, + Name: name, + Reference: reference, + TypeArguments: typeArguments, + Type: ParseParamType(lggr, v), + }, nil +} + +func extractTypeArguments(s map[string]any) ([]TypeParameter, error) { + taRaw, exists := s["typeArguments"] + if !exists { + return []TypeParameter{}, nil + } + taSlice, ok := taRaw.([]any) + if !ok { + return nil, fmt.Errorf("typeArguments is %T, expected array", taRaw) + } + typeArguments := make([]TypeParameter, 0, len(taSlice)) + for i, ta := range taSlice { + taMap, ok := ta.(map[string]any) + if !ok { + return nil, fmt.Errorf("typeArguments[%d] is %T, expected map", i, ta) + } + tpRaw, exists := taMap["TypeParameter"] + if !exists { + return nil, fmt.Errorf("typeArguments[%d] missing TypeParameter field", i) + } + tpFloat, ok := tpRaw.(float64) + if !ok { + return nil, fmt.Errorf("typeArguments[%d].TypeParameter is %T, expected float64", i, tpRaw) + } + typeArguments = append(typeArguments, TypeParameter{TypeParameter: tpFloat}) + } + return typeArguments, nil +} + +func extractStructIdentity(s map[string]any) (addr string, module string, name string, err error) { + addrRaw, ok := s["address"] + if !ok { + return "", "", "", fmt.Errorf("missing 'address' field in struct") + } + addr, ok = addrRaw.(string) + if !ok { + return "", "", "", fmt.Errorf("'address' field is %T, expected string", addrRaw) + } + modRaw, ok := s["module"] + if !ok { + return "", "", "", fmt.Errorf("missing 'module' field in struct") + } + module, ok = modRaw.(string) + if !ok { + return "", "", "", fmt.Errorf("'module' field is %T, expected string", modRaw) + } + nameRaw, ok := s["name"] + if !ok { + return "", "", "", fmt.Errorf("missing 'name' field in struct") + } + name, ok = nameRaw.(string) + if !ok { + return "", "", "", fmt.Errorf("'name' field is %T, expected string", nameRaw) + } + return addr, module, name, nil } func ParseParamType(lggr logger.Logger, param interface{}) string { @@ -276,7 +362,11 @@ func DecodeParameters(lggr logger.Logger, function map[string]any, key string) ( defaultReference := "Reference" decodedParameters := make([]SuiArgumentMetadata, len(parameters)) for i, parameter := range parameters { - decodedParameters[i] = decodeParam(lggr, parameter, defaultReference) + decoded, err := decodeParam(lggr, parameter, defaultReference) + if err != nil { + return nil, fmt.Errorf("failed to decode parameter %d: %w", i, err) + } + decodedParameters[i] = decoded } lggr.Debugw("decoded parameters", "decodedParameters", decodedParameters) diff --git a/relayer/chainwriter/ptb/offramp/helpers_test.go b/relayer/chainwriter/ptb/offramp/helpers_test.go new file mode 100644 index 000000000..3a0b1d0d8 --- /dev/null +++ b/relayer/chainwriter/ptb/offramp/helpers_test.go @@ -0,0 +1,369 @@ +package offramp + +import ( + "testing" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecodeParam_PrimitiveString(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + tests := []struct { + input string + wantName string + wantType string + }{ + {"U8", "U8", "u8"}, + {"U64", "U64", "u64"}, + {"Bool", "Bool", "bool"}, + {"Address", "Address", "object_id"}, + } + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + t.Parallel() + meta, err := decodeParam(lggr, tc.input, "Reference") + require.NoError(t, err) + assert.Equal(t, tc.wantName, meta.Name) + assert.Equal(t, tc.wantType, meta.Type) + assert.Equal(t, "Reference", meta.Reference) + }) + } +} + +func TestDecodeParam_StructDirect(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "Struct": map[string]any{ + "address": "0xcccc", + "module": "state_object", + "name": "CCIPObjectRef", + "typeArguments": []any{}, + }, + } + meta, err := decodeParam(lggr, param, "Reference") + require.NoError(t, err) + assert.Equal(t, "0xcccc", meta.Address) + assert.Equal(t, "state_object", meta.Module) + assert.Equal(t, "CCIPObjectRef", meta.Name) + assert.Equal(t, "Reference", meta.Reference) + assert.Empty(t, meta.TypeArguments) +} + +func TestDecodeParam_StructWithTypeArguments(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "Struct": map[string]any{ + "address": "0xaabb", + "module": "publisher_wrapper", + "name": "PublisherWrapper", + "typeArguments": []any{ + map[string]any{"TypeParameter": float64(0)}, + }, + }, + } + meta, err := decodeParam(lggr, param, "Reference") + require.NoError(t, err) + assert.Equal(t, "PublisherWrapper", meta.Name) + assert.Len(t, meta.TypeArguments, 1) + assert.Equal(t, float64(0), meta.TypeArguments[0].TypeParameter) +} + +func TestDecodeParam_Reference(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "Reference": map[string]any{ + "Struct": map[string]any{ + "address": "0x2", + "module": "clock", + "name": "Clock", + "typeArguments": []any{}, + }, + }, + } + meta, err := decodeParam(lggr, param, "Reference") + require.NoError(t, err) + assert.Equal(t, "Clock", meta.Name) + assert.Equal(t, "Reference", meta.Reference) +} + +func TestDecodeParam_MutableReference(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "MutableReference": map[string]any{ + "Struct": map[string]any{ + "address": "0xdead", + "module": "my_receiver", + "name": "ReceiverState", + "typeArguments": []any{}, + }, + }, + } + meta, err := decodeParam(lggr, param, "Reference") + require.NoError(t, err) + assert.Equal(t, "ReceiverState", meta.Name) + assert.Equal(t, "MutableReference", meta.Reference) +} + +func TestDecodeParam_VectorPrimitive(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{"Vector": "U8"} + meta, err := decodeParam(lggr, param, "Reference") + require.NoError(t, err) + assert.Equal(t, "U8", meta.Name) + assert.Equal(t, "Vector", meta.Reference) + assert.Equal(t, "u8", meta.Type) +} + +func TestDecodeParam_TypeParameter_ReturnsError(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{"TypeParameter": float64(0)} + _, err := decodeParam(lggr, param, "Reference") + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported TypeParameter") +} + +func TestDecodeParam_VectorTypeParameter_ReturnsError(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "Vector": map[string]any{"TypeParameter": float64(0)}, + } + _, err := decodeParam(lggr, param, "Reference") + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported TypeParameter") +} + +func TestDecodeParam_NonMapNonString_ReturnsError(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + _, err := decodeParam(lggr, float64(42), "Reference") + require.Error(t, err) + assert.Contains(t, err.Error(), "expected string or map") +} + +func TestDecodeParam_StructMissingAddress_ReturnsError(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "Struct": map[string]any{ + "module": "foo", + "name": "Bar", + "typeArguments": []any{}, + }, + } + _, err := decodeParam(lggr, param, "Reference") + require.Error(t, err) + assert.Contains(t, err.Error(), "missing 'address'") +} + +func TestDecodeParam_StructNonMapValue_ReturnsError(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "Struct": "not-a-map", + } + _, err := decodeParam(lggr, param, "Reference") + require.Error(t, err) + assert.Contains(t, err.Error(), "Struct value is") +} + +func TestDecodeParam_DefaultBranch_MissingNestedStruct(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "SomeOtherKey": map[string]any{"NotStruct": "value"}, + } + _, err := decodeParam(lggr, param, "Reference") + require.Error(t, err) + assert.Contains(t, err.Error(), "missing nested Struct") +} + +func TestDecodeParam_DefaultBranch_NonMapValue(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "SomeOtherKey": "not-a-map", + } + _, err := decodeParam(lggr, param, "Reference") + require.Error(t, err) + assert.Contains(t, err.Error(), "expected map value") +} + +func TestDecodeParam_TypeArgumentsBadShape(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "Struct": map[string]any{ + "address": "0x1", + "module": "foo", + "name": "Bar", + "typeArguments": []any{"not-a-map"}, + }, + } + _, err := decodeParam(lggr, param, "Reference") + require.Error(t, err) + assert.Contains(t, err.Error(), "typeArguments[0]") +} + +func TestDecodeParameters_RejectsTypeParameter(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + funcSig := map[string]any{ + "parameters": []any{ + map[string]any{"Vector": "U8"}, + map[string]any{"TypeParameter": float64(0)}, + }, + } + _, err := DecodeParameters(lggr, funcSig, "parameters") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode parameter 1") + assert.Contains(t, err.Error(), "unsupported TypeParameter") +} + +func TestDecodeParameters_ValidStandardReceiver(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + funcSig := map[string]any{ + "parameters": []any{ + map[string]any{"Vector": "U8"}, + map[string]any{ + "Reference": map[string]any{ + "Struct": map[string]any{ + "address": "0xcccc", + "module": "state_object", + "name": "CCIPObjectRef", + "typeArguments": []any{}, + }, + }, + }, + map[string]any{ + "Struct": map[string]any{ + "address": "0xcccc", + "module": "client", + "name": "Any2SuiMessage", + "typeArguments": []any{}, + }, + }, + map[string]any{ + "MutableReference": map[string]any{ + "Struct": map[string]any{ + "address": "0x2", + "module": "tx_context", + "name": "TxContext", + "typeArguments": []any{}, + }, + }, + }, + }, + } + paramTypes, err := DecodeParameters(lggr, funcSig, "parameters") + require.NoError(t, err) + assert.Equal(t, []string{"vector", "&object", "&object"}, paramTypes) +} + +func TestDecodeParameters_MissingKey(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + _, err := DecodeParameters(lggr, map[string]any{}, "parameters") + require.Error(t, err) + assert.Contains(t, err.Error(), "missing or nil") +} + +func TestDecodeParameters_NotArray(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + _, err := DecodeParameters(lggr, map[string]any{"parameters": "oops"}, "parameters") + require.Error(t, err) + assert.Contains(t, err.Error(), "not an array") +} + +func TestExtractTypeArguments_Empty(t *testing.T) { + t.Parallel() + + s := map[string]any{"typeArguments": []any{}} + ta, err := extractTypeArguments(s) + require.NoError(t, err) + assert.Empty(t, ta) +} + +func TestExtractTypeArguments_Missing(t *testing.T) { + t.Parallel() + + s := map[string]any{} + ta, err := extractTypeArguments(s) + require.NoError(t, err) + assert.Empty(t, ta) +} + +func TestExtractTypeArguments_WrongType(t *testing.T) { + t.Parallel() + + s := map[string]any{"typeArguments": "not-an-array"} + _, err := extractTypeArguments(s) + require.Error(t, err) + assert.Contains(t, err.Error(), "expected array") +} + +func TestExtractStructIdentity_Valid(t *testing.T) { + t.Parallel() + + s := map[string]any{ + "address": "0x1", + "module": "foo", + "name": "Bar", + } + addr, mod, name, err := extractStructIdentity(s) + require.NoError(t, err) + assert.Equal(t, "0x1", addr) + assert.Equal(t, "foo", mod) + assert.Equal(t, "Bar", name) +} + +func TestExtractStructIdentity_MissingFields(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input map[string]any + want string + }{ + {"missing address", map[string]any{"module": "a", "name": "b"}, "missing 'address'"}, + {"missing module", map[string]any{"address": "a", "name": "b"}, "missing 'module'"}, + {"missing name", map[string]any{"address": "a", "module": "b"}, "missing 'name'"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + _, _, _, err := extractStructIdentity(tc.input) + require.Error(t, err) + assert.Contains(t, err.Error(), tc.want) + }) + } +} diff --git a/relayer/chainwriter/ptb/offramp/receiver_validation.go b/relayer/chainwriter/ptb/offramp/receiver_validation.go index 8696042f6..de5771bab 100644 --- a/relayer/chainwriter/ptb/offramp/receiver_validation.go +++ b/relayer/chainwriter/ptb/offramp/receiver_validation.go @@ -43,11 +43,12 @@ func ValidateReceiverCallbackSignature( return fmt.Errorf("'parameters' field is not an array in receiver function signature") } - // Walk raw parameters, skipping TxContext (mirroring DecodeParameters), - // and inspect every extra parameter beyond the standard 3. decodedIdx := 0 for i, rawParam := range parameters { - meta := decodeParam(lggr, rawParam, "Reference") + meta, err := decodeParam(lggr, rawParam, "Reference") + if err != nil { + return fmt.Errorf("receiver callback parameter %d: %w", i, err) + } if meta.Name == "TxContext" { continue } diff --git a/relayer/chainwriter/ptb/offramp/receiver_validation_test.go b/relayer/chainwriter/ptb/offramp/receiver_validation_test.go index 9d39301de..bafe243b3 100644 --- a/relayer/chainwriter/ptb/offramp/receiver_validation_test.go +++ b/relayer/chainwriter/ptb/offramp/receiver_validation_test.go @@ -309,8 +309,6 @@ func TestValidateReceiverCallbackSignature_ImmutableCcipRefAllowed(t *testing.T) t.Parallel() lggr := logger.Test(t) - // Immutable reference to a CCIP type as an extra param should be allowed - // (read-only access is not dangerous in the same way mutable access is). params := append(standardParams(testCcipPackageId), map[string]any{ "Reference": map[string]any{ @@ -329,3 +327,55 @@ func TestValidateReceiverCallbackSignature_ImmutableCcipRefAllowed(t *testing.T) err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) require.NoError(t, err, "immutable references are safe; only mutable references to protocol types are denied") } + +func TestValidateReceiverCallbackSignature_TypeParameterReturnsError(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + // Reproduces the vulnerability: a malicious receiver with + // public fun ccip_receive(v: vector, ...) produces a normalized ABI + // containing {"Vector":{"TypeParameter":0}}. Previously this panicked; now + // it must return an error. + params := []any{ + map[string]any{"Vector": map[string]any{"TypeParameter": float64(0)}}, + map[string]any{ + "Reference": map[string]any{ + "Struct": map[string]any{ + "address": testCcipPackageId, + "module": "state_object", + "name": "CCIPObjectRef", + "typeArguments": []any{}, + }, + }, + }, + map[string]any{ + "Struct": map[string]any{ + "address": testCcipPackageId, + "module": "client", + "name": "Any2SuiMessage", + "typeArguments": []any{}, + }, + }, + } + funcSig := map[string]any{"parameters": params} + decodedTypes := []string{"vector", "&object", "object_id"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.Error(t, err, "TypeParameter shape must be rejected, not panic") + assert.Contains(t, err.Error(), "unsupported TypeParameter") +} + +func TestValidateReceiverCallbackSignature_MalformedParamReturnsError(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + params := append(standardParams(testCcipPackageId), + float64(42), + ) + funcSig := map[string]any{"parameters": params} + decodedTypes := []string{"vector", "&object", "object_id", "unknown"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.Error(t, err, "non-map/non-string param must return error, not panic") + assert.Contains(t, err.Error(), "expected string or map") +}