Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,32 @@
"--config",
"proxy-config.yaml",
]
},
{
"name": "Debug Proxy (cluster-a mux client)",
"type": "go",
"request": "launch",
"mode": "debug",
"program": "${workspaceFolder}/cmd/proxy",
"cwd": "${workspaceFolder}",
"args": [
"start",
"--config",
"develop/config/cluster-a-mux-client-proxy.yaml",
]
},
{
"name": "Debug Proxy (cluster-b mux server)",
"type": "go",
"request": "launch",
"mode": "debug",
"program": "${workspaceFolder}/cmd/proxy",
"cwd": "${workspaceFolder}",
"args": [
"start",
"--config",
"develop/config/cluster-b-mux-server-proxy.yaml",
]
}
]
}
5 changes: 5 additions & 0 deletions develop/config/cluster-a-mux-client-proxy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@ clusterConnections:
connectionType: "mux-client"
muxAddressInfo:
address: "localhost:6334"
namespaceTranslation:
mappings:
- local: "myNamespace"
remote: "myNamespace.accountid"

45 changes: 22 additions & 23 deletions interceptor/translation_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,34 +40,33 @@ func (i *TranslationInterceptor) Intercept(
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (any, error) {
if len(i.translators) > 0 &&
strings.HasPrefix(info.FullMethod, api.WorkflowServicePrefix) ||
strings.HasPrefix(info.FullMethod, api.AdminServicePrefix) {

methodName := api.MethodName(info.FullMethod)

for _, tr := range i.translators {
if tr.MatchMethod(info.FullMethod) {
start := time.Now()
changed, trErr := tr.TranslateRequest(req)
logTranslateResult(tr, i.logger, changed, trErr, methodName+"Request", req, time.Since(start))
}
}
if common.IsRequestTranslationDisabled(ctx) || len(i.translators) == 0 ||
(!strings.HasPrefix(info.FullMethod, api.WorkflowServicePrefix) &&
!strings.HasPrefix(info.FullMethod, api.AdminServicePrefix)) {
return handler(ctx, req)
}

resp, err := handler(ctx, req)
methodName := api.MethodName(info.FullMethod)

for _, tr := range i.translators {
if tr.MatchMethod(info.FullMethod) {
start := time.Now()
changed, trErr := tr.TranslateResponse(resp)
logTranslateResult(tr, i.logger, changed, trErr, methodName+"Response", resp, time.Since(start))
}
for _, tr := range i.translators {
if tr.MatchMethod(info.FullMethod) {
start := time.Now()
changed, trErr := tr.TranslateRequest(req)
logTranslateResult(tr, i.logger, changed, trErr, methodName+"Request", req, time.Since(start))
}
}

return resp, err
} else {
return handler(ctx, req)
resp, err := handler(ctx, req)

for _, tr := range i.translators {
if tr.MatchMethod(info.FullMethod) {
start := time.Now()
changed, trErr := tr.TranslateResponse(resp)
logTranslateResult(tr, i.logger, changed, trErr, methodName+"Response", resp, time.Since(start))
}
}

return resp, err
}

func (i *TranslationInterceptor) InterceptStream(
Expand Down
84 changes: 84 additions & 0 deletions interceptor/translation_interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package interceptor

import (
"context"
"testing"

"github.com/stretchr/testify/require"
"go.temporal.io/api/workflowservice/v1"
"go.temporal.io/server/common/api"
"go.temporal.io/server/common/log"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"

"github.com/temporalio/s2s-proxy/common"
)

// spyTranslator records how many times each Translator method is invoked
// so tests can verify whether translation logic ran.
type spyTranslator struct {
matchCalls int
translateReqCalls int
translateRespCalls int
}

func (s *spyTranslator) Kind() string { return "spy" }
func (s *spyTranslator) MatchMethod(string) bool { s.matchCalls++; return true }
func (s *spyTranslator) TranslateRequest(any) (bool, error) { s.translateReqCalls++; return false, nil }
func (s *spyTranslator) TranslateResponse(any) (bool, error) {
s.translateRespCalls++
return false, nil
}

func TestTranslationInterceptor(t *testing.T) {
logger := log.NewTestLogger()
info := &grpc.UnaryServerInfo{
FullMethod: api.WorkflowServicePrefix + "DescribeWorkflowExecution",
}
handler := func(_ context.Context, _ any) (any, error) {
return &workflowservice.DescribeWorkflowExecutionResponse{}, nil
}

cases := []struct {
name string
// incomingHeaders is attached to the request context (nil for none).
incomingHeaders map[string]string
// MatchMethod is consulted once per phase (request + response) when translation runs.
expectedMatchCalls int
expectedReqCalls int
expectedRespCalls int
}{
{
name: "header_false_skips_translators",
incomingHeaders: map[string]string{common.RequestTranslationHeaderName: "false"},
expectedMatchCalls: 0,
expectedReqCalls: 0,
expectedRespCalls: 0,
},
{
name: "header_absent_invokes_translators",
incomingHeaders: nil,
expectedMatchCalls: 2,
expectedReqCalls: 1,
expectedRespCalls: 1,
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
spy := &spyTranslator{}
ti := NewTranslationInterceptor(logger, []Translator{spy})

ctx := context.Background()
if tc.incomingHeaders != nil {
ctx = metadata.NewIncomingContext(ctx, metadata.New(tc.incomingHeaders))
}
_, err := ti.Intercept(ctx, &workflowservice.DescribeWorkflowExecutionRequest{}, info, handler)
require.NoError(t, err)

require.Equal(t, tc.expectedMatchCalls, spy.matchCalls, "MatchMethod call count")
require.Equal(t, tc.expectedReqCalls, spy.translateReqCalls, "TranslateRequest call count")
require.Equal(t, tc.expectedRespCalls, spy.translateRespCalls, "TranslateResponse call count")
})
}
}
15 changes: 12 additions & 3 deletions proxy/workflowservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ import (

const DCRedirectionContextHeaderName = "xdc-redirection" // https://github.com/temporalio/temporal/blob/9a1060c4162ff62576cb899d7e5b1bae179af814/common/rpc/interceptor/redirection.go#L27

// PreservedHeaders are forwarded from the incoming context to the outgoing
// context so they survive the proxy hop.
var PreservedHeaders = []string{
DCRedirectionContextHeaderName,
common.RequestTranslationHeaderName,
}

type (
workflowServiceProxyServer struct {
workflowservice.UnimplementedWorkflowServiceServer
Expand Down Expand Up @@ -310,9 +317,11 @@ func (s *workflowServiceProxyServer) UpdateWorkflowExecution(ctx context.Context
}

func copyContext(src context.Context) context.Context {
val := metadata.ValueFromIncomingContext(src, DCRedirectionContextHeaderName)
if len(val) > 0 {
src = metadata.AppendToOutgoingContext(src, DCRedirectionContextHeaderName, val[0])
for _, header := range PreservedHeaders {
val := metadata.ValueFromIncomingContext(src, header)
if len(val) > 0 {
src = metadata.AppendToOutgoingContext(src, header, val[0])
}
}
return src
}
15 changes: 10 additions & 5 deletions proxy/workflowservice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,23 +107,28 @@ func (s *workflowServiceTestSuite) TestNamespaceFiltering() {
s.Equalf(-1, steveIndex, "Shouldn't have found namespace %s\n in list response: %s", matchingNS, res.Namespaces)
}

func (s *workflowServiceTestSuite) TestPreserveRedirectionHeader() {
func (s *workflowServiceTestSuite) TestPreservedHeaders() {
loggerProvider := logging.NewLoggerProvider(log.NewTestLogger(), config.NewMockConfigProvider(config.S2SProxyConfig{}))
wfProxy := NewWorkflowServiceProxyServer("My cool test server", s.clientMock, nil, loggerProvider)

// Client should be called with xdc-redirection=false header
// All PreservedHeaders should be forwarded from the incoming context.
for _, headerValue := range []string{"true", "false", ""} {
s.clientMock.EXPECT().DescribeWorkflowExecution(gomockold.Any(), gomockold.Any()).DoAndReturn(
func(ctx context.Context, request *workflowservice.DescribeWorkflowExecutionRequest, opts ...grpc.CallOption) (*workflowservice.DescribeWorkflowExecutionResponse, error) {
md, ok := metadata.FromOutgoingContext(ctx)
s.True(ok)
s.Equal(md.Get(DCRedirectionContextHeaderName), []string{headerValue})
for _, header := range PreservedHeaders {
s.Equal([]string{headerValue}, md.Get(header), "header %q not forwarded", header)
}
return &workflowservice.DescribeWorkflowExecutionResponse{}, nil
},
).Times(1)

// API is passed xdc-redirection=false header
ctx := metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{DCRedirectionContextHeaderName: headerValue}))
incoming := map[string]string{}
for _, header := range PreservedHeaders {
incoming[header] = headerValue
}
ctx := metadata.NewIncomingContext(context.Background(), metadata.New(incoming))
res, err := wfProxy.DescribeWorkflowExecution(ctx, &workflowservice.DescribeWorkflowExecutionRequest{})
s.NoError(err)
s.NotNil(res)
Expand Down
Loading