diff --git a/relayer/chainwriter/ptb/offramp/execute.go b/relayer/chainwriter/ptb/offramp/execute.go index 56636229c..996be58b2 100644 --- a/relayer/chainwriter/ptb/offramp/execute.go +++ b/relayer/chainwriter/ptb/offramp/execute.go @@ -8,6 +8,7 @@ import ( "context" "encoding/hex" "fmt" + "runtime" "strings" "github.com/block-vision/sui-go-sdk/models" @@ -66,6 +67,18 @@ func BuildOffRampExecutePTB( signerAddress string, addressMappings OffRampAddressMappings, ) (err error) { + defer func() { + if r := recover(); r != nil { + buf := make([]byte, 4096) + n := runtime.Stack(buf, false) + lggr.Errorw("panic recovered in BuildOffRampExecutePTB", + "panic", fmt.Sprintf("%v", r), + "stack", string(buf[:n]), + ) + err = fmt.Errorf("BuildOffRampExecutePTB panicked: %v", r) + } + }() + sdkClient := ptbClient.GetClient() offrampArgs, err := DecodeOffRampExecCallArgs(args.Args) if err != nil { @@ -302,8 +315,12 @@ func AppendPTBCommandForTokenPool( return nil, fmt.Errorf("missing function signature for token pool function not found in module (%s)", OfframpTokenPoolFunctionName) } - // Figure out the parameter types from the normalized module of the token pool - paramTypes, err := DecodeParameters(lggr, functionSignature.(map[string]any), "parameters") + funcSigMap, ok := functionSignature.(map[string]any) + if !ok { + return nil, fmt.Errorf("token pool function signature is %T, expected map[string]any", functionSignature) + } + + paramTypes, err := DecodeParameters(lggr, funcSigMap, "parameters") if err != nil { return nil, fmt.Errorf("failed to decode parameters for token pool function: %w", err) } @@ -483,10 +500,24 @@ func AppendPTBCommandForReceiver( return nil, fmt.Errorf("missing function signature for receiver function not found in module (%s)", functionName) } - // Figure out the parameter types from the normalized module of the token pool - paramTypes, err = DecodeParameters(lggr, functionSignature.(map[string]any), "parameters") + funcSigMap, ok := functionSignature.(map[string]any) + if !ok { + return nil, fmt.Errorf("receiver function signature is %T, expected map[string]any", functionSignature) + } + + paramTypes, err = DecodeParameters(lggr, funcSigMap, "parameters") if err != nil { - return nil, fmt.Errorf("failed to decode parameters for token pool function: %w", err) + return nil, fmt.Errorf("failed to decode parameters for receiver function: %w", err) + } + + if err := ValidateReceiverCallbackSignature( + lggr, + funcSigMap, + paramTypes, + addressMappings.CcipPackageId, + addressMappings.OffRampPackageId, + ); err != nil { + return nil, fmt.Errorf("receiver callback validation failed for %s::%s: %w", moduleId, functionName, err) } lggr.Debugw("calling receiver", "paramTypes", paramTypes, "paramValues", paramValues) @@ -515,11 +546,21 @@ func AppendPTBCommandForReceiver( lggr.Error("unexpected receiverObjectIds type", "type", fmt.Sprintf("%T", receiverObjectIds)) } + if err := ValidateReceiverObjectIdCount(paramTypes, len(extraArgsValues)); err != nil { + return nil, fmt.Errorf("receiver %s::%s: %w", moduleId, functionName, err) + } + + var receiverObjectIdStrings []string for _, value := range extraArgsValues { objectId := hex.EncodeToString(value) + receiverObjectIdStrings = append(receiverObjectIdStrings, "0x"+objectId) paramValues = append(paramValues, bind.Object{Id: "0x" + objectId}) } + if err := ValidateReceiverObjectIds(receiverObjectIdStrings, addressMappings); err != nil { + return nil, fmt.Errorf("receiver %s::%s: %w", moduleId, functionName, err) + } + encodedReceiverCall, err := boundReceiverContract.EncodeCallArgsWithGenerics( functionName, typeArgsList, diff --git a/relayer/chainwriter/ptb/offramp/helpers.go b/relayer/chainwriter/ptb/offramp/helpers.go index 87e50bf16..96acfcb60 100644 --- a/relayer/chainwriter/ptb/offramp/helpers.go +++ b/relayer/chainwriter/ptb/offramp/helpers.go @@ -145,8 +145,7 @@ type SuiArgumentMetadata struct { Type string `json:"type"` } -func decodeParam(lggr logger.Logger, param any, reference string) SuiArgumentMetadata { - // Handle primitive types (strings like "U64", "Bool", etc.) +func decodeParam(lggr logger.Logger, param any, reference string) (SuiArgumentMetadata, error) { if str, ok := param.(string); ok { return SuiArgumentMetadata{ Address: "", @@ -155,51 +154,138 @@ func decodeParam(lggr logger.Logger, param any, reference string) SuiArgumentMet Reference: reference, TypeArguments: []TypeParameter{}, Type: ParseParamType(lggr, str), - } + }, nil + } + + m, ok := param.(map[string]any) + if !ok { + return SuiArgumentMetadata{}, fmt.Errorf("unsupported parameter shape: expected string or map, got %T", param) } - // Handle complex types (maps) - m := param.(map[string]any) for k, v := range m { switch k { + case "TypeParameter": + return SuiArgumentMetadata{}, fmt.Errorf( + "unsupported TypeParameter in normalized module ABI (value: %v); "+ + "generic type parameters cannot be resolved by the relayer", v) case "Struct": - // Direct struct - s := v.(map[string]any) - typeArguments := []TypeParameter{} - for _, ta := range s["typeArguments"].([]any) { - typeArgument := ta.(map[string]any) - typeArguments = append(typeArguments, TypeParameter{TypeParameter: typeArgument["TypeParameter"].(float64)}) - } - return SuiArgumentMetadata{ - Address: s["address"].(string), - Module: s["module"].(string), - Name: s["name"].(string), - Reference: reference, - TypeArguments: typeArguments, - Type: ParseParamType(lggr, v), - } + return decodeStructParam(lggr, v, reference) case "Reference", "MutableReference", "Vector": - // Reference and MutableReference are the same thing - // We need to unwrap the struct return decodeParam(lggr, v, k) default: - inner := v.(map[string]any)["Struct"].(map[string]any) - typeArguments := []TypeParameter{} - for _, ta := range inner["typeArguments"].([]any) { - typeArgument := ta.(map[string]any) - typeArguments = append(typeArguments, TypeParameter{TypeParameter: typeArgument["TypeParameter"].(float64)}) + vMap, ok := v.(map[string]any) + if !ok { + return SuiArgumentMetadata{}, fmt.Errorf( + "unsupported parameter wrapper %q: expected map value, got %T", k, v) + } + innerRaw, exists := vMap["Struct"] + if !exists { + return SuiArgumentMetadata{}, fmt.Errorf( + "unsupported parameter wrapper %q: missing nested Struct field", k) + } + inner, ok := innerRaw.(map[string]any) + if !ok { + return SuiArgumentMetadata{}, fmt.Errorf( + "unsupported parameter wrapper %q: Struct field is %T, expected map", k, innerRaw) + } + typeArguments, err := extractTypeArguments(inner) + if err != nil { + return SuiArgumentMetadata{}, fmt.Errorf("parameter wrapper %q: %w", k, err) + } + addr, module, name, err := extractStructIdentity(inner) + if err != nil { + return SuiArgumentMetadata{}, fmt.Errorf("parameter wrapper %q: %w", k, err) } return SuiArgumentMetadata{ - Address: inner["address"].(string), - Module: inner["module"].(string), - Name: inner["name"].(string), + Address: addr, + Module: module, + Name: name, Reference: k, TypeArguments: typeArguments, Type: ParseParamType(lggr, v), - } + }, nil } } - return SuiArgumentMetadata{} + return SuiArgumentMetadata{}, nil +} + +func decodeStructParam(lggr logger.Logger, v any, reference string) (SuiArgumentMetadata, error) { + s, ok := v.(map[string]any) + if !ok { + return SuiArgumentMetadata{}, fmt.Errorf("Struct value is %T, expected map", v) + } + typeArguments, err := extractTypeArguments(s) + if err != nil { + return SuiArgumentMetadata{}, fmt.Errorf("Struct: %w", err) + } + addr, module, name, err := extractStructIdentity(s) + if err != nil { + return SuiArgumentMetadata{}, fmt.Errorf("Struct: %w", err) + } + return SuiArgumentMetadata{ + Address: addr, + Module: module, + Name: name, + Reference: reference, + TypeArguments: typeArguments, + Type: ParseParamType(lggr, v), + }, nil +} + +func extractTypeArguments(s map[string]any) ([]TypeParameter, error) { + taRaw, exists := s["typeArguments"] + if !exists { + return []TypeParameter{}, nil + } + taSlice, ok := taRaw.([]any) + if !ok { + return nil, fmt.Errorf("typeArguments is %T, expected array", taRaw) + } + typeArguments := make([]TypeParameter, 0, len(taSlice)) + for i, ta := range taSlice { + taMap, ok := ta.(map[string]any) + if !ok { + return nil, fmt.Errorf("typeArguments[%d] is %T, expected map", i, ta) + } + tpRaw, exists := taMap["TypeParameter"] + if !exists { + return nil, fmt.Errorf("typeArguments[%d] missing TypeParameter field", i) + } + tpFloat, ok := tpRaw.(float64) + if !ok { + return nil, fmt.Errorf("typeArguments[%d].TypeParameter is %T, expected float64", i, tpRaw) + } + typeArguments = append(typeArguments, TypeParameter{TypeParameter: tpFloat}) + } + return typeArguments, nil +} + +func extractStructIdentity(s map[string]any) (addr string, module string, name string, err error) { + addrRaw, ok := s["address"] + if !ok { + return "", "", "", fmt.Errorf("missing 'address' field in struct") + } + addr, ok = addrRaw.(string) + if !ok { + return "", "", "", fmt.Errorf("'address' field is %T, expected string", addrRaw) + } + modRaw, ok := s["module"] + if !ok { + return "", "", "", fmt.Errorf("missing 'module' field in struct") + } + module, ok = modRaw.(string) + if !ok { + return "", "", "", fmt.Errorf("'module' field is %T, expected string", modRaw) + } + nameRaw, ok := s["name"] + if !ok { + return "", "", "", fmt.Errorf("missing 'name' field in struct") + } + name, ok = nameRaw.(string) + if !ok { + return "", "", "", fmt.Errorf("'name' field is %T, expected string", nameRaw) + } + return addr, module, name, nil } func ParseParamType(lggr logger.Logger, param interface{}) string { @@ -276,7 +362,11 @@ func DecodeParameters(lggr logger.Logger, function map[string]any, key string) ( defaultReference := "Reference" decodedParameters := make([]SuiArgumentMetadata, len(parameters)) for i, parameter := range parameters { - decodedParameters[i] = decodeParam(lggr, parameter, defaultReference) + decoded, err := decodeParam(lggr, parameter, defaultReference) + if err != nil { + return nil, fmt.Errorf("failed to decode parameter %d: %w", i, err) + } + decodedParameters[i] = decoded } lggr.Debugw("decoded parameters", "decodedParameters", decodedParameters) diff --git a/relayer/chainwriter/ptb/offramp/helpers_test.go b/relayer/chainwriter/ptb/offramp/helpers_test.go new file mode 100644 index 000000000..3a0b1d0d8 --- /dev/null +++ b/relayer/chainwriter/ptb/offramp/helpers_test.go @@ -0,0 +1,369 @@ +package offramp + +import ( + "testing" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecodeParam_PrimitiveString(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + tests := []struct { + input string + wantName string + wantType string + }{ + {"U8", "U8", "u8"}, + {"U64", "U64", "u64"}, + {"Bool", "Bool", "bool"}, + {"Address", "Address", "object_id"}, + } + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + t.Parallel() + meta, err := decodeParam(lggr, tc.input, "Reference") + require.NoError(t, err) + assert.Equal(t, tc.wantName, meta.Name) + assert.Equal(t, tc.wantType, meta.Type) + assert.Equal(t, "Reference", meta.Reference) + }) + } +} + +func TestDecodeParam_StructDirect(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "Struct": map[string]any{ + "address": "0xcccc", + "module": "state_object", + "name": "CCIPObjectRef", + "typeArguments": []any{}, + }, + } + meta, err := decodeParam(lggr, param, "Reference") + require.NoError(t, err) + assert.Equal(t, "0xcccc", meta.Address) + assert.Equal(t, "state_object", meta.Module) + assert.Equal(t, "CCIPObjectRef", meta.Name) + assert.Equal(t, "Reference", meta.Reference) + assert.Empty(t, meta.TypeArguments) +} + +func TestDecodeParam_StructWithTypeArguments(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "Struct": map[string]any{ + "address": "0xaabb", + "module": "publisher_wrapper", + "name": "PublisherWrapper", + "typeArguments": []any{ + map[string]any{"TypeParameter": float64(0)}, + }, + }, + } + meta, err := decodeParam(lggr, param, "Reference") + require.NoError(t, err) + assert.Equal(t, "PublisherWrapper", meta.Name) + assert.Len(t, meta.TypeArguments, 1) + assert.Equal(t, float64(0), meta.TypeArguments[0].TypeParameter) +} + +func TestDecodeParam_Reference(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "Reference": map[string]any{ + "Struct": map[string]any{ + "address": "0x2", + "module": "clock", + "name": "Clock", + "typeArguments": []any{}, + }, + }, + } + meta, err := decodeParam(lggr, param, "Reference") + require.NoError(t, err) + assert.Equal(t, "Clock", meta.Name) + assert.Equal(t, "Reference", meta.Reference) +} + +func TestDecodeParam_MutableReference(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "MutableReference": map[string]any{ + "Struct": map[string]any{ + "address": "0xdead", + "module": "my_receiver", + "name": "ReceiverState", + "typeArguments": []any{}, + }, + }, + } + meta, err := decodeParam(lggr, param, "Reference") + require.NoError(t, err) + assert.Equal(t, "ReceiverState", meta.Name) + assert.Equal(t, "MutableReference", meta.Reference) +} + +func TestDecodeParam_VectorPrimitive(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{"Vector": "U8"} + meta, err := decodeParam(lggr, param, "Reference") + require.NoError(t, err) + assert.Equal(t, "U8", meta.Name) + assert.Equal(t, "Vector", meta.Reference) + assert.Equal(t, "u8", meta.Type) +} + +func TestDecodeParam_TypeParameter_ReturnsError(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{"TypeParameter": float64(0)} + _, err := decodeParam(lggr, param, "Reference") + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported TypeParameter") +} + +func TestDecodeParam_VectorTypeParameter_ReturnsError(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "Vector": map[string]any{"TypeParameter": float64(0)}, + } + _, err := decodeParam(lggr, param, "Reference") + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported TypeParameter") +} + +func TestDecodeParam_NonMapNonString_ReturnsError(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + _, err := decodeParam(lggr, float64(42), "Reference") + require.Error(t, err) + assert.Contains(t, err.Error(), "expected string or map") +} + +func TestDecodeParam_StructMissingAddress_ReturnsError(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "Struct": map[string]any{ + "module": "foo", + "name": "Bar", + "typeArguments": []any{}, + }, + } + _, err := decodeParam(lggr, param, "Reference") + require.Error(t, err) + assert.Contains(t, err.Error(), "missing 'address'") +} + +func TestDecodeParam_StructNonMapValue_ReturnsError(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "Struct": "not-a-map", + } + _, err := decodeParam(lggr, param, "Reference") + require.Error(t, err) + assert.Contains(t, err.Error(), "Struct value is") +} + +func TestDecodeParam_DefaultBranch_MissingNestedStruct(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "SomeOtherKey": map[string]any{"NotStruct": "value"}, + } + _, err := decodeParam(lggr, param, "Reference") + require.Error(t, err) + assert.Contains(t, err.Error(), "missing nested Struct") +} + +func TestDecodeParam_DefaultBranch_NonMapValue(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "SomeOtherKey": "not-a-map", + } + _, err := decodeParam(lggr, param, "Reference") + require.Error(t, err) + assert.Contains(t, err.Error(), "expected map value") +} + +func TestDecodeParam_TypeArgumentsBadShape(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + param := map[string]any{ + "Struct": map[string]any{ + "address": "0x1", + "module": "foo", + "name": "Bar", + "typeArguments": []any{"not-a-map"}, + }, + } + _, err := decodeParam(lggr, param, "Reference") + require.Error(t, err) + assert.Contains(t, err.Error(), "typeArguments[0]") +} + +func TestDecodeParameters_RejectsTypeParameter(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + funcSig := map[string]any{ + "parameters": []any{ + map[string]any{"Vector": "U8"}, + map[string]any{"TypeParameter": float64(0)}, + }, + } + _, err := DecodeParameters(lggr, funcSig, "parameters") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode parameter 1") + assert.Contains(t, err.Error(), "unsupported TypeParameter") +} + +func TestDecodeParameters_ValidStandardReceiver(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + funcSig := map[string]any{ + "parameters": []any{ + map[string]any{"Vector": "U8"}, + map[string]any{ + "Reference": map[string]any{ + "Struct": map[string]any{ + "address": "0xcccc", + "module": "state_object", + "name": "CCIPObjectRef", + "typeArguments": []any{}, + }, + }, + }, + map[string]any{ + "Struct": map[string]any{ + "address": "0xcccc", + "module": "client", + "name": "Any2SuiMessage", + "typeArguments": []any{}, + }, + }, + map[string]any{ + "MutableReference": map[string]any{ + "Struct": map[string]any{ + "address": "0x2", + "module": "tx_context", + "name": "TxContext", + "typeArguments": []any{}, + }, + }, + }, + }, + } + paramTypes, err := DecodeParameters(lggr, funcSig, "parameters") + require.NoError(t, err) + assert.Equal(t, []string{"vector", "&object", "&object"}, paramTypes) +} + +func TestDecodeParameters_MissingKey(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + _, err := DecodeParameters(lggr, map[string]any{}, "parameters") + require.Error(t, err) + assert.Contains(t, err.Error(), "missing or nil") +} + +func TestDecodeParameters_NotArray(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + _, err := DecodeParameters(lggr, map[string]any{"parameters": "oops"}, "parameters") + require.Error(t, err) + assert.Contains(t, err.Error(), "not an array") +} + +func TestExtractTypeArguments_Empty(t *testing.T) { + t.Parallel() + + s := map[string]any{"typeArguments": []any{}} + ta, err := extractTypeArguments(s) + require.NoError(t, err) + assert.Empty(t, ta) +} + +func TestExtractTypeArguments_Missing(t *testing.T) { + t.Parallel() + + s := map[string]any{} + ta, err := extractTypeArguments(s) + require.NoError(t, err) + assert.Empty(t, ta) +} + +func TestExtractTypeArguments_WrongType(t *testing.T) { + t.Parallel() + + s := map[string]any{"typeArguments": "not-an-array"} + _, err := extractTypeArguments(s) + require.Error(t, err) + assert.Contains(t, err.Error(), "expected array") +} + +func TestExtractStructIdentity_Valid(t *testing.T) { + t.Parallel() + + s := map[string]any{ + "address": "0x1", + "module": "foo", + "name": "Bar", + } + addr, mod, name, err := extractStructIdentity(s) + require.NoError(t, err) + assert.Equal(t, "0x1", addr) + assert.Equal(t, "foo", mod) + assert.Equal(t, "Bar", name) +} + +func TestExtractStructIdentity_MissingFields(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input map[string]any + want string + }{ + {"missing address", map[string]any{"module": "a", "name": "b"}, "missing 'address'"}, + {"missing module", map[string]any{"address": "a", "name": "b"}, "missing 'module'"}, + {"missing name", map[string]any{"address": "a", "module": "b"}, "missing 'name'"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + _, _, _, err := extractStructIdentity(tc.input) + require.Error(t, err) + assert.Contains(t, err.Error(), tc.want) + }) + } +} diff --git a/relayer/chainwriter/ptb/offramp/receiver_validation.go b/relayer/chainwriter/ptb/offramp/receiver_validation.go new file mode 100644 index 000000000..de5771bab --- /dev/null +++ b/relayer/chainwriter/ptb/offramp/receiver_validation.go @@ -0,0 +1,147 @@ +package offramp + +import ( + "fmt" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +// receiverStandardParamCount is the number of standard parameters that every +// ccip_receive callback must begin with: +// +// [0] expected_message_id: vector +// [1] ref: &CCIPObjectRef +// [2] message: Any2SuiMessage +const receiverStandardParamCount = 3 + +// ValidateReceiverCallbackSignature validates that a receiver's ccip_receive +// callback does not declare extra parameters whose types belong to known CCIP +// protocol packages. This prevents a malicious receiver from tricking the +// relayer into injecting protocol-owned objects (e.g. OnRampState) as mutable +// PTB inputs in the transmitter-signed transaction. +func ValidateReceiverCallbackSignature( + lggr logger.Logger, + functionSig map[string]any, + decodedParamTypes []string, + ccipPackageId string, + offRampPackageId string, +) error { + if len(decodedParamTypes) < receiverStandardParamCount { + return fmt.Errorf( + "receiver callback has %d parameters, expected at least %d standard parameters "+ + "(expected_message_id, &CCIPObjectRef, Any2SuiMessage)", + len(decodedParamTypes), receiverStandardParamCount, + ) + } + + parametersRaw, ok := functionSig["parameters"] + if !ok { + return fmt.Errorf("missing 'parameters' field in receiver function signature") + } + parameters, ok := parametersRaw.([]any) + if !ok { + return fmt.Errorf("'parameters' field is not an array in receiver function signature") + } + + decodedIdx := 0 + for i, rawParam := range parameters { + meta, err := decodeParam(lggr, rawParam, "Reference") + if err != nil { + return fmt.Errorf("receiver callback parameter %d: %w", i, err) + } + if meta.Name == "TxContext" { + continue + } + + if decodedIdx >= receiverStandardParamCount { + if meta.Reference == "MutableReference" { + if isDeniedProtocolPackage(meta.Address, ccipPackageId, offRampPackageId) { + return fmt.Errorf( + "receiver callback parameter %d declares mutable reference to CCIP protocol type %s::%s::%s; "+ + "receiver callbacks must not accept mutable references to CCIP protocol objects", + i, meta.Address, meta.Module, meta.Name, + ) + } + if isDeniedProtocolModule(meta.Module, meta.Name) { + return fmt.Errorf( + "receiver callback parameter %d references denied protocol type %s::%s; "+ + "receiver callbacks must not accept references to CCIP internal objects", + i, meta.Module, meta.Name, + ) + } + } + } + + decodedIdx++ + } + + return nil +} + +// ValidateReceiverObjectIdCount ensures the number of receiverObjectIds matches +// the number of extra parameters declared by the callback beyond the standard 3. +// A mismatch indicates the callback ABI and the message's extra args are +// inconsistent, which is a precondition for the object injection attack. +func ValidateReceiverObjectIdCount(decodedParamTypes []string, receiverObjectIdCount int) error { + expectedExtraParams := len(decodedParamTypes) - receiverStandardParamCount + if expectedExtraParams < 0 { + expectedExtraParams = 0 + } + if receiverObjectIdCount != expectedExtraParams { + return fmt.Errorf( + "receiver callback declares %d extra object parameters but receiverObjectIds contains %d entries; counts must match", + expectedExtraParams, receiverObjectIdCount, + ) + } + return nil +} + +// ValidateReceiverObjectIds checks that none of the supplied receiver object +// IDs reference known CCIP protocol objects. Accepting protocol objects as +// receiver callback arguments would let a malicious receiver modify protocol +// state via the transmitter-signed PTB. +func ValidateReceiverObjectIds(objectIds []string, addressMappings *OffRampAddressMappings) error { + denied := map[string]string{ + addressMappings.CcipObjectRef: "CCIPObjectRef", + addressMappings.OffRampState: "OffRampState", + } + if addressMappings.CcipOwnerCap != "" { + denied[addressMappings.CcipOwnerCap] = "CcipOwnerCap" + } + + for i, objectId := range objectIds { + if name, found := denied[objectId]; found { + return fmt.Errorf( + "receiverObjectIds[%d] (%s) references protocol object %s; "+ + "receiver callbacks must not be passed CCIP protocol objects", + i, objectId, name, + ) + } + } + return nil +} + +func isDeniedProtocolPackage(addr, ccipPackageId, offRampPackageId string) bool { + return addr != "" && (addr == ccipPackageId || addr == offRampPackageId) +} + +// isDeniedProtocolModule provides a defense-in-depth check against known CCIP +// protocol module+type combinations. This catches cases where the attacker's +// package references protocol types whose package ID isn't in addressMappings +// (e.g. the onramp package). +func isDeniedProtocolModule(module, name string) bool { + denied := map[string]map[string]bool{ + "onramp": {"OnRampState": true}, + "offramp": {"OffRampState": true}, + "fee_quoter": {"FeeQuoterState": true}, + "token_admin_registry": {"TokenAdminRegistryState": true}, + "receiver_registry": {"ReceiverRegistry": true}, + "nonce_manager": {"NonceManagerState": true}, + "state_object": {"CCIPObjectRef": true}, + "offramp_state_helper": {"ReceiverParams": true}, + } + if names, ok := denied[module]; ok { + return names[name] + } + return false +} diff --git a/relayer/chainwriter/ptb/offramp/receiver_validation_test.go b/relayer/chainwriter/ptb/offramp/receiver_validation_test.go new file mode 100644 index 000000000..bafe243b3 --- /dev/null +++ b/relayer/chainwriter/ptb/offramp/receiver_validation_test.go @@ -0,0 +1,381 @@ +package offramp + +import ( + "testing" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + testCcipPackageId = "0x00000000000000000000000000000000000000000000000000000000ccipccip" + testOffRampPackageId = "0x000000000000000000000000000000000000000000000000000000000ff2a3f0" + testCcipObjectRef = "0x000000000000000000000000000000000000000000000000000000000bj3c7r3" + testOffRampState = "0x000000000000000000000000000000000000000000000000000000000ff2a357" + testCcipOwnerCap = "0x0000000000000000000000000000000000000000000000000000000000ca9ca9" +) + +func testAddressMappings() *OffRampAddressMappings { + return &OffRampAddressMappings{ + CcipPackageId: testCcipPackageId, + CcipObjectRef: testCcipObjectRef, + CcipOwnerCap: testCcipOwnerCap, + OffRampPackageId: testOffRampPackageId, + OffRampState: testOffRampState, + } +} + +func standardParams(ccipPkgId string) []any { + return []any{ + map[string]any{"Vector": "U8"}, + map[string]any{ + "Reference": map[string]any{ + "Struct": map[string]any{ + "address": ccipPkgId, + "module": "state_object", + "name": "CCIPObjectRef", + "typeArguments": []any{}, + }, + }, + }, + map[string]any{ + "Struct": map[string]any{ + "address": ccipPkgId, + "module": "client", + "name": "Any2SuiMessage", + "typeArguments": []any{}, + }, + }, + } +} + +func TestValidateReceiverCallbackSignature_StandardParams(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + params := standardParams(testCcipPackageId) + funcSig := map[string]any{"parameters": params} + decodedTypes := []string{"vector", "&object", "object_id"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.NoError(t, err) +} + +func TestValidateReceiverCallbackSignature_LegitExtraParams(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + params := append(standardParams(testCcipPackageId), + map[string]any{ + "Reference": map[string]any{ + "Struct": map[string]any{ + "address": "0x2", + "module": "clock", + "name": "Clock", + "typeArguments": []any{}, + }, + }, + }, + map[string]any{ + "MutableReference": map[string]any{ + "Struct": map[string]any{ + "address": "0xdeadbeef", + "module": "my_receiver", + "name": "ReceiverState", + "typeArguments": []any{}, + }, + }, + }, + ) + funcSig := map[string]any{"parameters": params} + decodedTypes := []string{"vector", "&object", "object_id", "&object", "&mut object"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.NoError(t, err, "legitimate extra params (Clock + receiver's own state) should pass") +} + +func TestValidateReceiverCallbackSignature_RejectsMutableCcipProtocolType(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + params := append(standardParams(testCcipPackageId), + map[string]any{ + "MutableReference": map[string]any{ + "Struct": map[string]any{ + "address": testCcipPackageId, + "module": "fee_quoter", + "name": "FeeQuoterState", + "typeArguments": []any{}, + }, + }, + }, + ) + funcSig := map[string]any{"parameters": params} + decodedTypes := []string{"vector", "&object", "object_id", "&mut object"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.Error(t, err) + assert.Contains(t, err.Error(), "mutable reference to CCIP protocol type") + assert.Contains(t, err.Error(), "FeeQuoterState") +} + +func TestValidateReceiverCallbackSignature_RejectsMutableOnRampState(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + onrampPackageId := "0x0000000000000000000000000000000000000000000000000000000000012345" + params := append(standardParams(testCcipPackageId), + map[string]any{ + "MutableReference": map[string]any{ + "Struct": map[string]any{ + "address": onrampPackageId, + "module": "onramp", + "name": "OnRampState", + "typeArguments": []any{}, + }, + }, + }, + ) + funcSig := map[string]any{"parameters": params} + decodedTypes := []string{"vector", "&object", "object_id", "&mut object"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.Error(t, err, "OnRampState should be caught by module name denylist even when package ID is unknown") + assert.Contains(t, err.Error(), "denied protocol type") + assert.Contains(t, err.Error(), "OnRampState") +} + +func TestValidateReceiverCallbackSignature_RejectsMutableOffRampPackageType(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + params := append(standardParams(testCcipPackageId), + map[string]any{ + "MutableReference": map[string]any{ + "Struct": map[string]any{ + "address": testOffRampPackageId, + "module": "offramp", + "name": "OffRampState", + "typeArguments": []any{}, + }, + }, + }, + ) + funcSig := map[string]any{"parameters": params} + decodedTypes := []string{"vector", "&object", "object_id", "&mut object"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.Error(t, err) + assert.Contains(t, err.Error(), "CCIP protocol type") +} + +func TestValidateReceiverCallbackSignature_TooFewParams(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + funcSig := map[string]any{"parameters": []any{map[string]any{"Vector": "U8"}}} + decodedTypes := []string{"vector"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.Error(t, err) + assert.Contains(t, err.Error(), "expected at least 3 standard parameters") +} + +func TestValidateReceiverCallbackSignature_TxContextSkipped(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + params := append(standardParams(testCcipPackageId), + map[string]any{ + "MutableReference": map[string]any{ + "Struct": map[string]any{ + "address": "0x2", + "module": "tx_context", + "name": "TxContext", + "typeArguments": []any{}, + }, + }, + }, + ) + funcSig := map[string]any{"parameters": params} + decodedTypes := []string{"vector", "&object", "object_id"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.NoError(t, err, "TxContext parameter should be skipped") +} + +func TestValidateReceiverObjectIdCount_Matching(t *testing.T) { + t.Parallel() + + decodedTypes := []string{"vector", "&object", "object_id", "&object", "&mut object"} + err := ValidateReceiverObjectIdCount(decodedTypes, 2) + require.NoError(t, err) +} + +func TestValidateReceiverObjectIdCount_ExactlyStandard(t *testing.T) { + t.Parallel() + + decodedTypes := []string{"vector", "&object", "object_id"} + err := ValidateReceiverObjectIdCount(decodedTypes, 0) + require.NoError(t, err) +} + +func TestValidateReceiverObjectIdCount_Mismatch_TooMany(t *testing.T) { + t.Parallel() + + decodedTypes := []string{"vector", "&object", "object_id"} + err := ValidateReceiverObjectIdCount(decodedTypes, 2) + require.Error(t, err) + assert.Contains(t, err.Error(), "declares 0 extra object parameters but receiverObjectIds contains 2") +} + +func TestValidateReceiverObjectIdCount_Mismatch_TooFew(t *testing.T) { + t.Parallel() + + decodedTypes := []string{"vector", "&object", "object_id", "&mut object", "&mut object"} + err := ValidateReceiverObjectIdCount(decodedTypes, 1) + require.Error(t, err) + assert.Contains(t, err.Error(), "declares 2 extra object parameters but receiverObjectIds contains 1") +} + +func TestValidateReceiverObjectIds_Safe(t *testing.T) { + t.Parallel() + + objectIds := []string{ + "0x0000000000000000000000000000000000000000000000000000000000aaaaaa", + "0x0000000000000000000000000000000000000000000000000000000000bbbbbb", + } + err := ValidateReceiverObjectIds(objectIds, testAddressMappings()) + require.NoError(t, err) +} + +func TestValidateReceiverObjectIds_RejectsCcipObjectRef(t *testing.T) { + t.Parallel() + + objectIds := []string{testCcipObjectRef} + err := ValidateReceiverObjectIds(objectIds, testAddressMappings()) + require.Error(t, err) + assert.Contains(t, err.Error(), "CCIPObjectRef") +} + +func TestValidateReceiverObjectIds_RejectsOffRampState(t *testing.T) { + t.Parallel() + + objectIds := []string{testOffRampState} + err := ValidateReceiverObjectIds(objectIds, testAddressMappings()) + require.Error(t, err) + assert.Contains(t, err.Error(), "OffRampState") +} + +func TestValidateReceiverObjectIds_RejectsCcipOwnerCap(t *testing.T) { + t.Parallel() + + objectIds := []string{testCcipOwnerCap} + err := ValidateReceiverObjectIds(objectIds, testAddressMappings()) + require.Error(t, err) + assert.Contains(t, err.Error(), "CcipOwnerCap") +} + +func TestValidateReceiverObjectIds_EmptyList(t *testing.T) { + t.Parallel() + + err := ValidateReceiverObjectIds(nil, testAddressMappings()) + require.NoError(t, err) +} + +func TestIsDeniedProtocolPackage(t *testing.T) { + t.Parallel() + + assert.True(t, isDeniedProtocolPackage(testCcipPackageId, testCcipPackageId, testOffRampPackageId)) + assert.True(t, isDeniedProtocolPackage(testOffRampPackageId, testCcipPackageId, testOffRampPackageId)) + assert.False(t, isDeniedProtocolPackage("0xdeadbeef", testCcipPackageId, testOffRampPackageId)) + assert.False(t, isDeniedProtocolPackage("", testCcipPackageId, testOffRampPackageId)) +} + +func TestIsDeniedProtocolModule(t *testing.T) { + t.Parallel() + + assert.True(t, isDeniedProtocolModule("onramp", "OnRampState")) + assert.True(t, isDeniedProtocolModule("offramp", "OffRampState")) + assert.True(t, isDeniedProtocolModule("fee_quoter", "FeeQuoterState")) + assert.True(t, isDeniedProtocolModule("state_object", "CCIPObjectRef")) + assert.False(t, isDeniedProtocolModule("my_receiver", "ReceiverState")) + assert.False(t, isDeniedProtocolModule("onramp", "SomeOtherType")) + assert.False(t, isDeniedProtocolModule("clock", "Clock")) +} + +func TestValidateReceiverCallbackSignature_ImmutableCcipRefAllowed(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + params := append(standardParams(testCcipPackageId), + map[string]any{ + "Reference": map[string]any{ + "Struct": map[string]any{ + "address": testCcipPackageId, + "module": "state_object", + "name": "CCIPObjectRef", + "typeArguments": []any{}, + }, + }, + }, + ) + funcSig := map[string]any{"parameters": params} + decodedTypes := []string{"vector", "&object", "object_id", "&object"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.NoError(t, err, "immutable references are safe; only mutable references to protocol types are denied") +} + +func TestValidateReceiverCallbackSignature_TypeParameterReturnsError(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + // Reproduces the vulnerability: a malicious receiver with + // public fun ccip_receive(v: vector, ...) produces a normalized ABI + // containing {"Vector":{"TypeParameter":0}}. Previously this panicked; now + // it must return an error. + params := []any{ + map[string]any{"Vector": map[string]any{"TypeParameter": float64(0)}}, + map[string]any{ + "Reference": map[string]any{ + "Struct": map[string]any{ + "address": testCcipPackageId, + "module": "state_object", + "name": "CCIPObjectRef", + "typeArguments": []any{}, + }, + }, + }, + map[string]any{ + "Struct": map[string]any{ + "address": testCcipPackageId, + "module": "client", + "name": "Any2SuiMessage", + "typeArguments": []any{}, + }, + }, + } + funcSig := map[string]any{"parameters": params} + decodedTypes := []string{"vector", "&object", "object_id"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.Error(t, err, "TypeParameter shape must be rejected, not panic") + assert.Contains(t, err.Error(), "unsupported TypeParameter") +} + +func TestValidateReceiverCallbackSignature_MalformedParamReturnsError(t *testing.T) { + t.Parallel() + lggr := logger.Test(t) + + params := append(standardParams(testCcipPackageId), + float64(42), + ) + funcSig := map[string]any{"parameters": params} + decodedTypes := []string{"vector", "&object", "object_id", "unknown"} + + err := ValidateReceiverCallbackSignature(lggr, funcSig, decodedTypes, testCcipPackageId, testOffRampPackageId) + require.Error(t, err, "non-map/non-string param must return error, not panic") + assert.Contains(t, err.Error(), "expected string or map") +}