diff --git a/.vscode/launch.json b/.vscode/launch.json index 506c44f..a2a81d6 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -16,6 +16,32 @@ "--config", "proxy-config.yaml", ] + }, + { + "name": "Debug Proxy (cluster-a mux client)", + "type": "go", + "request": "launch", + "mode": "debug", + "program": "${workspaceFolder}/cmd/proxy", + "cwd": "${workspaceFolder}", + "args": [ + "start", + "--config", + "develop/config/cluster-a-mux-client-proxy.yaml", + ] + }, + { + "name": "Debug Proxy (cluster-b mux server)", + "type": "go", + "request": "launch", + "mode": "debug", + "program": "${workspaceFolder}/cmd/proxy", + "cwd": "${workspaceFolder}", + "args": [ + "start", + "--config", + "develop/config/cluster-b-mux-server-proxy.yaml", + ] } ] } \ No newline at end of file diff --git a/develop/config/cluster-a-mux-client-proxy.yaml b/develop/config/cluster-a-mux-client-proxy.yaml index 0446d43..8ba5925 100644 --- a/develop/config/cluster-a-mux-client-proxy.yaml +++ b/develop/config/cluster-a-mux-client-proxy.yaml @@ -10,3 +10,8 @@ clusterConnections: connectionType: "mux-client" muxAddressInfo: address: "localhost:6334" + namespaceTranslation: + mappings: + - local: "myNamespace" + remote: "myNamespace.accountid" + diff --git a/interceptor/translation_interceptor.go b/interceptor/translation_interceptor.go index 3fc2b83..f06b8c4 100644 --- a/interceptor/translation_interceptor.go +++ b/interceptor/translation_interceptor.go @@ -40,34 +40,33 @@ func (i *TranslationInterceptor) Intercept( info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, ) (any, error) { - if len(i.translators) > 0 && - strings.HasPrefix(info.FullMethod, api.WorkflowServicePrefix) || - strings.HasPrefix(info.FullMethod, api.AdminServicePrefix) { - - methodName := api.MethodName(info.FullMethod) - - for _, tr := range i.translators { - if tr.MatchMethod(info.FullMethod) { - start := time.Now() - changed, trErr := tr.TranslateRequest(req) - logTranslateResult(tr, i.logger, changed, trErr, methodName+"Request", req, time.Since(start)) - } - } + if common.IsRequestTranslationDisabled(ctx) || len(i.translators) == 0 || + (!strings.HasPrefix(info.FullMethod, api.WorkflowServicePrefix) && + !strings.HasPrefix(info.FullMethod, api.AdminServicePrefix)) { + return handler(ctx, req) + } - resp, err := handler(ctx, req) + methodName := api.MethodName(info.FullMethod) - for _, tr := range i.translators { - if tr.MatchMethod(info.FullMethod) { - start := time.Now() - changed, trErr := tr.TranslateResponse(resp) - logTranslateResult(tr, i.logger, changed, trErr, methodName+"Response", resp, time.Since(start)) - } + for _, tr := range i.translators { + if tr.MatchMethod(info.FullMethod) { + start := time.Now() + changed, trErr := tr.TranslateRequest(req) + logTranslateResult(tr, i.logger, changed, trErr, methodName+"Request", req, time.Since(start)) } + } - return resp, err - } else { - return handler(ctx, req) + resp, err := handler(ctx, req) + + for _, tr := range i.translators { + if tr.MatchMethod(info.FullMethod) { + start := time.Now() + changed, trErr := tr.TranslateResponse(resp) + logTranslateResult(tr, i.logger, changed, trErr, methodName+"Response", resp, time.Since(start)) + } } + + return resp, err } func (i *TranslationInterceptor) InterceptStream( diff --git a/interceptor/translation_interceptor_test.go b/interceptor/translation_interceptor_test.go new file mode 100644 index 0000000..e67b478 --- /dev/null +++ b/interceptor/translation_interceptor_test.go @@ -0,0 +1,84 @@ +package interceptor + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/server/common/api" + "go.temporal.io/server/common/log" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "github.com/temporalio/s2s-proxy/common" +) + +// spyTranslator records how many times each Translator method is invoked +// so tests can verify whether translation logic ran. +type spyTranslator struct { + matchCalls int + translateReqCalls int + translateRespCalls int +} + +func (s *spyTranslator) Kind() string { return "spy" } +func (s *spyTranslator) MatchMethod(string) bool { s.matchCalls++; return true } +func (s *spyTranslator) TranslateRequest(any) (bool, error) { s.translateReqCalls++; return false, nil } +func (s *spyTranslator) TranslateResponse(any) (bool, error) { + s.translateRespCalls++ + return false, nil +} + +func TestTranslationInterceptor(t *testing.T) { + logger := log.NewTestLogger() + info := &grpc.UnaryServerInfo{ + FullMethod: api.WorkflowServicePrefix + "DescribeWorkflowExecution", + } + handler := func(_ context.Context, _ any) (any, error) { + return &workflowservice.DescribeWorkflowExecutionResponse{}, nil + } + + cases := []struct { + name string + // incomingHeaders is attached to the request context (nil for none). + incomingHeaders map[string]string + // MatchMethod is consulted once per phase (request + response) when translation runs. + expectedMatchCalls int + expectedReqCalls int + expectedRespCalls int + }{ + { + name: "header_false_skips_translators", + incomingHeaders: map[string]string{common.RequestTranslationHeaderName: "false"}, + expectedMatchCalls: 0, + expectedReqCalls: 0, + expectedRespCalls: 0, + }, + { + name: "header_absent_invokes_translators", + incomingHeaders: nil, + expectedMatchCalls: 2, + expectedReqCalls: 1, + expectedRespCalls: 1, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + spy := &spyTranslator{} + ti := NewTranslationInterceptor(logger, []Translator{spy}) + + ctx := context.Background() + if tc.incomingHeaders != nil { + ctx = metadata.NewIncomingContext(ctx, metadata.New(tc.incomingHeaders)) + } + _, err := ti.Intercept(ctx, &workflowservice.DescribeWorkflowExecutionRequest{}, info, handler) + require.NoError(t, err) + + require.Equal(t, tc.expectedMatchCalls, spy.matchCalls, "MatchMethod call count") + require.Equal(t, tc.expectedReqCalls, spy.translateReqCalls, "TranslateRequest call count") + require.Equal(t, tc.expectedRespCalls, spy.translateRespCalls, "TranslateResponse call count") + }) + } +} diff --git a/proxy/workflowservice.go b/proxy/workflowservice.go index da66c41..b6af3fb 100644 --- a/proxy/workflowservice.go +++ b/proxy/workflowservice.go @@ -14,6 +14,13 @@ import ( const DCRedirectionContextHeaderName = "xdc-redirection" // https://github.com/temporalio/temporal/blob/9a1060c4162ff62576cb899d7e5b1bae179af814/common/rpc/interceptor/redirection.go#L27 +// PreservedHeaders are forwarded from the incoming context to the outgoing +// context so they survive the proxy hop. +var PreservedHeaders = []string{ + DCRedirectionContextHeaderName, + common.RequestTranslationHeaderName, +} + type ( workflowServiceProxyServer struct { workflowservice.UnimplementedWorkflowServiceServer @@ -310,9 +317,11 @@ func (s *workflowServiceProxyServer) UpdateWorkflowExecution(ctx context.Context } func copyContext(src context.Context) context.Context { - val := metadata.ValueFromIncomingContext(src, DCRedirectionContextHeaderName) - if len(val) > 0 { - src = metadata.AppendToOutgoingContext(src, DCRedirectionContextHeaderName, val[0]) + for _, header := range PreservedHeaders { + val := metadata.ValueFromIncomingContext(src, header) + if len(val) > 0 { + src = metadata.AppendToOutgoingContext(src, header, val[0]) + } } return src } diff --git a/proxy/workflowservice_test.go b/proxy/workflowservice_test.go index 9858ec2..a8e7ac1 100644 --- a/proxy/workflowservice_test.go +++ b/proxy/workflowservice_test.go @@ -107,23 +107,28 @@ func (s *workflowServiceTestSuite) TestNamespaceFiltering() { s.Equalf(-1, steveIndex, "Shouldn't have found namespace %s\n in list response: %s", matchingNS, res.Namespaces) } -func (s *workflowServiceTestSuite) TestPreserveRedirectionHeader() { +func (s *workflowServiceTestSuite) TestPreservedHeaders() { loggerProvider := logging.NewLoggerProvider(log.NewTestLogger(), config.NewMockConfigProvider(config.S2SProxyConfig{})) wfProxy := NewWorkflowServiceProxyServer("My cool test server", s.clientMock, nil, loggerProvider) - // Client should be called with xdc-redirection=false header + // All PreservedHeaders should be forwarded from the incoming context. for _, headerValue := range []string{"true", "false", ""} { s.clientMock.EXPECT().DescribeWorkflowExecution(gomockold.Any(), gomockold.Any()).DoAndReturn( func(ctx context.Context, request *workflowservice.DescribeWorkflowExecutionRequest, opts ...grpc.CallOption) (*workflowservice.DescribeWorkflowExecutionResponse, error) { md, ok := metadata.FromOutgoingContext(ctx) s.True(ok) - s.Equal(md.Get(DCRedirectionContextHeaderName), []string{headerValue}) + for _, header := range PreservedHeaders { + s.Equal([]string{headerValue}, md.Get(header), "header %q not forwarded", header) + } return &workflowservice.DescribeWorkflowExecutionResponse{}, nil }, ).Times(1) - // API is passed xdc-redirection=false header - ctx := metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{DCRedirectionContextHeaderName: headerValue})) + incoming := map[string]string{} + for _, header := range PreservedHeaders { + incoming[header] = headerValue + } + ctx := metadata.NewIncomingContext(context.Background(), metadata.New(incoming)) res, err := wfProxy.DescribeWorkflowExecution(ctx, &workflowservice.DescribeWorkflowExecutionRequest{}) s.NoError(err) s.NotNil(res)