Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 15 additions & 22 deletions chasm/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -807,14 +807,6 @@ func (n *Node) serialize() error {
}
}

// marshalDeterministic encodes a proto.Message to bytes using deterministic
// serialization. Deterministic encoding sorts map keys before encoding, making
// byte comparison a reliable equality check for any well-formed proto message.
// For messages without map fields this is a no-op with no performance overhead.
func marshalDeterministic(m proto.Message) ([]byte, error) {
return proto.MarshalOptions{Deterministic: true}.Marshal(m)
}

// serializeComponentNode serializes the component node.
// If this method is updated to modify serialized fields beyond Data and
// LastUpdateVersionedTransition, the skip-if-clean revert logic in
Expand All @@ -831,11 +823,10 @@ func (n *Node) serializeComponentNode() error {

var blob *commonpb.DataBlob
if !field.val.IsNil() {
data, err := marshalDeterministic(field.val.Interface().(proto.Message))
if err != nil {
return serialization.NewSerializationError(enumspb.ENCODING_TYPE_PROTO3, err)
var err error
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: what is preferred, declaring var err, or using :=?

Copy link
Copy Markdown
Contributor Author

@stephanos stephanos May 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't use := as var blob *commonpb.DataBlob is declared before and in Go we can't mix existing/non-existing vars in :=. To make it work, we'd have to invent another temp var.

if blob, err = encodeChasmBlob(field.val.Interface().(proto.Message)); err != nil {
return err
}
blob = &commonpb.DataBlob{EncodingType: enumspb.ENCODING_TYPE_PROTO3, Data: data}
}

n.serializedNode.Data = blob
Expand Down Expand Up @@ -1205,11 +1196,10 @@ func (n *Node) serializeDataNode() error {

var blob *commonpb.DataBlob
if protoValue != nil {
data, err := marshalDeterministic(protoValue)
if err != nil {
return serialization.NewSerializationError(enumspb.ENCODING_TYPE_PROTO3, err)
var err error
if blob, err = encodeChasmBlob(protoValue); err != nil {
return err
}
blob = &commonpb.DataBlob{EncodingType: enumspb.ENCODING_TYPE_PROTO3, Data: data}
}
n.serializedNode.Data = blob
n.updateLastUpdateVersionedTransition()
Expand Down Expand Up @@ -3077,7 +3067,7 @@ func serializeTask(
) (*commonpb.DataBlob, error) {
protoValue, ok := taskValue.Interface().(proto.Message)
if ok {
return serialization.ProtoEncode(protoValue)
return encodeChasmBlob(protoValue)
}

taskGoType := registrableTask.goType
Expand All @@ -3090,10 +3080,7 @@ func serializeTask(

// Handle empty task struct.
if taskGoType.NumField() == 0 {
return &commonpb.DataBlob{
Data: nil,
EncodingType: enumspb.ENCODING_TYPE_PROTO3,
}, nil
return encodeChasmBlob(nil)
}

// TODO: consider pre-calculating the proto field num when registring the task type.
Expand All @@ -3112,7 +3099,7 @@ func serializeTask(
protoMessageFound = true

var err error
blob, err = serialization.ProtoEncode(fieldV.Interface().(proto.Message))
blob, err = encodeChasmBlob(fieldV.Interface().(proto.Message))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -3501,3 +3488,9 @@ func makeValidationFn(
return nil
}
}

// encodeChasmBlob encodes CHASM data and task payloads through the env-aware
// serializer while preserving deterministic proto3 bytes for byte comparisons.
func encodeChasmBlob(m proto.Message) (*commonpb.DataBlob, error) {
return serialization.Encode(m, serialization.WithDeterministicProto3)
}
48 changes: 22 additions & 26 deletions chasm/tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"go.temporal.io/server/common/definition"
"go.temporal.io/server/common/log"
"go.temporal.io/server/common/metrics"
"go.temporal.io/server/common/persistence/serialization"
"go.temporal.io/server/common/primitives"
"go.temporal.io/server/common/testing/protoassert"
"go.temporal.io/server/common/testing/protorequire"
Expand Down Expand Up @@ -252,6 +251,7 @@ func (s *nodeSuite) TestSerializeNode_DataAttributes() {
err := node.serialize()
s.NoError(err)
s.NotNil(node.serializedNode.GetData(), "child node serialized value must have data after serialize is called")
s.Equal(enumspb.ENCODING_TYPE_PROTO3, node.serializedNode.GetData().GetEncodingType())
s.Equal([]byte{0xa, 0x2, 0x32, 0x32}, node.serializedNode.GetData().GetData())
s.Equal(valueStateSynced, node.valueState)
}
Expand Down Expand Up @@ -899,7 +899,7 @@ func (s *nodeSuite) TestNodeSnapshot() {

func (s *nodeSuite) TestApplyMutation() {
mustEncode := func(m proto.Message) *commonpb.DataBlob {
taskBlob, err := serialization.ProtoEncode(m)
taskBlob, err := encodeChasmBlob(m)
s.NoError(err)
return taskBlob
}
Expand Down Expand Up @@ -1199,9 +1199,7 @@ func (s *nodeSuite) TestApplySnapshot() {
// - a new node "SubComponent2" is added.

now := timestamppb.Now()
updatedRootData, err := serialization.ProtoEncode(&protoMessageType{
StartTime: now,
})
updatedRootData, err := encodeChasmBlob(&protoMessageType{StartTime: now})
s.NoError(err)
incomingSnapshot := NodesSnapshot{
Nodes: map[string]*persistencespb.ChasmNode{
Expand Down Expand Up @@ -2098,7 +2096,7 @@ func (s *nodeSuite) TestSerializeDeserializeTask() {
payload := &commonpb.Payload{
Data: []byte("some-random-data"),
}
expectedBlob, err := serialization.ProtoEncode(payload)
expectedBlob, err := encodeChasmBlob(payload)
s.NoError(err)

testCases := []struct {
Expand Down Expand Up @@ -2338,8 +2336,9 @@ func (s *nodeSuite) TestCloseTransaction_InvalidateComponentTasks() {
payload := &commonpb.Payload{
Data: []byte("some-random-data"),
}
taskBlob, err := serialization.ProtoEncode(payload)
taskBlob, err := encodeChasmBlob(payload)
s.NoError(err)
emptyTaskBlob := s.emptyDataBlob()

persistenceNodes := map[string]*persistencespb.ChasmNode{
"": {
Expand All @@ -2361,11 +2360,8 @@ func (s *nodeSuite) TestCloseTransaction_InvalidateComponentTasks() {
TypeId: testOutboundSideEffectTaskTypeID,
VersionedTransition: &persistencespb.VersionedTransition{TransitionCount: 1},
VersionedTransitionOffset: 2,
Data: &commonpb.DataBlob{
Data: nil,
EncodingType: enumspb.ENCODING_TYPE_PROTO3,
},
PhysicalTaskStatus: physicalTaskStatusCreated,
Data: emptyTaskBlob,
PhysicalTaskStatus: physicalTaskStatusCreated,
},
},
PureTasks: []*persistencespb.ChasmComponentAttributes_Task{
Expand Down Expand Up @@ -2457,7 +2453,7 @@ func (s *nodeSuite) TestCloseTransaction_PausedStateInvalidatesTasks() {
payload := &commonpb.Payload{
Data: []byte("some-random-data"),
}
taskBlob, err := serialization.ProtoEncode(payload)
taskBlob, err := encodeChasmBlob(payload)
s.NoError(err)

makeTask := func(typeID uint32, offset int64) *persistencespb.ChasmComponentAttributes_Task {
Expand Down Expand Up @@ -3186,7 +3182,7 @@ func (s *nodeSuite) TestEachPureTask() {
now := s.timeSource.Now()

mustEncode := func(m proto.Message) *commonpb.DataBlob {
taskBlob, err := serialization.ProtoEncode(m)
taskBlob, err := encodeChasmBlob(m)
s.NoError(err)
return taskBlob
}
Expand Down Expand Up @@ -3516,6 +3512,7 @@ func (s *nodeSuite) TestExecuteSideEffectTask() {
},
}

emptyTaskBlob := s.emptyDataBlob()
taskInfo := &persistencespb.ChasmTaskInfo{
ComponentInitialVersionedTransition: &persistencespb.VersionedTransition{
TransitionCount: 1,
Expand All @@ -3526,10 +3523,7 @@ func (s *nodeSuite) TestExecuteSideEffectTask() {
Path: []string{"SubComponent1"},
TypeId: testSideEffectTaskTypeID,
ArchetypeId: testComponentTypeID,
Data: &commonpb.DataBlob{
Data: nil,
EncodingType: enumspb.ENCODING_TYPE_PROTO3,
},
Data: emptyTaskBlob,
}
workflowKey := definition.NewWorkflowKey(
primitives.NewUUID().String(),
Expand Down Expand Up @@ -3666,6 +3660,7 @@ func (s *nodeSuite) TestExecuteSideEffectDiscardTask() {
primitives.NewUUID().String(),
primitives.NewUUID().String(),
)
emptyTaskBlob := s.emptyDataBlob()
chasmTask := &tasks.ChasmTask{
WorkflowKey: workflowKey,
VisibilityTimestamp: s.timeSource.Now(),
Expand All @@ -3682,10 +3677,7 @@ func (s *nodeSuite) TestExecuteSideEffectDiscardTask() {
Path: []string{"SubComponent1"},
TypeId: testDiscardableSideEffectTaskTypeID,
ArchetypeId: testComponentTypeID,
Data: &commonpb.DataBlob{
Data: nil,
EncodingType: enumspb.ENCODING_TYPE_PROTO3,
},
Data: emptyTaskBlob,
},
}
executionKey := ExecutionKey{
Expand Down Expand Up @@ -3806,6 +3798,7 @@ func (s *nodeSuite) TestExecuteSideEffectDiscardTask() {
}

func (s *nodeSuite) TestValidateSideEffectTask() {
emptyTaskBlob := s.emptyDataBlob()
taskInfo := &persistencespb.ChasmTaskInfo{
ComponentInitialVersionedTransition: &persistencespb.VersionedTransition{
TransitionCount: 1,
Expand All @@ -3817,10 +3810,7 @@ func (s *nodeSuite) TestValidateSideEffectTask() {
},
Path: rootPath,
TypeId: testSideEffectTaskTypeID,
Data: &commonpb.DataBlob{
Data: nil,
EncodingType: enumspb.ENCODING_TYPE_PROTO3,
},
Data: emptyTaskBlob,
}
workflowKey := definition.NewWorkflowKey(
primitives.NewUUID().String(),
Expand Down Expand Up @@ -3950,3 +3940,9 @@ func (s *nodeSuite) newTestTree(
}
return NewTreeFromDB(serializedNodes, s.registry, s.timeSource, s.nodeBackend, s.nodePathEncoder, s.logger, s.metricsHandler)
}

func (s *nodeSuite) emptyDataBlob() *commonpb.DataBlob {
blob, err := encodeChasmBlob(nil)
s.NoError(err)
return blob
}
40 changes: 33 additions & 7 deletions common/persistence/serialization/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import (
// WARNING: This environment variable should only be used for testing; and never set it in production.
const SerializerDataEncodingEnvVar = "TEMPORAL_TEST_DATA_ENCODING"

// EncodingTypeFromEnv returns an EncodingType based on the environment variable `TEMPORAL_TEST_DATA_ENCODING`.
// encodingTypeFromEnv returns an EncodingType based on the environment variable `TEMPORAL_TEST_DATA_ENCODING`.
// It defaults to "ENCODING_TYPE_PROTO3" codec if the environment variable is not set.
func EncodingTypeFromEnv() enumspb.EncodingType {
func encodingTypeFromEnv() enumspb.EncodingType {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drive-by change: this didn't need to be exported

codecType := os.Getenv(SerializerDataEncodingEnvVar)
switch strings.ToLower(codecType) {
case "", "proto3":
Expand All @@ -38,12 +38,38 @@ func EncodingTypeFromEnv() enumspb.EncodingType {
}
}

// ProtoEncode is kept for backward compatibility.
func ProtoEncode(m proto.Message) (*commonpb.DataBlob, error) {
return encodeBlob(m, enumspb.ENCODING_TYPE_PROTO3)
type (
encodeOptions struct {
deterministic bool
}
EncodeOption func(*encodeOptions)
)

// WithDeterministicProto3 uses deterministic marshaling when Encode selects proto3.
//
// Deterministic encoding sorts map keys before encoding, making byte comparison
// a reliable equality check for any well-formed proto message. For messages
// without map fields this is a no-op with no performance overhead.
var WithDeterministicProto3 EncodeOption = func(opts *encodeOptions) {
opts.deterministic = true
}

func encodeBlob(m proto.Message, encoding enumspb.EncodingType) (*commonpb.DataBlob, error) {
// Encode encodes the given proto message. It respects the `TEMPORAL_TEST_DATA_ENCODING` environment variable;
// otherwise, it defaults to "ENCODING_TYPE_PROTO3".
func Encode(m proto.Message, options ...EncodeOption) (*commonpb.DataBlob, error) {
return encodeBlob(m, encodingTypeFromEnv(), options...)
}

func encodeBlob(
m proto.Message,
encoding enumspb.EncodingType,
options ...EncodeOption,
) (*commonpb.DataBlob, error) {
opts := encodeOptions{}
for _, option := range options {
option(&opts)
}

if m == nil {
return &commonpb.DataBlob{
Data: nil,
Expand All @@ -62,7 +88,7 @@ func encodeBlob(m proto.Message, encoding enumspb.EncodingType) (*commonpb.DataB
EncodingType: enumspb.ENCODING_TYPE_JSON,
}, nil
case enumspb.ENCODING_TYPE_PROTO3:
data, err := proto.Marshal(m)
data, err := proto.MarshalOptions{Deterministic: opts.deterministic}.Marshal(m)
if err != nil {
return nil, NewSerializationError(enumspb.ENCODING_TYPE_PROTO3, err)
}
Expand Down
Loading
Loading