diff --git a/chasm/tree.go b/chasm/tree.go index 0b8207adac5..e03f8f73d61 100644 --- a/chasm/tree.go +++ b/chasm/tree.go @@ -807,14 +807,6 @@ func (n *Node) serialize() error { } } -// marshalDeterministic encodes a proto.Message to bytes using deterministic -// serialization. Deterministic encoding sorts map keys before encoding, making -// byte comparison a reliable equality check for any well-formed proto message. -// For messages without map fields this is a no-op with no performance overhead. -func marshalDeterministic(m proto.Message) ([]byte, error) { - return proto.MarshalOptions{Deterministic: true}.Marshal(m) -} - // serializeComponentNode serializes the component node. // If this method is updated to modify serialized fields beyond Data and // LastUpdateVersionedTransition, the skip-if-clean revert logic in @@ -831,11 +823,10 @@ func (n *Node) serializeComponentNode() error { var blob *commonpb.DataBlob if !field.val.IsNil() { - data, err := marshalDeterministic(field.val.Interface().(proto.Message)) - if err != nil { - return serialization.NewSerializationError(enumspb.ENCODING_TYPE_PROTO3, err) + var err error + if blob, err = encodeChasmBlob(field.val.Interface().(proto.Message)); err != nil { + return err } - blob = &commonpb.DataBlob{EncodingType: enumspb.ENCODING_TYPE_PROTO3, Data: data} } n.serializedNode.Data = blob @@ -1205,11 +1196,10 @@ func (n *Node) serializeDataNode() error { var blob *commonpb.DataBlob if protoValue != nil { - data, err := marshalDeterministic(protoValue) - if err != nil { - return serialization.NewSerializationError(enumspb.ENCODING_TYPE_PROTO3, err) + var err error + if blob, err = encodeChasmBlob(protoValue); err != nil { + return err } - blob = &commonpb.DataBlob{EncodingType: enumspb.ENCODING_TYPE_PROTO3, Data: data} } n.serializedNode.Data = blob n.updateLastUpdateVersionedTransition() @@ -3077,7 +3067,7 @@ func serializeTask( ) (*commonpb.DataBlob, error) { protoValue, ok := taskValue.Interface().(proto.Message) if ok { - return serialization.ProtoEncode(protoValue) + return encodeChasmBlob(protoValue) } taskGoType := registrableTask.goType @@ -3090,10 +3080,7 @@ func serializeTask( // Handle empty task struct. if taskGoType.NumField() == 0 { - return &commonpb.DataBlob{ - Data: nil, - EncodingType: enumspb.ENCODING_TYPE_PROTO3, - }, nil + return encodeChasmBlob(nil) } // TODO: consider pre-calculating the proto field num when registring the task type. @@ -3112,7 +3099,7 @@ func serializeTask( protoMessageFound = true var err error - blob, err = serialization.ProtoEncode(fieldV.Interface().(proto.Message)) + blob, err = encodeChasmBlob(fieldV.Interface().(proto.Message)) if err != nil { return nil, err } @@ -3501,3 +3488,9 @@ func makeValidationFn( return nil } } + +// encodeChasmBlob encodes CHASM data and task payloads through the env-aware +// serializer while preserving deterministic proto3 bytes for byte comparisons. +func encodeChasmBlob(m proto.Message) (*commonpb.DataBlob, error) { + return serialization.Encode(m, serialization.WithDeterministicProto3) +} diff --git a/chasm/tree_test.go b/chasm/tree_test.go index c47fbe7f8c6..9f97903fb7a 100644 --- a/chasm/tree_test.go +++ b/chasm/tree_test.go @@ -23,7 +23,6 @@ import ( "go.temporal.io/server/common/definition" "go.temporal.io/server/common/log" "go.temporal.io/server/common/metrics" - "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/testing/protoassert" "go.temporal.io/server/common/testing/protorequire" @@ -252,6 +251,7 @@ func (s *nodeSuite) TestSerializeNode_DataAttributes() { err := node.serialize() s.NoError(err) s.NotNil(node.serializedNode.GetData(), "child node serialized value must have data after serialize is called") + s.Equal(enumspb.ENCODING_TYPE_PROTO3, node.serializedNode.GetData().GetEncodingType()) s.Equal([]byte{0xa, 0x2, 0x32, 0x32}, node.serializedNode.GetData().GetData()) s.Equal(valueStateSynced, node.valueState) } @@ -899,7 +899,7 @@ func (s *nodeSuite) TestNodeSnapshot() { func (s *nodeSuite) TestApplyMutation() { mustEncode := func(m proto.Message) *commonpb.DataBlob { - taskBlob, err := serialization.ProtoEncode(m) + taskBlob, err := encodeChasmBlob(m) s.NoError(err) return taskBlob } @@ -1199,9 +1199,7 @@ func (s *nodeSuite) TestApplySnapshot() { // - a new node "SubComponent2" is added. now := timestamppb.Now() - updatedRootData, err := serialization.ProtoEncode(&protoMessageType{ - StartTime: now, - }) + updatedRootData, err := encodeChasmBlob(&protoMessageType{StartTime: now}) s.NoError(err) incomingSnapshot := NodesSnapshot{ Nodes: map[string]*persistencespb.ChasmNode{ @@ -2098,7 +2096,7 @@ func (s *nodeSuite) TestSerializeDeserializeTask() { payload := &commonpb.Payload{ Data: []byte("some-random-data"), } - expectedBlob, err := serialization.ProtoEncode(payload) + expectedBlob, err := encodeChasmBlob(payload) s.NoError(err) testCases := []struct { @@ -2338,8 +2336,9 @@ func (s *nodeSuite) TestCloseTransaction_InvalidateComponentTasks() { payload := &commonpb.Payload{ Data: []byte("some-random-data"), } - taskBlob, err := serialization.ProtoEncode(payload) + taskBlob, err := encodeChasmBlob(payload) s.NoError(err) + emptyTaskBlob := s.emptyDataBlob() persistenceNodes := map[string]*persistencespb.ChasmNode{ "": { @@ -2361,11 +2360,8 @@ func (s *nodeSuite) TestCloseTransaction_InvalidateComponentTasks() { TypeId: testOutboundSideEffectTaskTypeID, VersionedTransition: &persistencespb.VersionedTransition{TransitionCount: 1}, VersionedTransitionOffset: 2, - Data: &commonpb.DataBlob{ - Data: nil, - EncodingType: enumspb.ENCODING_TYPE_PROTO3, - }, - PhysicalTaskStatus: physicalTaskStatusCreated, + Data: emptyTaskBlob, + PhysicalTaskStatus: physicalTaskStatusCreated, }, }, PureTasks: []*persistencespb.ChasmComponentAttributes_Task{ @@ -2457,7 +2453,7 @@ func (s *nodeSuite) TestCloseTransaction_PausedStateInvalidatesTasks() { payload := &commonpb.Payload{ Data: []byte("some-random-data"), } - taskBlob, err := serialization.ProtoEncode(payload) + taskBlob, err := encodeChasmBlob(payload) s.NoError(err) makeTask := func(typeID uint32, offset int64) *persistencespb.ChasmComponentAttributes_Task { @@ -3186,7 +3182,7 @@ func (s *nodeSuite) TestEachPureTask() { now := s.timeSource.Now() mustEncode := func(m proto.Message) *commonpb.DataBlob { - taskBlob, err := serialization.ProtoEncode(m) + taskBlob, err := encodeChasmBlob(m) s.NoError(err) return taskBlob } @@ -3516,6 +3512,7 @@ func (s *nodeSuite) TestExecuteSideEffectTask() { }, } + emptyTaskBlob := s.emptyDataBlob() taskInfo := &persistencespb.ChasmTaskInfo{ ComponentInitialVersionedTransition: &persistencespb.VersionedTransition{ TransitionCount: 1, @@ -3526,10 +3523,7 @@ func (s *nodeSuite) TestExecuteSideEffectTask() { Path: []string{"SubComponent1"}, TypeId: testSideEffectTaskTypeID, ArchetypeId: testComponentTypeID, - Data: &commonpb.DataBlob{ - Data: nil, - EncodingType: enumspb.ENCODING_TYPE_PROTO3, - }, + Data: emptyTaskBlob, } workflowKey := definition.NewWorkflowKey( primitives.NewUUID().String(), @@ -3666,6 +3660,7 @@ func (s *nodeSuite) TestExecuteSideEffectDiscardTask() { primitives.NewUUID().String(), primitives.NewUUID().String(), ) + emptyTaskBlob := s.emptyDataBlob() chasmTask := &tasks.ChasmTask{ WorkflowKey: workflowKey, VisibilityTimestamp: s.timeSource.Now(), @@ -3682,10 +3677,7 @@ func (s *nodeSuite) TestExecuteSideEffectDiscardTask() { Path: []string{"SubComponent1"}, TypeId: testDiscardableSideEffectTaskTypeID, ArchetypeId: testComponentTypeID, - Data: &commonpb.DataBlob{ - Data: nil, - EncodingType: enumspb.ENCODING_TYPE_PROTO3, - }, + Data: emptyTaskBlob, }, } executionKey := ExecutionKey{ @@ -3806,6 +3798,7 @@ func (s *nodeSuite) TestExecuteSideEffectDiscardTask() { } func (s *nodeSuite) TestValidateSideEffectTask() { + emptyTaskBlob := s.emptyDataBlob() taskInfo := &persistencespb.ChasmTaskInfo{ ComponentInitialVersionedTransition: &persistencespb.VersionedTransition{ TransitionCount: 1, @@ -3817,10 +3810,7 @@ func (s *nodeSuite) TestValidateSideEffectTask() { }, Path: rootPath, TypeId: testSideEffectTaskTypeID, - Data: &commonpb.DataBlob{ - Data: nil, - EncodingType: enumspb.ENCODING_TYPE_PROTO3, - }, + Data: emptyTaskBlob, } workflowKey := definition.NewWorkflowKey( primitives.NewUUID().String(), @@ -3950,3 +3940,9 @@ func (s *nodeSuite) newTestTree( } return NewTreeFromDB(serializedNodes, s.registry, s.timeSource, s.nodeBackend, s.nodePathEncoder, s.logger, s.metricsHandler) } + +func (s *nodeSuite) emptyDataBlob() *commonpb.DataBlob { + blob, err := encodeChasmBlob(nil) + s.NoError(err) + return blob +} diff --git a/common/persistence/serialization/codec.go b/common/persistence/serialization/codec.go index 4abc8236954..f8e8426946f 100644 --- a/common/persistence/serialization/codec.go +++ b/common/persistence/serialization/codec.go @@ -23,9 +23,9 @@ import ( // WARNING: This environment variable should only be used for testing; and never set it in production. const SerializerDataEncodingEnvVar = "TEMPORAL_TEST_DATA_ENCODING" -// EncodingTypeFromEnv returns an EncodingType based on the environment variable `TEMPORAL_TEST_DATA_ENCODING`. +// encodingTypeFromEnv returns an EncodingType based on the environment variable `TEMPORAL_TEST_DATA_ENCODING`. // It defaults to "ENCODING_TYPE_PROTO3" codec if the environment variable is not set. -func EncodingTypeFromEnv() enumspb.EncodingType { +func encodingTypeFromEnv() enumspb.EncodingType { codecType := os.Getenv(SerializerDataEncodingEnvVar) switch strings.ToLower(codecType) { case "", "proto3": @@ -38,12 +38,38 @@ func EncodingTypeFromEnv() enumspb.EncodingType { } } -// ProtoEncode is kept for backward compatibility. -func ProtoEncode(m proto.Message) (*commonpb.DataBlob, error) { - return encodeBlob(m, enumspb.ENCODING_TYPE_PROTO3) +type ( + encodeOptions struct { + deterministic bool + } + EncodeOption func(*encodeOptions) +) + +// WithDeterministicProto3 uses deterministic marshaling when Encode selects proto3. +// +// Deterministic encoding sorts map keys before encoding, making byte comparison +// a reliable equality check for any well-formed proto message. For messages +// without map fields this is a no-op with no performance overhead. +var WithDeterministicProto3 EncodeOption = func(opts *encodeOptions) { + opts.deterministic = true } -func encodeBlob(m proto.Message, encoding enumspb.EncodingType) (*commonpb.DataBlob, error) { +// Encode encodes the given proto message. It respects the `TEMPORAL_TEST_DATA_ENCODING` environment variable; +// otherwise, it defaults to "ENCODING_TYPE_PROTO3". +func Encode(m proto.Message, options ...EncodeOption) (*commonpb.DataBlob, error) { + return encodeBlob(m, encodingTypeFromEnv(), options...) +} + +func encodeBlob( + m proto.Message, + encoding enumspb.EncodingType, + options ...EncodeOption, +) (*commonpb.DataBlob, error) { + opts := encodeOptions{} + for _, option := range options { + option(&opts) + } + if m == nil { return &commonpb.DataBlob{ Data: nil, @@ -62,7 +88,7 @@ func encodeBlob(m proto.Message, encoding enumspb.EncodingType) (*commonpb.DataB EncodingType: enumspb.ENCODING_TYPE_JSON, }, nil case enumspb.ENCODING_TYPE_PROTO3: - data, err := proto.Marshal(m) + data, err := proto.MarshalOptions{Deterministic: opts.deterministic}.Marshal(m) if err != nil { return nil, NewSerializationError(enumspb.ENCODING_TYPE_PROTO3, err) } diff --git a/common/persistence/serialization/codec_test.go b/common/persistence/serialization/codec_test.go index 80e7a868c21..77ea8c64822 100644 --- a/common/persistence/serialization/codec_test.go +++ b/common/persistence/serialization/codec_test.go @@ -3,31 +3,41 @@ package serialization import ( "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" persistencespb "go.temporal.io/server/api/persistence/v1" ) -func TestProtoEncode(t *testing.T) { +func TestEncode(t *testing.T) { // only testing edge cases here; happy path is covered plenty already t.Run("nil message", func(t *testing.T) { - blob, err := ProtoEncode(nil) + t.Setenv(SerializerDataEncodingEnvVar, "proto3") + blob, err := Encode(nil) require.NoError(t, err) require.NotNil(t, blob) - assert.Equal(t, enumspb.ENCODING_TYPE_PROTO3, blob.EncodingType) - assert.Nil(t, blob.Data) + require.Equal(t, enumspb.ENCODING_TYPE_PROTO3, blob.EncodingType) + require.Nil(t, blob.Data) }) t.Run("nil pointer message", func(t *testing.T) { + t.Setenv(SerializerDataEncodingEnvVar, "proto3") var shardInfo *persistencespb.ShardInfo - blob, err := ProtoEncode(shardInfo) + blob, err := Encode(shardInfo) require.NoError(t, err) require.NotNil(t, blob) - assert.Equal(t, enumspb.ENCODING_TYPE_PROTO3, blob.EncodingType) - assert.Nil(t, blob.Data) + require.Equal(t, enumspb.ENCODING_TYPE_PROTO3, blob.EncodingType) + require.Nil(t, blob.Data) + }) + + t.Run("respects env encoding", func(t *testing.T) { + t.Setenv(SerializerDataEncodingEnvVar, "json") + blob, err := Encode(nil) + require.NoError(t, err) + require.NotNil(t, blob) + require.Equal(t, enumspb.ENCODING_TYPE_JSON, blob.EncodingType) + require.Nil(t, blob.Data) }) } @@ -38,8 +48,9 @@ func TestProtoDecode(t *testing.T) { var result persistencespb.ShardInfo err := Decode(nil, &result) require.Error(t, err) - assert.IsType(t, &DeserializationError{}, err) - assert.Contains(t, err.Error(), "cannot decode nil") + var deserializationErr *DeserializationError + require.ErrorAs(t, err, &deserializationErr) + require.Contains(t, err.Error(), "cannot decode nil") }) t.Run("empty data blob", func(t *testing.T) { @@ -48,8 +59,9 @@ func TestProtoDecode(t *testing.T) { var result persistencespb.ShardInfo err := Decode(blob, &result) require.Error(t, err) - assert.IsType(t, &UnknownEncodingTypeError{}, err) - assert.Contains(t, err.Error(), "unknown or unsupported encoding type Unspecified") + var unknownEncodingErr *UnknownEncodingTypeError + require.ErrorAs(t, err, &unknownEncodingErr) + require.Contains(t, err.Error(), "unknown or unsupported encoding type Unspecified") }) t.Run("nil data field", func(t *testing.T) { @@ -72,8 +84,9 @@ func TestProtoDecode(t *testing.T) { var result persistencespb.ShardInfo err := Decode(blob, &result) require.Error(t, err) - assert.IsType(t, &UnknownEncodingTypeError{}, err) - assert.Contains(t, err.Error(), "unknown or unsupported encoding type 999") + var unknownEncodingErr *UnknownEncodingTypeError + require.ErrorAs(t, err, &unknownEncodingErr) + require.Contains(t, err.Error(), "unknown or unsupported encoding type 999") }) t.Run("invalid proto data", func(t *testing.T) { @@ -85,7 +98,8 @@ func TestProtoDecode(t *testing.T) { var result persistencespb.ShardInfo err := Decode(blob, &result) require.Error(t, err) - assert.IsType(t, &DeserializationError{}, err) - assert.Contains(t, err.Error(), "error deserializing using Proto3 encoding") + var deserializationErr *DeserializationError + require.ErrorAs(t, err, &deserializationErr) + require.Contains(t, err.Error(), "error deserializing using Proto3 encoding") }) } diff --git a/common/persistence/serialization/serializer.go b/common/persistence/serialization/serializer.go index 82475c81c57..ca412b64367 100644 --- a/common/persistence/serialization/serializer.go +++ b/common/persistence/serialization/serializer.go @@ -139,7 +139,7 @@ type ( ) func NewSerializer() Serializer { - return &serializerImpl{encodingType: EncodingTypeFromEnv()} + return &serializerImpl{encodingType: encodingTypeFromEnv()} } func (t *serializerImpl) EncodingType() enumspb.EncodingType { diff --git a/common/persistence/tests/util.go b/common/persistence/tests/util.go index 2f67bf30716..1671658fd57 100644 --- a/common/persistence/tests/util.go +++ b/common/persistence/tests/util.go @@ -167,7 +167,7 @@ func RandomChasmNode() *persistencespb.ChasmNode { // Some arbitrary random data to ensure the chasm node's attributes are preserved. var blobInfo persistencespb.WorkflowExecutionInfo _ = fakedata.FakeStruct(&blobInfo) - blob, _ := serialization.ProtoEncode(&blobInfo) + blob, _ := serialization.Encode(&blobInfo) var versionedTransition persistencespb.VersionedTransition _ = fakedata.FakeStruct(&versionedTransition) diff --git a/service/history/chasm_engine_test.go b/service/history/chasm_engine_test.go index dc6c8d900e8..6dfcb6fe483 100644 --- a/service/history/chasm_engine_test.go +++ b/service/history/chasm_engine_test.go @@ -1989,7 +1989,7 @@ func (s *chasmEngineSuite) buildPersistenceMutableState( func (s *chasmEngineSuite) serializeComponentState( state proto.Message, ) *commonpb.DataBlob { - blob, err := serialization.ProtoEncode(state) + blob, err := serialization.Encode(state, serialization.WithDeterministicProto3) s.NoError(err) return blob }