diff --git a/README.md b/README.md index e93d14b..7625339 100644 --- a/README.md +++ b/README.md @@ -1,35 +1,196 @@ -# Temporal Proxy with AWS KMS Encryption POC +# Temporal Cloud Proxy -This poc project implements a proxy for Temporal workflows with end-to-end encryption using AWS KMS. +A Temporal Cloud proxy that handles Temporal namespace authentication, payload encryption/decryption, and client/worker authentication for multiple workloads with different configurations. -## Prerequisites +## Key Features -- Go 1.16 or later -- AWS account with permissions to create and use KMS keys -- A Temporal Cloud account +- **Multi-workload Support** - Handle multiple Temporal workload configurations through a single proxy instance +- **Payload Encryption/Decryption** - AWS KMS and GCP KMS support with intelligent caching for performance +- **Temporal Cloud Namespace Authentication** - Support for mTLS and API keys +- **Worker Authentication** - Support for JWT (with JWKS) and SPIFFE/SPIRE +- **Observability** - Built-in Prometheus metrics, Grafana dashboards, and structured logging -## Setup Instructions +## Quick Start -### 1. Create an AWS KMS Key +### Prerequisites -1. Sign in to the AWS Management Console and open the KMS console at https://console.aws.amazon.com/kms -2. Choose **Create key** -3. Select **Symmetric** for Key type -4. For Key usage, select **Encrypt and decrypt** -5. Add a name and description for your key -6. Configure key administrative permissions and key usage permissions -7. Review and finish creating the key -8. Note the ARN of your new key, which will look like: `arn:aws:kms:region:account-id:key/key-id` +- Go 1.24 or later +- AWS account with KMS permissions (for AWS KMS encryption) +- GCP account with KMS permissions (for GCP KMS encryption) +- Temporal Cloud account -### 2. Configure AWS Credentials +### Installation -Ensure your AWS credentials are properly configured: +1. Clone the repository: ```bash -aws configure +git clone +cd temporal-cloud-proxy ``` -Or set environment variables: +2. Build the binary: + +```bash +make build +``` + +3. Copy and configure the sample config: + +```bash +cp config.yaml.sample config.yaml +# Edit config.yaml with your settings +``` + +4. Run the proxy: + +```bash +./tclp --config config.yaml +``` + +## Configuration Reference + +The proxy is configured via a YAML file. Here's the basic structure: + +```yaml +server: + port: 7233 # Proxy server port + host: "0.0.0.0" # Bind address + +metrics: + port: 9090 # Prometheus metrics port + +encryption: + caching: + max_cache: 100 # Maximum cached encryption keys + max_age: "10m" # Key cache TTL + max_usage: 100 # Maximum key usage count + +workloads: + - workload_id: "my-workload" + temporal_cloud: + namespace: "my-namespace.my-account" + host_port: "my-namespace.my-account.tmprl.cloud:7233" + authentication: + # Choose one authentication method + tls: + cert_file: "/path/to/tls.crt" + key_file: "/path/to/tls.key" + # OR + api_key: + value: "your-api-key" + # OR env: "API_KEY_ENV_VAR" + encryption: + type: "aws-kms" + config: + key-id: "arn:aws:kms:region:account:key/key-id" + authentication: + type: "jwt" + config: + jwks-url: "https://your-auth-provider/.well-known/keys" + audiences: ["temporal_cloud_proxy"] +``` + +### Multiple Workloads Example + +```yaml +workloads: + - workload_id: "production" + temporal_cloud: + namespace: "prod.company" + host_port: "prod.company.tmprl.cloud:7233" + authentication: + api_key: + env: "PROD_API_KEY" + encryption: + type: "aws-kms" + config: + key-id: "arn:aws:kms:us-east-1:123456789:key/prod-key" + authentication: + type: "spiffe" + config: + trust_domain: "spiffe://company.com/" + endpoint: "unix:///tmp/spire-agent/public/api.sock" + + - workload_id: "staging" + temporal_cloud: + namespace: "staging.company" + host_port: "staging.company.tmprl.cloud:7233" + authentication: + tls: + cert_file: "/certs/staging.crt" + key_file: "/certs/staging.key" + encryption: + type: "gcp-kms" + config: + key-name: "projects/my-project/locations/us-central1/keyRings/staging/cryptoKeys/temporal" +``` + +## Temporal Cloud Namespace Authentication Methods + +### mTLS Authentication + +```yaml +temporal_cloud: + authentication: + tls: + cert_file: "/path/to/client.crt" + key_file: "/path/to/client.key" +``` + +### API Key Authentication + +```yaml +temporal_cloud: + authentication: + api_key: + value: "your-api-key" + # OR use environment variable + env: "TEMPORAL_API_KEY" +``` + +## Worker to Proxy Authentication Methods + +### JWT Authentication + +```yaml +authentication: + type: "jwt" + config: + jwks-url: "https://auth.company.com/.well-known/keys" + audiences: ["temporal_cloud_proxy"] +``` + +### SPIFFE Authentication + +```yaml +authentication: + type: "spiffe" + config: + trust_domain: "spiffe://company.com/" + endpoint: "unix:///tmp/spire-agent/public/api.sock" + audiences: ["temporal_cloud_proxy"] +``` + +## Encryption & Security + +### AWS KMS Configuration + +1. Create a KMS key in AWS: + +```bash +aws kms create-key --description "Temporal Cloud Proxy Encryption Key" +``` + +2. Configure the proxy: + +```yaml +encryption: + type: "aws-kms" + config: + key-id: "arn:aws:kms:region:account-id:key/key-id" +``` + +3. Ensure AWS credentials are configured: ```bash export AWS_ACCESS_KEY_ID=your_access_key @@ -37,16 +198,119 @@ export AWS_SECRET_ACCESS_KEY=your_secret_key export AWS_REGION=your_region ``` -## Build +### GCP KMS Configuration + +1. Create a KMS key in GCP: -```sh -make +```bash +gcloud kms keyrings create temporal-proxy --location=global +gcloud kms keys create encryption-key --location=global --keyring=temporal-proxy --purpose=encryption ``` -## Run +2. Configure the proxy: -Update `config.yaml` with namespace details. +```yaml +encryption: + type: "gcp-kms" + config: + key-name: "projects/PROJECT_ID/locations/global/keyRings/temporal-proxy/cryptoKeys/encryption-key" +``` -```sh -./tclp --config config.yaml +### Encryption Caching + +The proxy includes intelligent caching to optimize encryption performance: + +```yaml +encryption: + caching: + max_cache: 100 # Maximum number of cached keys + max_age: "10m" # Maximum age of cached keys + max_usage: 100 # Maximum usage count per key +``` + +## Monitoring & Observability + +### Prometheus Metrics + +The proxy exposes metrics on the configured metrics port (default: 9090): + +- `proxy_request_total` - Total number of proxy requests +- `proxy_request_errors` - Number of failed requests +- `proxy_request_success` - Number of successful requests +- `proxy_latency` - Request latency histogram +- Encryption/decryption metrics per workload +- Authentication success/failure rates + +### Grafana Dashboard + +A pre-configured Grafana dashboard is available at `dashboards/grafana-dashboard.json`. Import this dashboard to visualize: + +- Request throughput and error rates +- Authentication success rates +- Encryption performance metrics +- Per-workload statistics + +### Logging + +Configure log levels using the `--log-level` flag: + +```bash +./tclp --config config.yaml --log-level debug +``` + +Available levels: `debug`, `info`, `warn`, `error` + +## Development & Testing + +### Building + +```bash +# Build the binary +make build + +# Clean build artifacts +make clean + +# Build and test +make all +``` + +### Testing + +```bash +# Run all tests +make test + +# Run tests with verbose output +make test-verbose + +# Run tests with coverage +make test-coverage + +# Run race condition tests +make test-race + +# Run benchmarks +make benchmark +``` + +## Temporal Worker Configuration + +### Required Headers + +- `workload-id`: Identifies which workload configuration to use +- `authorization`: Authentication token (when worker authentication is enabled) + +### Example Implementations + +- [Temporal Worker with SPIFFE authentication](https://github.com/temporal-sa/temporal-proxy-spiffe-worker) +- [Temporal Worker with JWT authentication](https://github.com/temporal-sa/temporal-proxy-jwt-worker) + + +## Debug Logging + +Enable debug logging for detailed troubleshooting: + +```bash +./tclp --config config.yaml --log-level debug ``` diff --git a/auth/authenticator.go b/auth/authenticator.go index 20205a9..7ac613c 100644 --- a/auth/authenticator.go +++ b/auth/authenticator.go @@ -2,19 +2,63 @@ package auth import ( "context" + "fmt" + "github.com/temporal-sa/temporal-cloud-proxy/config" + "go.uber.org/zap" "time" ) -type AuthenticationResult struct { - Authenticated bool - Subject string - Claims map[string]interface{} - Expiration time.Time +type ( + AuthenticatorFactory interface { + NewAuthenticator(authConfig config.AuthConfig) (Authenticator, error) + } + + Authenticator interface { + Type() string + Init(ctx context.Context, config map[string]interface{}) error + Authenticate(ctx context.Context, credentials interface{}) (*AuthenticationResult, error) + Close() error + } + + AuthenticationResult struct { + Authenticated bool + Subject string + Claims map[string]interface{} + Expiration time.Time + } + + AuthenticatorConstructor func(config map[string]interface{}) (Authenticator, error) + + authenticatorFactory struct { + providers map[string]AuthenticatorConstructor + } +) + +func newAuthenticatorFactoryProvider(ctx context.Context, _ *zap.Logger) (AuthenticatorFactory, error) { + af := &authenticatorFactory{ + providers: make(map[string]AuthenticatorConstructor), + } + + af.providers["spiffe"] = func(config map[string]interface{}) (Authenticator, error) { + authenticator := &SpiffeAuthenticator{} + err := authenticator.Init(ctx, config) + return authenticator, err + } + + af.providers["jwt"] = func(config map[string]interface{}) (Authenticator, error) { + authenticator := &JwtAuthenticator{} + err := authenticator.Init(ctx, config) + return authenticator, err + } + + return af, nil } -type Authenticator interface { - Type() string - Init(ctx context.Context, config map[string]interface{}) error - Authenticate(ctx context.Context, credentials interface{}) (*AuthenticationResult, error) - Close() error +func (a *authenticatorFactory) NewAuthenticator(authConfig config.AuthConfig) (Authenticator, error) { + authenticator, ok := a.providers[authConfig.Type] + if !ok { + return nil, fmt.Errorf("authenticator not found for type %s", authConfig.Type) + } + + return authenticator(authConfig.Config) } diff --git a/auth/authenticator_test.go b/auth/authenticator_test.go index df54ecc..51abf16 100644 --- a/auth/authenticator_test.go +++ b/auth/authenticator_test.go @@ -1,414 +1,87 @@ package auth import ( + "context" "testing" - "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/temporal-sa/temporal-cloud-proxy/config" ) -func TestAuthenticationResult(t *testing.T) { - tests := []struct { - name string - result *AuthenticationResult - }{ - { - name: "complete authentication result", - result: &AuthenticationResult{ - Authenticated: true, - Subject: "user123", - Claims: map[string]interface{}{ - "role": "admin", - "scope": "read:write", - "exp": 1234567890, - }, - Expiration: time.Now().Add(time.Hour), - }, - }, - { - name: "failed authentication result", - result: &AuthenticationResult{ - Authenticated: false, - Subject: "", - Claims: nil, - Expiration: time.Time{}, - }, - }, - { - name: "authentication result with empty claims", - result: &AuthenticationResult{ - Authenticated: true, - Subject: "service-account", - Claims: map[string]interface{}{}, - Expiration: time.Now().Add(30 * time.Minute), - }, - }, - } +func TestNewAuthenticatorFactoryProvider(t *testing.T) { + ctx := context.Background() + logger := zap.NewNop() - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Test that all fields are accessible and have expected values - assert.Equal(t, tt.result.Authenticated, tt.result.Authenticated) - assert.Equal(t, tt.result.Subject, tt.result.Subject) - assert.Equal(t, tt.result.Claims, tt.result.Claims) - assert.Equal(t, tt.result.Expiration, tt.result.Expiration) + factory, err := newAuthenticatorFactoryProvider(ctx, logger) - // Test claims access if present - if tt.result.Claims != nil { - for key, expectedValue := range tt.result.Claims { - actualValue, exists := tt.result.Claims[key] - assert.True(t, exists, "Expected claim %s to exist", key) - assert.Equal(t, expectedValue, actualValue, "Expected claim %s to have correct value", key) - } - } - }) - } + assert.NoError(t, err) + assert.NotNil(t, factory) } -func TestAuthenticationResult_IsExpired(t *testing.T) { - now := time.Now() +func TestAuthenticatorFactory_NewAuthenticator_UnsupportedType(t *testing.T) { + ctx := context.Background() + logger := zap.NewNop() - tests := []struct { - name string - result *AuthenticationResult - checkTime time.Time - isExpired bool - }{ - { - name: "not expired - future expiration", - result: &AuthenticationResult{ - Authenticated: true, - Subject: "user123", - Expiration: now.Add(time.Hour), - }, - checkTime: now, - isExpired: false, - }, - { - name: "expired - past expiration", - result: &AuthenticationResult{ - Authenticated: true, - Subject: "user123", - Expiration: now.Add(-time.Hour), - }, - checkTime: now, - isExpired: true, - }, - { - name: "exactly at expiration time", - result: &AuthenticationResult{ - Authenticated: true, - Subject: "user123", - Expiration: now, - }, - checkTime: now, - isExpired: false, // Should not be expired at exact time - }, - { - name: "zero expiration time", - result: &AuthenticationResult{ - Authenticated: true, - Subject: "user123", - Expiration: time.Time{}, - }, - checkTime: now, - isExpired: true, // Zero time should be considered expired - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Test expiration logic - isExpired := tt.result.Expiration.Before(tt.checkTime) && !tt.result.Expiration.IsZero() - if tt.result.Expiration.IsZero() { - isExpired = true // Zero time is always expired - } + factory, err := newAuthenticatorFactoryProvider(ctx, logger) + require.NoError(t, err) - assert.Equal(t, tt.isExpired, isExpired) - }) - } -} - -// TestAuthenticatorInterface verifies that our implementations satisfy the interface -func TestAuthenticatorInterface(t *testing.T) { - tests := []struct { - name string - auth Authenticator - }{ - { - name: "SpiffeAuthenticator implements Authenticator", - auth: &SpiffeAuthenticator{}, - }, - { - name: "JwtAuthenticator implements Authenticator", - auth: &JwtAuthenticator{}, + authConfig := config.AuthConfig{ + Type: "unsupported-type", + Config: map[string]interface{}{ + "some-config": "value", }, - { - name: "MockAuthenticator implements Authenticator", - auth: NewMockAuthenticator("test"), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Verify interface compliance by checking Type method - authType := tt.auth.Type() - assert.NotEmpty(t, authType) - - // Verify that the authenticator implements all interface methods - // by checking that they can be assigned to the interface - var _ Authenticator = tt.auth - - // Test that the methods exist (compile-time check) - // We don't call them to avoid mock setup issues - assert.NotNil(t, tt.auth.Init) - assert.NotNil(t, tt.auth.Authenticate) - assert.NotNil(t, tt.auth.Close) - }) - } -} - -// TestAuthenticatorTypeUniqueness ensures different authenticator types return unique type strings -func TestAuthenticatorTypeUniqueness(t *testing.T) { - authenticators := []Authenticator{ - &SpiffeAuthenticator{}, - NewMockAuthenticator("mock1"), - NewMockAuthenticator("mock2"), } - types := make(map[string]bool) + authenticator, err := factory.NewAuthenticator(authConfig) - for _, auth := range authenticators { - authType := auth.Type() - assert.NotEmpty(t, authType, "Authenticator type should not be empty") - - // For mock authenticators with different types, they should be unique - if authType != "mock1" && authType != "mock2" { - assert.False(t, types[authType], "Authenticator type %s should be unique", authType) - } - types[authType] = true - } + assert.Error(t, err) + assert.Contains(t, err.Error(), "authenticator not found for type unsupported-type") + assert.Nil(t, authenticator) } -// TestAuthenticationResultClaimsManipulation tests working with claims -func TestAuthenticationResultClaimsManipulation(t *testing.T) { +func TestAuthenticationResult(t *testing.T) { + // Test the AuthenticationResult struct result := &AuthenticationResult{ Authenticated: true, - Subject: "test-user", - Claims: make(map[string]interface{}), - Expiration: time.Now().Add(time.Hour), - } - - // Test adding claims - result.Claims["role"] = "admin" - result.Claims["permissions"] = []string{"read", "write"} - result.Claims["numeric_claim"] = 42 - - // Verify claims were added - assert.Equal(t, "admin", result.Claims["role"]) - assert.Equal(t, []string{"read", "write"}, result.Claims["permissions"]) - assert.Equal(t, 42, result.Claims["numeric_claim"]) - - // Test modifying claims - result.Claims["role"] = "user" - assert.Equal(t, "user", result.Claims["role"]) - - // Test deleting claims - delete(result.Claims, "numeric_claim") - _, exists := result.Claims["numeric_claim"] - assert.False(t, exists) - - // Test claims count - assert.Equal(t, 2, len(result.Claims)) -} - -// TestAuthenticationResultCopy tests copying authentication results -func TestAuthenticationResultCopy(t *testing.T) { - original := &AuthenticationResult{ - Authenticated: true, - Subject: "original-user", + Subject: "test-subject", Claims: map[string]interface{}{ - "role": "admin", - "exp": 1234567890, + "iss": "test-issuer", + "sub": "test-subject", + "aud": "test-audience", }, - Expiration: time.Now().Add(time.Hour), } - // Create a copy - copy := &AuthenticationResult{ - Authenticated: original.Authenticated, - Subject: original.Subject, - Claims: make(map[string]interface{}), - Expiration: original.Expiration, - } - - // Copy claims - for k, v := range original.Claims { - copy.Claims[k] = v - } - - // Verify copy is identical - assert.Equal(t, original.Authenticated, copy.Authenticated) - assert.Equal(t, original.Subject, copy.Subject) - assert.Equal(t, original.Expiration, copy.Expiration) - assert.Equal(t, len(original.Claims), len(copy.Claims)) - - for k, v := range original.Claims { - assert.Equal(t, v, copy.Claims[k]) - } - - // Verify they are independent (modifying copy doesn't affect original) - copy.Subject = "modified-user" - copy.Claims["new_claim"] = "new_value" - - assert.NotEqual(t, original.Subject, copy.Subject) - _, exists := original.Claims["new_claim"] - assert.False(t, exists) + assert.True(t, result.Authenticated) + assert.Equal(t, "test-subject", result.Subject) + assert.Equal(t, "test-issuer", result.Claims["iss"]) + assert.Equal(t, "test-subject", result.Claims["sub"]) + assert.Equal(t, "test-audience", result.Claims["aud"]) } -// TestAuthenticationResultValidation tests validation scenarios -func TestAuthenticationResultValidation(t *testing.T) { - tests := []struct { - name string - result *AuthenticationResult - isValid bool - reason string - }{ - { - name: "valid authenticated result", - result: &AuthenticationResult{ - Authenticated: true, - Subject: "user123", - Claims: map[string]interface{}{"role": "admin"}, - Expiration: time.Now().Add(time.Hour), - }, - isValid: true, - reason: "complete valid result", - }, - { - name: "valid unauthenticated result", - result: &AuthenticationResult{ - Authenticated: false, - Subject: "", - Claims: nil, - Expiration: time.Time{}, - }, - isValid: true, - reason: "valid failure result", - }, - { - name: "authenticated but no subject", - result: &AuthenticationResult{ - Authenticated: true, - Subject: "", - Claims: map[string]interface{}{"role": "admin"}, - Expiration: time.Now().Add(time.Hour), - }, - isValid: false, - reason: "authenticated results should have a subject", - }, - { - name: "authenticated but expired", - result: &AuthenticationResult{ - Authenticated: true, - Subject: "user123", - Claims: map[string]interface{}{"role": "admin"}, - Expiration: time.Now().Add(-time.Hour), - }, - isValid: false, - reason: "authenticated results should not be expired", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Basic validation logic - isValid := true - - if tt.result.Authenticated { - // If authenticated, should have a subject - if tt.result.Subject == "" { - isValid = false - } - // If authenticated, should not be expired - if !tt.result.Expiration.IsZero() && tt.result.Expiration.Before(time.Now()) { - isValid = false - } - } - - assert.Equal(t, tt.isValid, isValid, tt.reason) - }) - } -} - -// TestAuthenticationResultEdgeCases tests edge cases and boundary conditions -func TestAuthenticationResultEdgeCases(t *testing.T) { - t.Run("nil claims map", func(t *testing.T) { - result := &AuthenticationResult{ - Authenticated: true, - Subject: "user123", - Claims: nil, - Expiration: time.Now().Add(time.Hour), - } - - // Should not panic when accessing nil claims - assert.Nil(t, result.Claims) - - // Initialize claims if needed - if result.Claims == nil { - result.Claims = make(map[string]interface{}) - } - - result.Claims["test"] = "value" - assert.Equal(t, "value", result.Claims["test"]) - }) - - t.Run("very long subject", func(t *testing.T) { - longSubject := string(make([]byte, 10000)) - for i := range longSubject { - longSubject = longSubject[:i] + "a" + longSubject[i+1:] - } - - result := &AuthenticationResult{ - Authenticated: true, - Subject: longSubject, - Claims: map[string]interface{}{}, - Expiration: time.Now().Add(time.Hour), - } - - assert.Equal(t, 10000, len(result.Subject)) - assert.Equal(t, longSubject, result.Subject) - }) - - t.Run("complex claims structure", func(t *testing.T) { - complexClaims := map[string]interface{}{ - "string_claim": "value", - "int_claim": 42, - "float_claim": 3.14, - "bool_claim": true, - "array_claim": []interface{}{"a", "b", "c"}, - "nested_claim": map[string]interface{}{ - "inner_string": "inner_value", - "inner_int": 123, - }, - } - - result := &AuthenticationResult{ - Authenticated: true, - Subject: "user123", - Claims: complexClaims, - Expiration: time.Now().Add(time.Hour), - } - - // Verify all claim types are preserved - assert.Equal(t, "value", result.Claims["string_claim"]) - assert.Equal(t, 42, result.Claims["int_claim"]) - assert.Equal(t, 3.14, result.Claims["float_claim"]) - assert.Equal(t, true, result.Claims["bool_claim"]) - assert.Equal(t, []interface{}{"a", "b", "c"}, result.Claims["array_claim"]) - - nestedClaim := result.Claims["nested_claim"].(map[string]interface{}) - assert.Equal(t, "inner_value", nestedClaim["inner_string"]) - assert.Equal(t, 123, nestedClaim["inner_int"]) - }) +// Note: JWT and SPIFFE authenticator tests are in their respective _test.go files +// (jwt_test.go and spiffe_test.go) where they can properly mock external dependencies +// without making real network calls or requiring actual SPIFFE infrastructure. +// +// This file focuses on testing the factory pattern and basic functionality +// that doesn't require external dependencies. + +func TestAuthenticatorFactory_FactoryPattern(t *testing.T) { + // Test that the factory correctly registers authenticator constructors + ctx := context.Background() + logger := zap.NewNop() + + factory, err := newAuthenticatorFactoryProvider(ctx, logger) + require.NoError(t, err) + + // Cast to concrete type to access internal state for testing + concreteFactory, ok := factory.(*authenticatorFactory) + require.True(t, ok, "Factory should be of type *authenticatorFactory") + + // Verify that JWT and SPIFFE providers are registered + assert.Contains(t, concreteFactory.providers, "jwt", "JWT provider should be registered") + assert.Contains(t, concreteFactory.providers, "spiffe", "SPIFFE provider should be registered") + assert.Len(t, concreteFactory.providers, 2, "Should have exactly 2 providers registered") } diff --git a/auth/fx.go b/auth/fx.go new file mode 100644 index 0000000..b825a25 --- /dev/null +++ b/auth/fx.go @@ -0,0 +1,7 @@ +package auth + +import "go.uber.org/fx" + +var Module = fx.Provide( + newAuthenticatorFactoryProvider, +) diff --git a/auth/manager.go b/auth/manager.go deleted file mode 100644 index 3e9b482..0000000 --- a/auth/manager.go +++ /dev/null @@ -1,71 +0,0 @@ -package auth - -import ( - "context" - "errors" - "fmt" - "sync" -) - -type AuthManager struct { - authenticators map[string]Authenticator - mu sync.RWMutex -} - -func NewAuthManager() *AuthManager { - return &AuthManager{ - authenticators: make(map[string]Authenticator), - } -} - -func (am *AuthManager) RegisterAuthenticator(auth Authenticator) error { - am.mu.Lock() - defer am.mu.Unlock() - - typ := auth.Type() - if _, exists := am.authenticators[typ]; exists { - return fmt.Errorf("authenticator with type %s already registered", typ) - } - - am.authenticators[typ] = auth - return nil -} - -func (am *AuthManager) GetAuthenticator(name string) (Authenticator, error) { - am.mu.RLock() - defer am.mu.RUnlock() - - auth, exists := am.authenticators[name] - if !exists { - return nil, fmt.Errorf("authenticator with name %s not found", name) - } - - return auth, nil -} - -func (am *AuthManager) Authenticate(ctx context.Context, name string, credentials interface{}) (*AuthenticationResult, error) { - auth, err := am.GetAuthenticator(name) - if err != nil { - return nil, err - } - - return auth.Authenticate(ctx, credentials) -} - -func (am *AuthManager) Close() error { - am.mu.Lock() - defer am.mu.Unlock() - - var errs []error - for name, auth := range am.authenticators { - if err := auth.Close(); err != nil { - errs = append(errs, fmt.Errorf("failed to close authenticator %s: %w", name, err)) - } - } - - if len(errs) > 0 { - return errors.Join(errs...) - } - - return nil -} diff --git a/auth/manager_test.go b/auth/manager_test.go deleted file mode 100644 index 3873d8c..0000000 --- a/auth/manager_test.go +++ /dev/null @@ -1,413 +0,0 @@ -package auth - -import ( - "context" - "errors" - "fmt" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -// MockAuthenticator is a mock implementation of the Authenticator interface -type MockAuthenticator struct { - mock.Mock - authType string -} - -func NewMockAuthenticator(authType string) *MockAuthenticator { - return &MockAuthenticator{authType: authType} -} - -func (m *MockAuthenticator) Type() string { - if m.authType != "" { - return m.authType - } - args := m.Called() - return args.String(0) -} - -func (m *MockAuthenticator) Init(ctx context.Context, config map[string]interface{}) error { - args := m.Called(ctx, config) - return args.Error(0) -} - -func (m *MockAuthenticator) Authenticate(ctx context.Context, credentials interface{}) (*AuthenticationResult, error) { - args := m.Called(ctx, credentials) - return args.Get(0).(*AuthenticationResult), args.Error(1) -} - -func (m *MockAuthenticator) Close() error { - args := m.Called() - return args.Error(0) -} - -func TestNewAuthManager(t *testing.T) { - manager := NewAuthManager() - - assert.NotNil(t, manager) - assert.NotNil(t, manager.authenticators) - assert.Equal(t, 0, len(manager.authenticators)) -} - -func TestAuthManager_RegisterAuthenticator(t *testing.T) { - tests := []struct { - name string - authenticators []string - expectError bool - errorContains string - }{ - { - name: "register single authenticator", - authenticators: []string{"jwt"}, - expectError: false, - }, - { - name: "register multiple authenticators", - authenticators: []string{"jwt", "oauth", "spiffe"}, - expectError: false, - }, - { - name: "register duplicate authenticator", - authenticators: []string{"jwt", "jwt"}, - expectError: true, - errorContains: "authenticator with type jwt already registered", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - manager := NewAuthManager() - var err error - - for _, authType := range tt.authenticators { - mockAuth := NewMockAuthenticator(authType) - err = manager.RegisterAuthenticator(mockAuth) - - if tt.expectError && err != nil { - break - } - } - - if tt.expectError { - assert.Error(t, err) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) - assert.Equal(t, len(tt.authenticators), len(manager.authenticators)) - } - }) - } -} - -func TestAuthManager_GetAuthenticator(t *testing.T) { - manager := NewAuthManager() - - // Register test authenticators - jwtAuth := NewMockAuthenticator("jwt") - oauthAuth := NewMockAuthenticator("oauth") - - err := manager.RegisterAuthenticator(jwtAuth) - require.NoError(t, err) - err = manager.RegisterAuthenticator(oauthAuth) - require.NoError(t, err) - - tests := []struct { - name string - authName string - expectError bool - errorContains string - }{ - { - name: "get existing jwt authenticator", - authName: "jwt", - expectError: false, - }, - { - name: "get existing oauth authenticator", - authName: "oauth", - expectError: false, - }, - { - name: "get non-existing authenticator", - authName: "nonexistent", - expectError: true, - errorContains: "authenticator with name nonexistent not found", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - auth, err := manager.GetAuthenticator(tt.authName) - - if tt.expectError { - assert.Error(t, err) - assert.Nil(t, auth) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) - assert.NotNil(t, auth) - assert.Equal(t, tt.authName, auth.Type()) - } - }) - } -} - -func TestAuthManager_Authenticate(t *testing.T) { - manager := NewAuthManager() - ctx := context.Background() - - // Setup mock authenticator - mockAuth := NewMockAuthenticator("jwt") - expectedResult := &AuthenticationResult{ - Authenticated: true, - Subject: "test-user", - Claims: map[string]interface{}{"role": "admin"}, - Expiration: time.Now().Add(time.Hour), - } - - mockAuth.On("Authenticate", ctx, "valid-token").Return(expectedResult, nil) - mockAuth.On("Authenticate", ctx, "invalid-token").Return((*AuthenticationResult)(nil), errors.New("invalid token")) - - err := manager.RegisterAuthenticator(mockAuth) - require.NoError(t, err) - - tests := []struct { - name string - authName string - credentials interface{} - expectError bool - errorContains string - expectedAuth bool - }{ - { - name: "successful authentication", - authName: "jwt", - credentials: "valid-token", - expectError: false, - expectedAuth: true, - }, - { - name: "failed authentication", - authName: "jwt", - credentials: "invalid-token", - expectError: true, - errorContains: "invalid token", - }, - { - name: "non-existing authenticator", - authName: "nonexistent", - credentials: "any-token", - expectError: true, - errorContains: "authenticator with name nonexistent not found", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := manager.Authenticate(ctx, tt.authName, tt.credentials) - - if tt.expectError { - assert.Error(t, err) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, tt.expectedAuth, result.Authenticated) - } - }) - } - - mockAuth.AssertExpectations(t) -} - -func TestAuthManager_Close(t *testing.T) { - tests := []struct { - name string - setupMocks func() []*MockAuthenticator - expectError bool - errorContains string - }{ - { - name: "close all authenticators successfully", - setupMocks: func() []*MockAuthenticator { - auth1 := NewMockAuthenticator("jwt") - auth2 := NewMockAuthenticator("oauth") - auth1.On("Close").Return(nil) - auth2.On("Close").Return(nil) - return []*MockAuthenticator{auth1, auth2} - }, - expectError: false, - }, - { - name: "close with one authenticator error", - setupMocks: func() []*MockAuthenticator { - auth1 := NewMockAuthenticator("jwt") - auth2 := NewMockAuthenticator("oauth") - auth1.On("Close").Return(errors.New("close error")) - auth2.On("Close").Return(nil) - return []*MockAuthenticator{auth1, auth2} - }, - expectError: true, - errorContains: "failed to close authenticator jwt", - }, - { - name: "close with multiple authenticator errors", - setupMocks: func() []*MockAuthenticator { - auth1 := NewMockAuthenticator("jwt") - auth2 := NewMockAuthenticator("oauth") - auth1.On("Close").Return(errors.New("jwt close error")) - auth2.On("Close").Return(errors.New("oauth close error")) - return []*MockAuthenticator{auth1, auth2} - }, - expectError: true, - errorContains: "failed to close authenticator", - }, - { - name: "close empty manager", - setupMocks: func() []*MockAuthenticator { - return []*MockAuthenticator{} - }, - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - manager := NewAuthManager() - mocks := tt.setupMocks() - - // Register all mock authenticators - for _, mockAuth := range mocks { - err := manager.RegisterAuthenticator(mockAuth) - require.NoError(t, err) - } - - err := manager.Close() - - if tt.expectError { - assert.Error(t, err) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) - } - - // Verify all mocks - for _, mockAuth := range mocks { - mockAuth.AssertExpectations(t) - } - }) - } -} - -func TestAuthManager_ConcurrentAccess(t *testing.T) { - manager := NewAuthManager() - ctx := context.Background() - - // Setup authenticators - numAuthenticators := 10 - var wg sync.WaitGroup - - // Concurrent registration - wg.Add(numAuthenticators) - for i := 0; i < numAuthenticators; i++ { - go func(id int) { - defer wg.Done() - authType := fmt.Sprintf("auth-%d", id) - mockAuth := NewMockAuthenticator(authType) - mockAuth.On("Authenticate", mock.Anything, mock.Anything).Return(&AuthenticationResult{ - Authenticated: true, - Subject: fmt.Sprintf("user-%d", id), - }, nil) - - err := manager.RegisterAuthenticator(mockAuth) - assert.NoError(t, err) - }(i) - } - wg.Wait() - - // Verify all authenticators were registered - assert.Equal(t, numAuthenticators, len(manager.authenticators)) - - // Concurrent authentication - numRequests := 50 - wg.Add(numRequests) - - for i := 0; i < numRequests; i++ { - go func(id int) { - defer wg.Done() - authType := fmt.Sprintf("auth-%d", id%numAuthenticators) - - result, err := manager.Authenticate(ctx, authType, "test-creds") - assert.NoError(t, err) - assert.True(t, result.Authenticated) - }(i) - } - wg.Wait() - - // Concurrent get operations - wg.Add(numRequests) - for i := 0; i < numRequests; i++ { - go func(id int) { - defer wg.Done() - authType := fmt.Sprintf("auth-%d", id%numAuthenticators) - - auth, err := manager.GetAuthenticator(authType) - assert.NoError(t, err) - assert.NotNil(t, auth) - assert.Equal(t, authType, auth.Type()) - }(i) - } - wg.Wait() -} - -func TestAuthManager_ThreadSafety(t *testing.T) { - manager := NewAuthManager() - ctx := context.Background() - - // Test concurrent read/write operations - var wg sync.WaitGroup - numOperations := 100 - - // Concurrent registration and authentication - wg.Add(numOperations * 2) - - for i := 0; i < numOperations; i++ { - // Registration goroutine - go func(id int) { - defer wg.Done() - authType := fmt.Sprintf("concurrent-auth-%d", id) - mockAuth := NewMockAuthenticator(authType) - mockAuth.On("Authenticate", mock.Anything, mock.Anything).Return(&AuthenticationResult{ - Authenticated: true, - Subject: fmt.Sprintf("user-%d", id), - }, nil) - - manager.RegisterAuthenticator(mockAuth) - }(i) - - // Authentication goroutine (may fail if authenticator not yet registered) - go func(id int) { - defer wg.Done() - authType := fmt.Sprintf("concurrent-auth-%d", id) - manager.Authenticate(ctx, authType, "test-creds") - }(i) - } - - wg.Wait() - - // Verify no race conditions occurred (test should not panic) - assert.True(t, len(manager.authenticators) <= numOperations) -} diff --git a/cmd/main.go b/cmd/main.go index 968a121..9b0112e 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -3,176 +3,116 @@ package main import ( "context" "fmt" - "log" - "net" - "net/http" + "github.com/temporal-sa/temporal-cloud-proxy/auth" + "github.com/temporal-sa/temporal-cloud-proxy/codec" + "github.com/temporal-sa/temporal-cloud-proxy/config" + "github.com/temporal-sa/temporal-cloud-proxy/metrics" + "github.com/temporal-sa/temporal-cloud-proxy/proxy" + "github.com/temporal-sa/temporal-cloud-proxy/transport" + "go.uber.org/fx" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" "os" "os/signal" - "strconv" "syscall" - "temporal-sa/temporal-cloud-proxy/auth" - "temporal-sa/temporal-cloud-proxy/crypto" - "temporal-sa/temporal-cloud-proxy/metrics" - "temporal-sa/temporal-cloud-proxy/proxy" - "temporal-sa/temporal-cloud-proxy/utils" - "time" - - "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/urfave/cli/v2" - "go.opentelemetry.io/otel/attribute" - "go.temporal.io/api/workflowservice/v1" - "go.temporal.io/sdk/client" - "google.golang.org/grpc" ) var configFilePath string -func main() { +func run(args []string) error { + app := buildCLIOptions() + return app.Run(args) +} + +func buildCLIOptions() *cli.App { app := &cli.App{ Name: "tclp", Usage: "Temporal Cloud Proxy", Flags: []cli.Flag{ &cli.StringFlag{ - Name: "config", - Usage: "config file", + Name: config.ConfigPathFlag, + Usage: "Path to yaml config file. Default is ./config.yaml", Aliases: []string{"c"}, - Value: "config.yaml", + Value: config.DefaultConfigPath, Destination: &configFilePath, }, + &cli.StringFlag{ + Name: config.LogLevelFlag, + Usage: "Set log level (debug, info, warn, error). Default level is info", + Required: false, + }, }, - Action: func(*cli.Context) error { - configManager, err := utils.NewConfigManager(configFilePath) - if err != nil { - return err - } - defer configManager.Close() - - cfg := configManager.GetConfig() - - proxyConns := proxy.NewConn() - defer proxyConns.CloseAll() - - if err := configureProxy(proxyConns, cfg); err != nil { - return err - } - - workflowClient := workflowservice.NewWorkflowServiceClient(proxyConns) - - handler, err := client.NewWorkflowServiceProxyServer( - client.WorkflowServiceProxyOptions{Client: workflowClient}, - ) - if err != nil { - return err - } - - grpcServer := grpc.NewServer() - workflowservice.RegisterWorkflowServiceServer(grpcServer, handler) - - // Initialize metrics - metrics.InitPrometheus() - metricsServer := &http.Server{Addr: ":" + strconv.Itoa(cfg.Metrics.Port)} - http.Handle(metrics.DefaultPrometheusPath, promhttp.Handler()) - go func() { - fmt.Printf("Metrics is exposed at %s:%d%s\n", cfg.Server.Host, cfg.Metrics.Port, metrics.DefaultPrometheusPath) - if err := metricsServer.ListenAndServe(); err != http.ErrServerClosed { - log.Printf("metrics server error: %v", err) - } - }() - - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - go func() { - <-c - fmt.Println("\nShutting down gracefully...") - grpcServer.GracefulStop() - os.Exit(0) - }() - - lis, err := net.Listen("tcp", cfg.Server.Host+":"+strconv.Itoa(cfg.Server.Port)) - if err != nil { - return err - } - - fmt.Printf("Proxy is listening on %s:%d\n", cfg.Server.Host, cfg.Server.Port) - - err = grpcServer.Serve(lis) - - return err - }, + Action: startProxy, } - if err := app.Run(os.Args); err != nil { - log.Fatalln(err) + return app +} + +func startProxy(c *cli.Context) error { + logLevel := c.String(config.LogLevelFlag) + + var zapLevel zapcore.Level + if err := zapLevel.UnmarshalText([]byte(logLevel)); err != nil { + return fmt.Errorf("invalid log level %q: %w", logLevel, err) } + + loggerConfig := zap.NewProductionConfig() + loggerConfig.Level.SetLevel(zapLevel) + logger, err := loggerConfig.Build() + if err != nil { + return fmt.Errorf("failed to initialise logger: %w", err) + } + + app := fx.New( + fx.Provide( + func() *zap.Logger { + return logger + }, + func() *cli.Context { return c }, + func() context.Context { return c.Context }, + ), + + auth.Module, + codec.Module, + config.Module, + metrics.Module, + proxy.Module, + transport.Module, + + fx.Invoke( + func(metrics.MetricsProvider) {}, + func(transport.TransportProvider) {}, + ), + ) + + if err := app.Start(context.Background()); err != nil { + return err + } + + <-interruptCh() + + return app.Stop(context.Background()) } -func configureProxy(proxyConns *proxy.Conn, cfg *utils.Config) error { - ctx := context.TODO() - - for _, w := range cfg.Workloads { - var authManager *auth.AuthManager - var authType string - - if w.Authentication != nil { - authManager = auth.NewAuthManager() - authType = w.Authentication.Type - var authProvider auth.Authenticator - - switch authType { - case "spiffe": - authProvider = &auth.SpiffeAuthenticator{} - case "jwt": - authProvider = &auth.JwtAuthenticator{} - default: - return fmt.Errorf("unsupported authentication type: %s", authType) - } - - if err := authProvider.Init(ctx, w.Authentication.Config); err != nil { - return fmt.Errorf("failed to initialize %s authenticator: %w", authType, err) - } - - if err := authManager.RegisterAuthenticator(authProvider); err != nil { - return err - } - } - - metricsHandler := metrics.NewMetricsHandler(metrics.MetricsHandlerOptions{ - // Todo: do we need these many attributes? - InitialAttributes: attribute.NewSet( - attribute.String("workload_id", w.WorkloadId), - attribute.String("namespace", w.TemporalCloud.Namespace), - attribute.String("host_port", w.TemporalCloud.HostPort), - attribute.String("auth_type", authType), - attribute.String("encryption_key", w.EncryptionKey), - ), - }) - - // Parse global caching config - var cachingConfig *crypto.CachingConfig - if cfg.Encryption.Caching.MaxCache > 0 || cfg.Encryption.Caching.MaxAge != "" || cfg.Encryption.Caching.MaxUsage > 0 { - cachingConfig = &crypto.CachingConfig{ - MaxCache: cfg.Encryption.Caching.MaxCache, - MaxMessagesUsed: cfg.Encryption.Caching.MaxUsage, - } - if cfg.Encryption.Caching.MaxAge != "" { - if duration, err := time.ParseDuration(cfg.Encryption.Caching.MaxAge); err == nil { - cachingConfig.MaxAge = duration - } - } - } - - err := proxyConns.AddConn(proxy.AddConnInput{ - Workload: &w, - AuthManager: authManager, - AuthType: authType, - MetricsHandler: metricsHandler, - CryptoCachingConfig: cachingConfig, - }) - - if err != nil { - return err - } +func main() { + if err := run(os.Args); err != nil { + panic(err) } +} + +func interruptCh() <-chan interface{} { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + + ret := make(chan interface{}, 1) + go func() { + s := <-c + ret <- s + close(ret) + signal.Stop(c) + }() - return nil + return ret } diff --git a/codec/codec.go b/codec/codec.go new file mode 100644 index 0000000..bfc80c6 --- /dev/null +++ b/codec/codec.go @@ -0,0 +1,149 @@ +package codec + +import ( + gcpKms "cloud.google.com/go/kms/apiv1" + "context" + "fmt" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + awsKms "github.com/aws/aws-sdk-go/service/kms" + "github.com/temporal-sa/temporal-cloud-proxy/config" + "github.com/temporal-sa/temporal-cloud-proxy/crypto" + "github.com/temporal-sa/temporal-cloud-proxy/metrics" + "go.opentelemetry.io/otel/attribute" + "go.temporal.io/sdk/converter" + "os" + "time" +) + +// +// This could be extended to include codecs of other types (e.g. compression), but +// is currently focused on encryption specifically. +// + +type ( + EncryptionCodecFactory interface { + NewEncryptionCodec(args EncryptionCodecOptions) (converter.PayloadCodec, error) + } + + EncryptionCodec interface { + converter.PayloadCodec + } + + EncryptionCodecOptions struct { + LocalEncryptionConfig config.EncryptionConfig + CodecContext map[string]string + MetricsHandler *metrics.MetricsHandler + } + + EncryptionCodecConstructor func(args EncryptionCodecOptions) (converter.PayloadCodec, error) + + encryptionCodecFactory struct { + providers map[string]EncryptionCodecConstructor + cachingConfig *crypto.CachingConfig + } +) + +func newCodecFactoryProvider(configProvider config.ConfigProvider) (EncryptionCodecFactory, error) { + var cachingConfig *crypto.CachingConfig + + providerCacheCfg := configProvider.GetProxyConfig().Encryption.Caching + if providerCacheCfg.MaxCache > 0 || providerCacheCfg.MaxAge != "" || providerCacheCfg.MaxUsage > 0 { + cachingConfig = &crypto.CachingConfig{ + MaxCache: providerCacheCfg.MaxCache, + MaxMessagesUsed: providerCacheCfg.MaxUsage, + } + if providerCacheCfg.MaxAge != "" { + if duration, err := time.ParseDuration(providerCacheCfg.MaxAge); err == nil { + cachingConfig.MaxAge = duration + } + } + } + + cf := &encryptionCodecFactory{ + providers: make(map[string]EncryptionCodecConstructor), + cachingConfig: cachingConfig, + } + + cf.providers["aws-kms"] = func(args EncryptionCodecOptions) (converter.PayloadCodec, error) { + rawKeyId, ok := args.LocalEncryptionConfig.Config["key-id"] + if !ok { + return nil, fmt.Errorf("key not found in config") + } + keyId, ok := rawKeyId.(string) + if !ok { + return nil, fmt.Errorf("key is not a string") + } + + region := os.Getenv(config.AwsRegionEnvVar) + if region == "" { + region = config.DefaultAwsRegion + } + sess := session.Must(session.NewSession(&aws.Config{ + Region: aws.String(region), + })) + kmsClient := awsKms.New(sess) + + awsMaterialsManager := crypto.NewAWSKMSProvider(kmsClient, crypto.AWSKMSOptions{ + KeyID: keyId, + KeySpec: "AES_256", + }) + + args.MetricsHandler.AddAttributes(attribute.String("encryption_key", keyId)) + + return NewEncryptionCodecWithCaching( + awsMaterialsManager, + args.CodecContext, + keyId, + args.MetricsHandler, + cf.cachingConfig, + ), nil + } + + cf.providers["gcp-kms"] = func(args EncryptionCodecOptions) (converter.PayloadCodec, error) { + rawKeyName, ok := args.LocalEncryptionConfig.Config["key-name"] + if !ok { + return nil, fmt.Errorf("key not found in config") + } + keyName, ok := rawKeyName.(string) + if !ok { + return nil, fmt.Errorf("key is not a string") + } + + region := os.Getenv(config.GcpRegionEnvVar) + if region == "" { + region = config.DefaultGcpRegion + } + + kmsClient, err := gcpKms.NewKeyManagementClient(context.TODO()) + if err != nil { + return nil, err + } + + gcpMaterialsManager := crypto.NewGCPKMSProvider(kmsClient, crypto.GCPKMSOptions{ + KeyName: keyName, + Algorithm: "AES_256", + }) + + args.MetricsHandler.AddAttributes(attribute.String("encryption_key", keyName)) + + return NewEncryptionCodecWithCaching( + gcpMaterialsManager, + args.CodecContext, + keyName, + args.MetricsHandler, + cf.cachingConfig, + ), nil + } + + return cf, nil +} + +func (e *encryptionCodecFactory) NewEncryptionCodec(args EncryptionCodecOptions) (converter.PayloadCodec, error) { + encryptionCodec, ok := e.providers[args.LocalEncryptionConfig.Type] + if !ok { + return nil, fmt.Errorf("unsupported encryption type %s", args.LocalEncryptionConfig.Type) + } + + return encryptionCodec(args) +} diff --git a/codec/encryption_codec.go b/codec/encryption_codec.go index b21ec28..7ed305e 100644 --- a/codec/encryption_codec.go +++ b/codec/encryption_codec.go @@ -5,10 +5,8 @@ import ( "fmt" "time" - "temporal-sa/temporal-cloud-proxy/crypto" - "temporal-sa/temporal-cloud-proxy/metrics" - - "github.com/aws/aws-sdk-go/service/kms" + "github.com/temporal-sa/temporal-cloud-proxy/crypto" + "github.com/temporal-sa/temporal-cloud-proxy/metrics" commonpb "go.temporal.io/api/common/v1" "go.temporal.io/sdk/client" @@ -39,7 +37,7 @@ type Codec struct { // NewEncryptionCodecWithCaching creates a new encryption codec with configurable caching. func NewEncryptionCodecWithCaching( - kmsClient *kms.KMS, + kmsProvider crypto.MaterialsManager, codecContext map[string]string, encryptionKeyID string, metricsHandler client.MetricsHandler, @@ -54,15 +52,9 @@ func NewEncryptionCodecWithCaching( } } - // Create AWS KMS provider - awsProvider := crypto.NewAWSKMSProvider(kmsClient, crypto.KMSOptions{ - KeyID: encryptionKeyID, - KeySpec: "AES_256", - }) - // Create caching materials manager cachingMM, _ := crypto.NewCachingMaterialsManager( - awsProvider, + kmsProvider, *cachingConfig, metricsHandler, ) diff --git a/codec/fx.go b/codec/fx.go new file mode 100644 index 0000000..292ff95 --- /dev/null +++ b/codec/fx.go @@ -0,0 +1,7 @@ +package codec + +import "go.uber.org/fx" + +var Module = fx.Provide( + newCodecFactoryProvider, +) diff --git a/config.yaml.sample b/config.yaml.sample index 092a4aa..4cea565 100644 --- a/config.yaml.sample +++ b/config.yaml.sample @@ -22,36 +22,16 @@ workloads: cert_file: "/path/to/./tls.crt" key_file: "/path/to/./tls.key" api_key: # only set either value or env, not both - value: "" - env: - encryption_key: "" -# authentication: # spiffe authentication example -# type: "spiffe" + value: "" + env: +# encryption: # aws kms encryption example +# type: "aws-kms" # config: -# trust_domain: "spiffe://example.org/" -# endpoint: "unix:///tmp/spire-agent/public/api.sock" -# audiences: -# - "temporal_cloud_proxy" -# authentication: # jwt authentication example -# type: "jwt" +# key-id: "" +# encryption: # gcp kms encryption example +# type: "gcp-kms" # config: -# jwks-url: "http://localhost:8200/v1/identity/oidc/.well-known/keys" -# audiences: -# - "temporal_cloud_proxy" - - - workload_id: "" - temporal_cloud: - namespace: "." - host_port: "..tmprl.cloud:7233" # endpoint when using mTLS - # host_port: "..api.temporal.io:7233" # endpoint when using API keys - authentication: # only set either tls or api_key, not both - tls: - cert_file: "/path/to/./tls.crt" - key_file: "/path/to/./tls.key" - api_key: # only set either value or env, not both - value: "" - env: - encryption_key: "" +# key-name: "" # authentication: # spiffe authentication example # type: "spiffe" # config: @@ -64,4 +44,4 @@ workloads: # config: # jwks-url: "http://localhost:8200/v1/identity/oidc/.well-known/keys" # audiences: -# - "temporal_cloud_proxy" +# - "temporal_cloud_proxy" \ No newline at end of file diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..c0b5346 --- /dev/null +++ b/config/config.go @@ -0,0 +1,208 @@ +package config + +import ( + "errors" + "fmt" + "github.com/urfave/cli/v2" + "gopkg.in/yaml.v3" + "os" +) + +const ( + ConfigPathFlag = "config" + DefaultConfigPath = "config.yaml" + LogLevelFlag = "level" + + AwsRegionEnvVar = "AWS_REGION" + DefaultAwsRegion = "us-west-2" + + GcpRegionEnvVar = "GCP_REGION" + DefaultGcpRegion = "us-central1" +) + +type ( + ConfigProvider interface { + GetProxyConfig() ProxyConfig + } + + ProxyConfig struct { + Server ServerConfig `yaml:"server"` + Metrics MetricsConfig `yaml:"metrics"` + Encryption GlobalEncryptionConfig `yaml:"encryption"` + Workloads []WorkloadConfig `yaml:"workloads"` + } + + ServerConfig struct { + Port int `yaml:"port"` + Host string `yaml:"host"` + } + + MetricsConfig struct { + Port int `yaml:"port"` + } + + GlobalEncryptionConfig struct { + Caching CachingConfig `yaml:"caching"` + } + + CachingConfig struct { + MaxCache int `yaml:"max_cache,omitempty"` + MaxAge string `yaml:"max_age,omitempty"` + MaxUsage int `yaml:"max_usage,omitempty"` + } + + WorkloadConfig struct { + WorkloadId string `yaml:"workload_id"` + TemporalCloud TemporalCloudConfig `yaml:"temporal_cloud"` + Encryption *EncryptionConfig `yaml:"encryption,omitempty"` + Authentication *AuthConfig `yaml:"authentication,omitempty"` + } + + TemporalCloudConfig struct { + Namespace string `yaml:"namespace"` + HostPort string `yaml:"host_port"` + Authentication TemporalAuthConfig `yaml:"authentication"` + } + + TemporalAuthConfig struct { + TLS *TLSConfig `yaml:"tls,omitempty"` + ApiKey *TemporalApiKeyConfig `yaml:"api_key,omitempty"` + } + + TemporalApiKeyConfig struct { + Value string `yaml:"value,omitempty"` + EnvVar string `yaml:"env,omitempty"` + } + + TLSConfig struct { + CertFile string `yaml:"cert_file"` + KeyFile string `yaml:"key_file"` + } + + AuthConfig struct { + Type string `yaml:"type"` + Config map[string]interface{} `yaml:"config"` + } + + EncryptionConfig struct { + Type string `yaml:"type"` + Config map[string]interface{} `yaml:"config"` + } + + cliConfigProvider struct { + ctx *cli.Context + proxyConfig ProxyConfig + } +) + +func newConfigProvider(ctx *cli.Context) (ConfigProvider, error) { + proxyConfig, err := LoadConfig(ctx.String(ConfigPathFlag)) + if err != nil { + return nil, err + } + + return &cliConfigProvider{ + ctx: ctx, + proxyConfig: proxyConfig, + }, nil +} + +func (c *cliConfigProvider) GetProxyConfig() ProxyConfig { + return c.proxyConfig +} + +func LoadConfig(configFilePath string) (ProxyConfig, error) { + var config ProxyConfig + + configFile, err := os.ReadFile(configFilePath) + if err != nil { + return config, fmt.Errorf("failed to read config file: %w", err) + } + + if err = yaml.Unmarshal(configFile, &config); err != nil { + return config, fmt.Errorf("failed to unmarshal config file: %w", err) + } + + if err = config.Validate(); err != nil { + return config, fmt.Errorf("failed to validate config: %w", err) + } + + return config, nil +} + +func (p *ProxyConfig) Validate() error { + var errs []error + + errs = append(errs, p.Server.Validate()...) + errs = append(errs, p.Metrics.Validate()...) + errs = append(errs, p.Encryption.Validate()...) + + workloadIds := make(map[string]bool) + for _, workload := range p.Workloads { + if _, exists := workloadIds[workload.WorkloadId]; exists { + errs = append(errs, fmt.Errorf("workload already exists: %s", workload.WorkloadId)) + } + errs = append(errs, workload.Validate()...) + workloadIds[workload.WorkloadId] = true + } + + return errors.Join(errs...) +} + +func (s *ServerConfig) Validate() []error { + var errs []error + + if s.Port <= 0 || s.Port > 65535 { + errs = append(errs, fmt.Errorf("invalid server port: %d", s.Port)) + } + + return errs +} + +func (m *MetricsConfig) Validate() []error { + var errs []error + + if m.Port <= 0 || m.Port > 65535 { + errs = append(errs, fmt.Errorf("invalid metrics server port: %d", m.Port)) + } + + return errs +} + +func (g *GlobalEncryptionConfig) Validate() []error { + var errs []error + + if g.Caching.MaxCache < 0 { + errs = append(errs, fmt.Errorf("encryption max_cache must be >= 0: %d", g.Caching.MaxCache)) + } + + if g.Caching.MaxUsage < 0 { + errs = append(errs, fmt.Errorf("encryption max_usage must be >= 0: %d", g.Caching.MaxUsage)) + } + + return errs +} + +func (w *WorkloadConfig) Validate() []error { + var errs []error + + if w.WorkloadId == "" { + errs = append(errs, fmt.Errorf("workload_id is required")) + } + + if w.TemporalCloud.Namespace == "" { + errs = append(errs, fmt.Errorf("temporal cloud namespace must not be blank: %s", w.WorkloadId)) + } + + if w.TemporalCloud.HostPort == "" { + errs = append(errs, fmt.Errorf("temporal cloud hostport must not be blank: %s", w.WorkloadId)) + } + + if w.TemporalCloud.Authentication.ApiKey != nil && w.TemporalCloud.Authentication.TLS != nil { + errs = append(errs, fmt.Errorf( + "cannot have both api key and mtls authentication configured on a single workload: %s", w.WorkloadId, + )) + } + + return errs +} diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 0000000..dc44a08 --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,513 @@ +package config + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLoadConfig(t *testing.T) { + tests := []struct { + name string + configYAML string + expectError bool + errorMsg string + }{ + { + name: "valid config", + configYAML: ` +server: + port: 7233 + host: "0.0.0.0" +metrics: + port: 9090 +encryption: + caching: + max_cache: 100 + max_age: "10m" + max_usage: 100 +workloads: + - workload_id: "test-workload" + temporal_cloud: + namespace: "test.namespace" + host_port: "test.namespace.tmprl.cloud:7233" + authentication: + api_key: + value: "test-key" +`, + expectError: false, + }, + { + name: "invalid yaml", + configYAML: ` +server: + port: 7233 + host: "0.0.0.0" +invalid_yaml: [ +`, + expectError: true, + errorMsg: "failed to unmarshal config file", + }, + { + name: "validation failure - invalid port", + configYAML: ` +server: + port: 70000 + host: "0.0.0.0" +metrics: + port: 9090 +workloads: [] +`, + expectError: true, + errorMsg: "failed to validate config", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create temporary config file + tmpFile, err := os.CreateTemp("", "config-*.yaml") + require.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + _, err = tmpFile.WriteString(tt.configYAML) + require.NoError(t, err) + tmpFile.Close() + + // Test LoadConfig + config, err := LoadConfig(tmpFile.Name()) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + assert.NotEmpty(t, config) + } + }) + } +} + +func TestProxyConfig_Validate(t *testing.T) { + tests := []struct { + name string + config ProxyConfig + wantErr bool + errMsg string + }{ + { + name: "valid config", + config: ProxyConfig{ + Server: ServerConfig{ + Port: 7233, + Host: "0.0.0.0", + }, + Metrics: MetricsConfig{ + Port: 9090, + }, + Encryption: GlobalEncryptionConfig{ + Caching: CachingConfig{ + MaxCache: 100, + MaxAge: "10m", + MaxUsage: 100, + }, + }, + Workloads: []WorkloadConfig{ + { + WorkloadId: "test-workload", + TemporalCloud: TemporalCloudConfig{ + Namespace: "test.namespace", + HostPort: "test.namespace.tmprl.cloud:7233", + Authentication: TemporalAuthConfig{ + ApiKey: &TemporalApiKeyConfig{ + Value: "test-key", + }, + }, + }, + }, + }, + }, + wantErr: false, + }, + { + name: "duplicate workload IDs", + config: ProxyConfig{ + Server: ServerConfig{Port: 7233, Host: "0.0.0.0"}, + Metrics: MetricsConfig{Port: 9090}, + Workloads: []WorkloadConfig{ + { + WorkloadId: "duplicate-id", + TemporalCloud: TemporalCloudConfig{ + Namespace: "test1.namespace", + HostPort: "test1.namespace.tmprl.cloud:7233", + Authentication: TemporalAuthConfig{ + ApiKey: &TemporalApiKeyConfig{Value: "key1"}, + }, + }, + }, + { + WorkloadId: "duplicate-id", + TemporalCloud: TemporalCloudConfig{ + Namespace: "test2.namespace", + HostPort: "test2.namespace.tmprl.cloud:7233", + Authentication: TemporalAuthConfig{ + ApiKey: &TemporalApiKeyConfig{Value: "key2"}, + }, + }, + }, + }, + }, + wantErr: true, + errMsg: "workload already exists: duplicate-id", + }, + { + name: "invalid server port - too high", + config: ProxyConfig{ + Server: ServerConfig{Port: 70000, Host: "0.0.0.0"}, + Metrics: MetricsConfig{Port: 9090}, + }, + wantErr: true, + errMsg: "invalid server port: 70000", + }, + { + name: "invalid server port - zero", + config: ProxyConfig{ + Server: ServerConfig{Port: 0, Host: "0.0.0.0"}, + Metrics: MetricsConfig{Port: 9090}, + }, + wantErr: true, + errMsg: "invalid server port: 0", + }, + { + name: "invalid metrics port - negative", + config: ProxyConfig{ + Server: ServerConfig{Port: 7233, Host: "0.0.0.0"}, + Metrics: MetricsConfig{Port: -1}, + }, + wantErr: true, + errMsg: "invalid metrics server port: -1", + }, + { + name: "negative encryption max_cache", + config: ProxyConfig{ + Server: ServerConfig{Port: 7233, Host: "0.0.0.0"}, + Metrics: MetricsConfig{Port: 9090}, + Encryption: GlobalEncryptionConfig{ + Caching: CachingConfig{MaxCache: -1}, + }, + }, + wantErr: true, + errMsg: "encryption max_cache must be >= 0: -1", + }, + { + name: "negative encryption max_usage", + config: ProxyConfig{ + Server: ServerConfig{Port: 7233, Host: "0.0.0.0"}, + Metrics: MetricsConfig{Port: 9090}, + Encryption: GlobalEncryptionConfig{ + Caching: CachingConfig{MaxUsage: -5}, + }, + }, + wantErr: true, + errMsg: "encryption max_usage must be >= 0: -5", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + + if tt.wantErr { + assert.Error(t, err) + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestWorkloadConfig_Validate(t *testing.T) { + tests := []struct { + name string + workload WorkloadConfig + wantErr bool + errMsg string + }{ + { + name: "valid workload with API key", + workload: WorkloadConfig{ + WorkloadId: "test-workload", + TemporalCloud: TemporalCloudConfig{ + Namespace: "test.namespace", + HostPort: "test.namespace.tmprl.cloud:7233", + Authentication: TemporalAuthConfig{ + ApiKey: &TemporalApiKeyConfig{ + Value: "test-key", + }, + }, + }, + }, + wantErr: false, + }, + { + name: "valid workload with TLS", + workload: WorkloadConfig{ + WorkloadId: "test-workload", + TemporalCloud: TemporalCloudConfig{ + Namespace: "test.namespace", + HostPort: "test.namespace.tmprl.cloud:7233", + Authentication: TemporalAuthConfig{ + TLS: &TLSConfig{ + CertFile: "/path/to/cert.pem", + KeyFile: "/path/to/key.pem", + }, + }, + }, + }, + wantErr: false, + }, + { + name: "missing workload_id", + workload: WorkloadConfig{ + WorkloadId: "", + TemporalCloud: TemporalCloudConfig{ + Namespace: "test.namespace", + HostPort: "test.namespace.tmprl.cloud:7233", + }, + }, + wantErr: true, + errMsg: "workload_id is required", + }, + { + name: "missing namespace", + workload: WorkloadConfig{ + WorkloadId: "test-workload", + TemporalCloud: TemporalCloudConfig{ + Namespace: "", + HostPort: "test.namespace.tmprl.cloud:7233", + }, + }, + wantErr: true, + errMsg: "temporal cloud namespace must not be blank: test-workload", + }, + { + name: "missing host_port", + workload: WorkloadConfig{ + WorkloadId: "test-workload", + TemporalCloud: TemporalCloudConfig{ + Namespace: "test.namespace", + HostPort: "", + }, + }, + wantErr: true, + errMsg: "temporal cloud hostport must not be blank: test-workload", + }, + { + name: "both API key and TLS configured", + workload: WorkloadConfig{ + WorkloadId: "test-workload", + TemporalCloud: TemporalCloudConfig{ + Namespace: "test.namespace", + HostPort: "test.namespace.tmprl.cloud:7233", + Authentication: TemporalAuthConfig{ + ApiKey: &TemporalApiKeyConfig{ + Value: "test-key", + }, + TLS: &TLSConfig{ + CertFile: "/path/to/cert.pem", + KeyFile: "/path/to/key.pem", + }, + }, + }, + }, + wantErr: true, + errMsg: "cannot have both api key and mtls authentication configured on a single workload: test-workload", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.workload.Validate() + + if tt.wantErr { + assert.NotEmpty(t, errs) + if tt.errMsg != "" { + found := false + for _, err := range errs { + if assert.Contains(t, err.Error(), tt.errMsg) { + found = true + break + } + } + assert.True(t, found, "Expected error message not found in: %v", errs) + } + } else { + assert.Empty(t, errs) + } + }) + } +} + +func TestServerConfig_Validate(t *testing.T) { + tests := []struct { + name string + config ServerConfig + wantErr bool + errMsg string + }{ + { + name: "valid port", + config: ServerConfig{Port: 8080, Host: "localhost"}, + wantErr: false, + }, + { + name: "port too low", + config: ServerConfig{Port: 0, Host: "localhost"}, + wantErr: true, + errMsg: "invalid server port: 0", + }, + { + name: "port too high", + config: ServerConfig{Port: 70000, Host: "localhost"}, + wantErr: true, + errMsg: "invalid server port: 70000", + }, + { + name: "negative port", + config: ServerConfig{Port: -1, Host: "localhost"}, + wantErr: true, + errMsg: "invalid server port: -1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.config.Validate() + + if tt.wantErr { + assert.NotEmpty(t, errs) + if tt.errMsg != "" { + assert.Contains(t, errs[0].Error(), tt.errMsg) + } + } else { + assert.Empty(t, errs) + } + }) + } +} + +func TestMetricsConfig_Validate(t *testing.T) { + tests := []struct { + name string + config MetricsConfig + wantErr bool + errMsg string + }{ + { + name: "valid port", + config: MetricsConfig{Port: 9090}, + wantErr: false, + }, + { + name: "port too low", + config: MetricsConfig{Port: 0}, + wantErr: true, + errMsg: "invalid metrics server port: 0", + }, + { + name: "port too high", + config: MetricsConfig{Port: 70000}, + wantErr: true, + errMsg: "invalid metrics server port: 70000", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.config.Validate() + + if tt.wantErr { + assert.NotEmpty(t, errs) + if tt.errMsg != "" { + assert.Contains(t, errs[0].Error(), tt.errMsg) + } + } else { + assert.Empty(t, errs) + } + }) + } +} + +func TestGlobalEncryptionConfig_Validate(t *testing.T) { + tests := []struct { + name string + config GlobalEncryptionConfig + wantErr bool + errMsg string + }{ + { + name: "valid config", + config: GlobalEncryptionConfig{ + Caching: CachingConfig{ + MaxCache: 100, + MaxUsage: 50, + }, + }, + wantErr: false, + }, + { + name: "zero values are valid", + config: GlobalEncryptionConfig{ + Caching: CachingConfig{ + MaxCache: 0, + MaxUsage: 0, + }, + }, + wantErr: false, + }, + { + name: "negative max_cache", + config: GlobalEncryptionConfig{ + Caching: CachingConfig{ + MaxCache: -1, + MaxUsage: 50, + }, + }, + wantErr: true, + errMsg: "encryption max_cache must be >= 0: -1", + }, + { + name: "negative max_usage", + config: GlobalEncryptionConfig{ + Caching: CachingConfig{ + MaxCache: 100, + MaxUsage: -5, + }, + }, + wantErr: true, + errMsg: "encryption max_usage must be >= 0: -5", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.config.Validate() + + if tt.wantErr { + assert.NotEmpty(t, errs) + if tt.errMsg != "" { + assert.Contains(t, errs[0].Error(), tt.errMsg) + } + } else { + assert.Empty(t, errs) + } + }) + } +} diff --git a/config/fx.go b/config/fx.go new file mode 100644 index 0000000..789410f --- /dev/null +++ b/config/fx.go @@ -0,0 +1,7 @@ +package config + +import "go.uber.org/fx" + +var Module = fx.Provide( + newConfigProvider, +) diff --git a/crypto/aws_kms_provider.go b/crypto/aws_kms_provider.go index 37ef3f3..86ee2df 100644 --- a/crypto/aws_kms_provider.go +++ b/crypto/aws_kms_provider.go @@ -9,8 +9,8 @@ import ( "github.com/aws/aws-sdk-go/service/kms/kmsiface" ) -// KMSOptions contains configuration options for AWSKMSProvider -type KMSOptions struct { +// AWWSKMSOptions contains configuration options for AWSKMSProvider +type AWSKMSOptions struct { // KeyID is the ARN or ID of the KMS key to use KeyID string @@ -26,7 +26,7 @@ type AWSKMSProvider struct { } // NewAWSKMSProvider creates a new KMS-based materials manager -func NewAWSKMSProvider(kmsClient kmsiface.KMSAPI, options KMSOptions) *AWSKMSProvider { +func NewAWSKMSProvider(kmsClient kmsiface.KMSAPI, options AWSKMSOptions) *AWSKMSProvider { // Set default keySpec if not provided keySpec := options.KeySpec if keySpec == "" { diff --git a/crypto/aws_kms_provider_test.go b/crypto/aws_kms_provider_test.go index 6cb0a09..0f0951e 100644 --- a/crypto/aws_kms_provider_test.go +++ b/crypto/aws_kms_provider_test.go @@ -48,17 +48,17 @@ func (m *MockKMSClient) Decrypt(input *kms.DecryptInput) (*kms.DecryptOutput, er func TestNewAWSKMSProvider(t *testing.T) { tests := []struct { name string - options KMSOptions + options AWSKMSOptions expected string }{ { name: "Default KeySpec", - options: KMSOptions{KeyID: "test-key-id"}, + options: AWSKMSOptions{KeyID: "test-key-id"}, expected: "AES_256", }, { name: "Custom KeySpec", - options: KMSOptions{KeyID: "test-key-id", KeySpec: "RSA_2048"}, + options: AWSKMSOptions{KeyID: "test-key-id", KeySpec: "RSA_2048"}, expected: "RSA_2048", }, } @@ -123,7 +123,7 @@ func TestAWSKMSProvider_GetMaterial(t *testing.T) { generateDataKeyError: tt.mockError, } - provider := NewAWSKMSProvider(mockKMS, KMSOptions{KeyID: "test-key-id"}) + provider := NewAWSKMSProvider(mockKMS, AWSKMSOptions{KeyID: "test-key-id"}) ctx := context.Background() material, err := provider.GetMaterial(ctx, tt.context) @@ -188,7 +188,7 @@ func TestAWSKMSProvider_DecryptMaterial(t *testing.T) { decryptError: tt.mockError, } - provider := NewAWSKMSProvider(mockKMS, KMSOptions{KeyID: "test-key-id"}) + provider := NewAWSKMSProvider(mockKMS, AWSKMSOptions{KeyID: "test-key-id"}) inputMaterial := &Material{ EncryptedKey: tt.encryptedKey, } @@ -223,7 +223,7 @@ func TestAWSKMSProvider_EncryptionContextHandling(t *testing.T) { }, } - provider := NewAWSKMSProvider(mockKMS, KMSOptions{KeyID: "test-key-id"}) + provider := NewAWSKMSProvider(mockKMS, AWSKMSOptions{KeyID: "test-key-id"}) ctx := context.Background() _, err := provider.GetMaterial(ctx, emptyContext) require.NoError(t, err) diff --git a/crypto/benchmark_test.go b/crypto/benchmark_test.go index fdfe4b1..e96424f 100644 --- a/crypto/benchmark_test.go +++ b/crypto/benchmark_test.go @@ -31,7 +31,7 @@ func setupManagers(b testing.TB) (*CachingMaterialsManager, *CachingMaterialsMan kmsClient := kms.New(sess) // Create the AWS KMS provider - awsProvider := NewAWSKMSProvider(kmsClient, KMSOptions{KeyID: keyID}) + awsProvider := NewAWSKMSProvider(kmsClient, AWSKMSOptions{KeyID: keyID}) // Create the caching materials manager cachingMM, err := NewCachingMaterialsManager( @@ -217,7 +217,7 @@ func TestCachingBehavior(t *testing.T) { kmsClient := kms.New(sess) // Create the AWS KMS provider - awsProvider := NewAWSKMSProvider(kmsClient, KMSOptions{KeyID: keyID}) + awsProvider := NewAWSKMSProvider(kmsClient, AWSKMSOptions{KeyID: keyID}) // Create the caching materials manager with short TTL for testing cachingMM, err := NewCachingMaterialsManager( diff --git a/crypto/caching_materials_manager.go b/crypto/caching_materials_manager.go index 58f778a..3c2cf85 100644 --- a/crypto/caching_materials_manager.go +++ b/crypto/caching_materials_manager.go @@ -4,9 +4,9 @@ import ( "context" "crypto/sha256" "fmt" + "github.com/temporal-sa/temporal-cloud-proxy/metrics" "sort" "sync" - "temporal-sa/temporal-cloud-proxy/metrics" "time" lru "github.com/hashicorp/golang-lru" diff --git a/crypto/usage_count_test.go b/crypto/usage_count_test.go index e5a269d..0b9ee22 100644 --- a/crypto/usage_count_test.go +++ b/crypto/usage_count_test.go @@ -26,7 +26,7 @@ func TestUsageCount(t *testing.T) { kmsClient := kms.New(sess) // Create the AWS KMS provider - awsProvider := NewAWSKMSProvider(kmsClient, KMSOptions{KeyID: keyID}) + awsProvider := NewAWSKMSProvider(kmsClient, AWSKMSOptions{KeyID: keyID}) // Create the caching materials manager with specific usage limit maxUsage := 3 @@ -92,7 +92,7 @@ func TestDecryptionWithDecryptMaterial(t *testing.T) { kmsClient := kms.New(sess) // Create the AWS KMS provider - awsProvider := NewAWSKMSProvider(kmsClient, KMSOptions{KeyID: keyID}) + awsProvider := NewAWSKMSProvider(kmsClient, AWSKMSOptions{KeyID: keyID}) // Create the caching materials manager with specific usage limit maxUsage := 3 // This should not affect decryption diff --git a/go.mod b/go.mod index 76d9d57..7b6c918 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module temporal-sa/temporal-cloud-proxy +module github.com/temporal-sa/temporal-cloud-proxy go 1.24.1 @@ -18,16 +18,32 @@ require ( go.temporal.io/api v1.47.0 go.temporal.io/sdk v1.34.0 go.temporal.io/sdk/contrib/opentelemetry v0.6.0 + go.uber.org/fx v1.24.0 + go.uber.org/zap v1.27.0 google.golang.org/grpc v1.73.0 gopkg.in/yaml.v3 v3.0.1 ) require ( + cloud.google.com/go v0.120.0 // indirect + cloud.google.com/go/auth v0.16.1 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect + cloud.google.com/go/compute/metadata v0.6.0 // indirect + cloud.google.com/go/iam v1.5.2 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-jose/go-jose/v4 v4.0.5 // indirect + github.com/google/s2a-go v0.1.9 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect github.com/zeebo/errs v1.4.0 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect go.opentelemetry.io/otel/sdk v1.37.0 // indirect + go.uber.org/dig v1.19.0 // indirect + go.uber.org/multierr v1.10.0 // indirect golang.org/x/crypto v0.39.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + google.golang.org/genproto v0.0.0-20250505200425-f936aa4a68b2 // indirect ) require ( diff --git a/go.sum b/go.sum index 4779104..179ffe4 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,14 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.120.0 h1:wc6bgG9DHyKqF5/vQvX1CiZrtHnxJjBlKUyF9nP6meA= +cloud.google.com/go v0.120.0/go.mod h1:/beW32s8/pGRuj4IILWQNd4uuebeT4dkOhKmkfit64Q= +cloud.google.com/go/auth v0.16.1 h1:XrXauHMd30LhQYVRHLGvJiYeczweKQXZxsTbV9TiguU= +cloud.google.com/go/auth v0.16.1/go.mod h1:1howDHJ5IETh/LwYs3ZxvlkXF48aSqqJUM+5o02dNOI= +cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= +cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= +cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= +cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= +cloud.google.com/go/iam v1.5.2 h1:qgFRAGEmd8z6dJ/qyEchAuL9jpswyODjA2lS+w234g8= +cloud.google.com/go/iam v1.5.2/go.mod h1:SE1vg0N81zQqLzQEwxL2WI6yhetBdbNQuTvIKCSkUHE= cloud.google.com/go/kms v1.22.0 h1:dBRIj7+GDeeEvatJeTB19oYZNV0aj6wEqSIT/7gLqtk= cloud.google.com/go/kms v1.22.0/go.mod h1:U7mf8Sva5jpOb4bxYZdtw/9zsbIjrklYwPcvMk34AL8= cloud.google.com/go/longrunning v0.6.7 h1:IGtfDWHhQCgCjwQjV9iiLnUta9LBCo8R9QmAFsS/PrE= @@ -29,6 +39,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a h1:yDWHCSQ40h88yih2JAcL6Ls/kVkSE8GFACTGVnMPruw= github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a/go.mod h1:7Ga40egUymuWXxAe151lTNnCv97MddSOVsjpPPkityA= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE= github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= @@ -56,8 +68,12 @@ github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6 github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU9uHLo7OnF5tL52HFAgMmyrf4= +github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= github.com/googleapis/gax-go/v2 v2.14.2 h1:eBLnkZ9635krYIPD+ag1USrOAI0Nr0QYF3+/3GqO0k0= github.com/googleapis/gax-go/v2 v2.14.2/go.mod h1:ON64QhlJkhVtSqp4v1uaK92VyZ2gmvDQsweuyLV+8+w= github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 h1:UH//fgunKIs4JdUbpDl1VZCDaL56wXCB/5+wF6uHfaI= @@ -131,6 +147,10 @@ github.com/zeebo/errs v1.4.0 h1:XNdoD/RRMKP7HD0UhJnIzUy74ISdGGxURlYG8HSWSfM= github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0 h1:x7wzEgXfnzJcHDwStJT+mxOz4etr2EcexjqhBvmoakw= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0/go.mod h1:rg+RlpR5dKwaS95IyyZqj5Wd4E13lk/msnTS0Xl9lJM= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 h1:sbiXRNDSWJOTobXh5HyQKjq6wUC5tNybqjIqDpAY4CU= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0/go.mod h1:69uWxva0WgAA/4bu2Yy70SLDBwZXuQ6PbBpbsa5iZrQ= go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= go.opentelemetry.io/otel/exporters/prometheus v0.59.0 h1:HHf+wKS6o5++XZhS98wvILrLVgHxjA/AMjqHKes+uzo= @@ -150,9 +170,19 @@ go.temporal.io/sdk v1.34.0/go.mod h1:iE4U5vFrH3asOhqpBBphpj9zNtw8btp8+MSaf5A0D3w go.temporal.io/sdk/contrib/opentelemetry v0.6.0 h1:rNBArDj5iTUkcMwKocUShoAW59o6HdS7Nq4CTp4ldj8= go.temporal.io/sdk/contrib/opentelemetry v0.6.0/go.mod h1:Lem8VrE2ks8P+FYcRM3UphPoBr+tfM3v/Kaf0qStzSg= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/dig v1.19.0 h1:BACLhebsYdpQ7IROQ1AGPjrXcP5dF80U3gKoFzbaq/4= +go.uber.org/dig v1.19.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE= +go.uber.org/fx v1.24.0 h1:wE8mruvpg2kiiL1Vqd0CC+tr0/24XIB10Iwp2lLWzkg= +go.uber.org/fx v1.24.0/go.mod h1:AmDeGyS+ZARGKM4tlH4FY2Jr63VjbEDJHtqXTGP5hbo= go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -178,6 +208,8 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/metrics/constants.go b/metrics/constants.go index 9203bed..f86033a 100644 --- a/metrics/constants.go +++ b/metrics/constants.go @@ -28,4 +28,15 @@ const ( MaterialsManagerDecryptRequests = TemporalProxyPrefix + "materials_manager_decrypt_requests" MaterialsManagerDecryptErrors = TemporalProxyPrefix + "materials_manager_decrypt_errors" MaterialsManagerDecryptSuccess = TemporalProxyPrefix + "materials_manager_decrypt_success" + + // Proxy metrics + ProxyRequestTotal = TemporalProxyPrefix + "requests_total" + ProxyRequestSuccess = TemporalProxyPrefix + "request_success" + ProxyRequestErrors = TemporalProxyPrefix + "request_errors" + ProxyLatency = TemporalProxyPrefix + "latency" + + // Namespace metrics + //ProxyRequestRequests = TemporalProxyPrefix + "requests_total" + //ProxyRequestSuccess = TemporalProxyPrefix + "request_success" + //ProxyRequestErrors = TemporalProxyPrefix + "request_errors" ) diff --git a/metrics/fx.go b/metrics/fx.go new file mode 100644 index 0000000..9ae100f --- /dev/null +++ b/metrics/fx.go @@ -0,0 +1,7 @@ +package metrics + +import "go.uber.org/fx" + +var Module = fx.Provide( + newMetricsProvider, +) diff --git a/metrics/metrics.go b/metrics/metrics.go new file mode 100644 index 0000000..eec1b9b --- /dev/null +++ b/metrics/metrics.go @@ -0,0 +1,79 @@ +package metrics + +import ( + "context" + "errors" + "fmt" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/temporal-sa/temporal-cloud-proxy/config" + "go.uber.org/fx" + "go.uber.org/zap" + "net/http" +) + +type ( + MetricsProvider interface { + Start() error + Stop() error + } + + httpPromMetricsProvider struct { + host string + port int + path string + server *http.Server + logger *zap.Logger + } +) + +func newMetricsProvider(lc fx.Lifecycle, configProvider config.ConfigProvider, logger *zap.Logger) MetricsProvider { + provider := &httpPromMetricsProvider{ + host: configProvider.GetProxyConfig().Server.Host, + port: configProvider.GetProxyConfig().Metrics.Port, + path: DefaultPrometheusPath, + logger: logger, + } + + // Initialize metrics + _, err := InitPrometheus() + if err != nil { + logger.Fatal("failed to initialize prometheus provider", zap.Error(err)) + } + + provider.server = &http.Server{Addr: provider.getHostPort()} + http.Handle(provider.path, promhttp.Handler()) + + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + return provider.Start() + }, + OnStop: func(ctx context.Context) error { + return provider.Stop() + }, + }) + + return provider +} + +func (h *httpPromMetricsProvider) Start() error { + go func() { + h.logger.Info("metrics server started", zap.String("endpoint", h.getHostPortPath())) + if err := h.server.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + h.logger.Fatal("metrics server error: %v", zap.Error(err)) + } + }() + + return nil +} + +func (h *httpPromMetricsProvider) Stop() error { + return nil +} + +func (h *httpPromMetricsProvider) getHostPort() string { + return fmt.Sprintf("%s:%d", h.host, h.port) +} + +func (h *httpPromMetricsProvider) getHostPortPath() string { + return fmt.Sprintf("%s:%d%s", h.host, h.port, h.path) +} diff --git a/metrics/metrics_handler.go b/metrics/metrics_handler.go index 8daf3a7..c59304d 100644 --- a/metrics/metrics_handler.go +++ b/metrics/metrics_handler.go @@ -102,6 +102,10 @@ func (m MetricsHandler) GetAttributes() attribute.Set { return m.attributes } +func (m MetricsHandler) AddAttributes(attributes ...attribute.KeyValue) { + m.attributes = attribute.NewSet(append(m.attributes.ToSlice(), attributes...)...) +} + func (m MetricsHandler) WithTags(tags map[string]string) client.MetricsHandler { attributes := m.attributes.ToSlice() for k, v := range tags { diff --git a/proxy/fx.go b/proxy/fx.go new file mode 100644 index 0000000..7614b0d --- /dev/null +++ b/proxy/fx.go @@ -0,0 +1,7 @@ +package proxy + +import "go.uber.org/fx" + +var Module = fx.Provide( + newProxyProvider, +) diff --git a/proxy/proxy.go b/proxy/proxy.go index 0814be5..3824c54 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -5,114 +5,238 @@ import ( "crypto/tls" "errors" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/kms" + "github.com/temporal-sa/temporal-cloud-proxy/auth" + "github.com/temporal-sa/temporal-cloud-proxy/codec" + "github.com/temporal-sa/temporal-cloud-proxy/config" + "github.com/temporal-sa/temporal-cloud-proxy/metrics" + "os" + "sync" + "time" + + "go.opentelemetry.io/otel/attribute" "go.temporal.io/sdk/converter" + "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" - "os" - "sync" - "temporal-sa/temporal-cloud-proxy/auth" - "temporal-sa/temporal-cloud-proxy/codec" - "temporal-sa/temporal-cloud-proxy/crypto" - "temporal-sa/temporal-cloud-proxy/metrics" - "temporal-sa/temporal-cloud-proxy/utils" ) -type Conn struct { - mu sync.RWMutex - namespace map[string]NamespaceConn -} +type ( + ProxyProvider interface { + GetConnectionMux() grpc.ClientConnInterface + Start() error + Stop() error + } -type NamespaceConn struct { - conn *grpc.ClientConn - authManager *auth.AuthManager - authType string -} + proxyServer struct { + grpc.ClientConnInterface + connectionMux map[string]namespaceConnection + mu sync.RWMutex + logger *zap.Logger + metricsHandler metrics.MetricsHandler + } -func NewConn() *Conn { - return &Conn{ - namespace: make(map[string]NamespaceConn), + namespaceConnection struct { + conn *grpc.ClientConn + auth *auth.Authenticator + metricsHandler metrics.MetricsHandler } -} +) -// createKMSClient creates an AWS KMS client -func createKMSClient() *kms.KMS { - // Use the region from parameter or environment variable - region := os.Getenv("AWS_REGION") - if region == "" { - region = "us-west-2" // Default region +func newProxyProvider(configProvider config.ConfigProvider, logger *zap.Logger, + authFactory auth.AuthenticatorFactory, codecFactory codec.EncryptionCodecFactory) (ProxyProvider, error) { + proxy := &proxyServer{ + connectionMux: make(map[string]namespaceConnection), + logger: logger, + metricsHandler: metrics.NewMetricsHandler(metrics.MetricsHandlerOptions{}), } - sess := session.Must(session.NewSession(&aws.Config{ - Region: aws.String(region), - })) + for _, w := range configProvider.GetProxyConfig().Workloads { + logger.Debug("adding namespace connection", + zap.String("workload-id", w.WorkloadId), + zap.String("namespace", w.TemporalCloud.Namespace), + ) + + nsConn := &namespaceConnection{ + metricsHandler: metrics.NewMetricsHandler(metrics.MetricsHandlerOptions{ + InitialAttributes: attribute.NewSet( + attribute.String("workload_id", w.WorkloadId), + attribute.String("namespace", w.TemporalCloud.Namespace), + ), + }), + } - return kms.New(sess) -} + // configure worker auth + if w.Authentication == nil { + logger.Warn("workload configured without worker authentication", + zap.String("workload-id", w.WorkloadId)) + } + if w.Authentication != nil { + authenticator, err := authFactory.NewAuthenticator(*w.Authentication) + if err != nil { + logger.Error("failed to create authenticator", + zap.String("workload-id", w.WorkloadId), zap.Error(err)) + return nil, fmt.Errorf( + "failed to create authenticator for workload %s, %w", w.WorkloadId, err) + } + nsConn.auth = &authenticator + } -// AddConnInput contains parameters for adding a new connection -type AddConnInput struct { - Workload *utils.WorkloadConfig - AuthManager *auth.AuthManager - AuthType string - MetricsHandler metrics.MetricsHandler - CryptoCachingConfig *crypto.CachingConfig -} + var grpcInterceptors []grpc.UnaryClientInterceptor -// AddConn adds a new connection to the proxy -func (mc *Conn) AddConn(input AddConnInput) error { - fmt.Printf("Adding connection id: %s namespace: %s hostport: %s\n", - input.Workload.WorkloadId, input.Workload.TemporalCloud.Namespace, input.Workload.TemporalCloud.HostPort) + // configure encryption/decryption codec + if w.Encryption == nil { + logger.Warn("workload configured without payload encryption", + zap.String("workload-id", w.WorkloadId)) + } + if w.Encryption != nil { + // configure encryption metrics handler + attributes := []attribute.KeyValue{ + attribute.String("workload_id", w.WorkloadId), + attribute.String("namespace", w.TemporalCloud.Namespace), + attribute.String("host_port", w.TemporalCloud.HostPort), + } + if w.Authentication != nil { + attributes = append(attributes, attribute.String("auth_type", w.Authentication.Type)) + } - mc.mu.RLock() - _, exists := mc.namespace[input.Workload.WorkloadId] - mc.mu.RUnlock() - if exists { - return fmt.Errorf("workload-id %s already exists", input.Workload.WorkloadId) - } + metricsHandler := metrics.NewMetricsHandler(metrics.MetricsHandlerOptions{ + InitialAttributes: attribute.NewSet( + attributes..., + // Note: encryption codec adds additional attributes + ), + }) - if input.Workload.TemporalCloud.Authentication.ApiKey != nil && input.Workload.TemporalCloud.Authentication.TLS != nil { - return fmt.Errorf("%s: cannot have both api key and mtls authentication configured on a single workload", - input.Workload.WorkloadId) + encryptionCodec, err := codecFactory.NewEncryptionCodec(codec.EncryptionCodecOptions{ + LocalEncryptionConfig: *w.Encryption, + MetricsHandler: &metricsHandler, + CodecContext: map[string]string{ + "namespace": w.TemporalCloud.Namespace, + }, + }) + if err != nil { + logger.Error("failed to create encryption codec", + zap.String("workload-id", w.WorkloadId), zap.Error(err)) + return nil, fmt.Errorf( + "failed to create encryption codec for workload %s, %w", w.WorkloadId, err) + } + + if encryptionCodec != nil { + encryptionInterceptor, err := converter.NewPayloadCodecGRPCClientInterceptor( + converter.PayloadCodecGRPCClientInterceptorOptions{ + Codecs: []converter.PayloadCodec{encryptionCodec}, + }, + ) + if err != nil { + logger.Error("failed to create client interceptor", + zap.String("workload-id", w.WorkloadId), zap.Error(err)) + return nil, fmt.Errorf( + "failed to create client interceptor for workload %s, %w", w.WorkloadId, err) + } + if encryptionInterceptor != nil { + grpcInterceptors = append(grpcInterceptors, encryptionInterceptor) + } + } + } + + // set api key or mTLS auth on the namespace connection + tlsConfig, authInterceptor, err := setNamespaceAuth(w, logger) + if err != nil { + logger.Error("failed to set namespace auth", + zap.String("workload-id", w.WorkloadId), zap.Error(err)) + return nil, fmt.Errorf( + "failed to set namespace auth for workload %s, %w", w.WorkloadId, err) + } + if authInterceptor != nil { + grpcInterceptors = append(grpcInterceptors, authInterceptor) + } + + conn, err := grpc.NewClient( + w.TemporalCloud.HostPort, + grpc.WithTransportCredentials(credentials.NewTLS( + tlsConfig, + )), + grpc.WithChainUnaryInterceptor(grpcInterceptors...), + ) + if err != nil { + logger.Error("failed to create grpc client", + zap.String("workload-id", w.WorkloadId), zap.Error(err)) + return nil, fmt.Errorf( + "failed to create grpc client for workload %s, %w", w.WorkloadId, err) + } + + nsConn.conn = conn + + proxy.mu.Lock() + proxy.connectionMux[w.WorkloadId] = *nsConn + proxy.mu.Unlock() } - //Initialize AWS KMS client - kmsClient := createKMSClient() + return proxy, nil +} - codecContext := map[string]string{ - "namespace": input.Workload.TemporalCloud.Namespace, +func (n *namespaceConnection) GetConnection() grpc.ClientConnInterface { + return n.conn +} + +func (n *namespaceConnection) GetAuthenticator() auth.Authenticator { + if n.auth == nil { + return nil } + return *n.auth +} - clientInterceptor, err := converter.NewPayloadCodecGRPCClientInterceptor( - converter.PayloadCodecGRPCClientInterceptorOptions{ - Codecs: []converter.PayloadCodec{codec.NewEncryptionCodecWithCaching( - kmsClient, - codecContext, - input.Workload.EncryptionKey, - input.MetricsHandler, - input.CryptoCachingConfig, - )}, - }, - ) - if err != nil { - return err +func (n *namespaceConnection) Close() error { + var errs []error + + if err := n.conn.Close(); err != nil { + errs = append(errs, err) } + if n.auth != nil { + if err := n.GetAuthenticator().Close(); err != nil { + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} + +func (p *proxyServer) GetConnectionMux() grpc.ClientConnInterface { + return p +} + +func (p *proxyServer) Start() error { + return nil +} - tlsConfig := tls.Config{} +func (p *proxyServer) Stop() error { + p.mu.Lock() + defer p.mu.Unlock() - grpcInterceptors := []grpc.UnaryClientInterceptor{ - clientInterceptor, + var errs []error + + for _, conn := range p.connectionMux { + if err := conn.Close(); err != nil { + errs = append(errs, err) + } } - if apiKeyConfig := input.Workload.TemporalCloud.Authentication.ApiKey; apiKeyConfig != nil { + return errors.Join(errs...) +} + +func setNamespaceAuth(workloadConfig config.WorkloadConfig, logger *zap.Logger) (*tls.Config, grpc.UnaryClientInterceptor, error) { + tlsConfig := &tls.Config{} + var grpcInterceptor grpc.UnaryClientInterceptor + + if apiKeyConfig := workloadConfig.TemporalCloud.Authentication.ApiKey; apiKeyConfig != nil { + // + // Configure API key auth + // if apiKeyConfig.Value != "" && apiKeyConfig.EnvVar != "" { - // TODO proper logging - fmt.Printf("WARN - multiple values provided for api key, using value. workload-id: %s\n", input.Workload.WorkloadId) + logger.Warn("both value and envvar provider for api key; using value", + zap.String("workload-id", workloadConfig.WorkloadId)) } apiKey := "" @@ -123,10 +247,10 @@ func (mc *Conn) AddConn(input AddConnInput) error { } if apiKey == "" { - return fmt.Errorf("%s: no api key provided", input.Workload.WorkloadId) + return nil, nil, fmt.Errorf("no api key provided") } - grpcInterceptors = append(grpcInterceptors, + grpcInterceptor = func(ctx context.Context, method string, req any, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { md, ok := metadata.FromIncomingContext(ctx) @@ -136,111 +260,98 @@ func (mc *Conn) AddConn(input AddConnInput) error { md.Delete("temporal-namespace") ctx = metadata.NewOutgoingContext(ctx, md) - ctx = metadata.AppendToOutgoingContext(ctx, "temporal-namespace", input.Workload.TemporalCloud.Namespace) + ctx = metadata.AppendToOutgoingContext(ctx, "temporal-namespace", workloadConfig.TemporalCloud.Namespace) ctx = metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+apiKey) } return invoker(ctx, method, req, reply, cc, opts...) - }) - } else { - cert, err := tls.LoadX509KeyPair(input.Workload.TemporalCloud.Authentication.TLS.CertFile, - input.Workload.TemporalCloud.Authentication.TLS.KeyFile) + } + } else if workloadConfig.TemporalCloud.Authentication.TLS != nil { + // + // Configure mTLS auth + // + cert, err := tls.LoadX509KeyPair(workloadConfig.TemporalCloud.Authentication.TLS.CertFile, + workloadConfig.TemporalCloud.Authentication.TLS.KeyFile) if err != nil { - return err + return nil, nil, err } tlsConfig.Certificates = []tls.Certificate{cert} + } else { + // Passthrough. Useful if the client/worker is setting the API. Note: will not work with + // mTLS configured at the client/worker. } - conn, err := grpc.NewClient( - input.Workload.TemporalCloud.HostPort, - grpc.WithTransportCredentials(credentials.NewTLS( - &tlsConfig, - )), - grpc.WithChainUnaryInterceptor(grpcInterceptors...), - ) - if err != nil { - return err - } - - mc.mu.Lock() - mc.namespace[input.Workload.WorkloadId] = NamespaceConn{ - conn: conn, - authManager: input.AuthManager, - authType: input.AuthType, - } - mc.mu.Unlock() - - return nil -} - -// CloseAll closes all connections -func (mc *Conn) CloseAll() error { - mc.mu.Lock() - defer mc.mu.Unlock() - - var errs []error - - for _, namespace := range mc.namespace { - if err := namespace.conn.Close(); err != nil { - errs = append(errs, err) - } - if namespace.authManager != nil { - if err := namespace.authManager.Close(); err != nil { - errs = append(errs, err) - } - } - } - - return errors.Join(errs...) + return tlsConfig, grpcInterceptor, nil } // Invoke implements the grpc.ClientConnInterface Invoke method -func (mc *Conn) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error { +func (p *proxyServer) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error { + start := time.Now() + p.metricsHandler.Counter(metrics.ProxyRequestTotal).Inc(1) + md, ok := metadata.FromIncomingContext(ctx) if !ok { + p.metricsHandler.WithTags(map[string]string{"error": "unable to read metadata"}).Counter(metrics.ProxyRequestErrors).Inc(1) return status.Errorf(codes.InvalidArgument, "unable to read metadata") } workloadId := md.Get("workload-id") if len(workloadId) <= 0 { + p.metricsHandler.WithTags(map[string]string{"error": "metadata missing workload-id"}).Counter(metrics.ProxyRequestErrors).Inc(1) return status.Error(codes.InvalidArgument, "metadata missing workload-id") } if len(workloadId) != 1 { + p.metricsHandler.WithTags(map[string]string{"error": "metadata contains multiple workload-id entries"}).Counter(metrics.ProxyRequestErrors).Inc(1) return status.Error(codes.InvalidArgument, "metadata contains multiple workload-id entries") } - mc.mu.RLock() - namespace, exists := mc.namespace[workloadId[0]] - mc.mu.RUnlock() + p.mu.RLock() + namespace, exists := p.connectionMux[workloadId[0]] + p.mu.RUnlock() if !exists { - return status.Errorf(codes.InvalidArgument, "invalid workload-id: %s", workloadId[0]) + p.logger.Warn("invalid workload-id", zap.String("workload-id", workloadId[0])) + p.metricsHandler.WithTags(map[string]string{"error": "invalid workload-id"}).Counter(metrics.ProxyRequestErrors).Inc(1) + return status.Errorf(codes.InvalidArgument, "invalid workload-id") } - if namespace.authManager != nil { + if namespace.GetAuthenticator() != nil { authorization := md.Get("authorization") if len(authorization) < 1 { + namespace.metricsHandler.WithTags(map[string]string{"error": "metadata is missing authorization"}).Counter(metrics.ProxyRequestErrors).Inc(1) return status.Error(codes.InvalidArgument, "metadata is missing authorization") } else if len(authorization) > 1 { + namespace.metricsHandler.WithTags(map[string]string{"error": "metadata contains multiple authorization entries"}).Counter(metrics.ProxyRequestErrors).Inc(1) return status.Error(codes.InvalidArgument, "metadata contains multiple authorization entries") } - result, err := namespace.authManager.Authenticate(ctx, namespace.authType, authorization[0]) + result, err := namespace.GetAuthenticator().Authenticate(ctx, authorization[0]) if err != nil { - return status.Errorf(codes.Unknown, "failed to authenticate: %s", err) + namespace.metricsHandler.WithTags(map[string]string{"error": "failed to authenticate"}).Counter(metrics.ProxyRequestErrors).Inc(1) + return status.Errorf(codes.Unknown, "failed to authenticate: %v", err) } if !result.Authenticated { + namespace.metricsHandler.WithTags(map[string]string{"error": "invalid token"}).Counter(metrics.ProxyRequestErrors).Inc(1) return status.Errorf(codes.Unauthenticated, "invalid token") } } - return namespace.conn.Invoke(ctx, method, args, reply, opts...) + p.logger.Debug("invoking method", + zap.String("workload-id", workloadId[0]), + zap.String("method", method), + zap.Any("args", args), + zap.Any("md", md), + ) + + namespace.metricsHandler.Counter(metrics.ProxyRequestSuccess).Inc(1) + namespace.metricsHandler.Timer(metrics.ProxyLatency).Record(time.Since(start)) + return namespace.GetConnection().Invoke(ctx, method, args, reply, opts...) } // NewStream implements the grpc.ClientConnInterface NewStream method -func (mc *Conn) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { +func (p *proxyServer) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { return nil, status.Error(codes.Unimplemented, "streams not supported") } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 1d22f7c..1d17703 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -2,657 +2,528 @@ package proxy import ( "context" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" "errors" - "fmt" - "math/big" - "net" - "os" - "sync" "testing" - "time" - - "temporal-sa/temporal-cloud-proxy/auth" - "temporal-sa/temporal-cloud-proxy/metrics" - "temporal-sa/temporal-cloud-proxy/utils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "google.golang.org/grpc" + "go.temporal.io/api/common/v1" + "go.temporal.io/sdk/converter" + "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + + "github.com/temporal-sa/temporal-cloud-proxy/auth" + "github.com/temporal-sa/temporal-cloud-proxy/codec" + "github.com/temporal-sa/temporal-cloud-proxy/config" ) -// MockAuthManager is a mock implementation of the AuthManager interface -type MockAuthManager struct { +// Mock implementations +type MockConfigProvider struct { + config config.ProxyConfig +} + +func (m *MockConfigProvider) GetProxyConfig() config.ProxyConfig { + return m.config +} + +type MockAuthenticatorFactory struct { mock.Mock } -func (m *MockAuthManager) Authenticate(ctx context.Context, authType string, credentials string) (*auth.AuthenticationResult, error) { - args := m.Called(ctx, authType, credentials) +func (m *MockAuthenticatorFactory) NewAuthenticator(authConfig config.AuthConfig) (auth.Authenticator, error) { + args := m.Called(authConfig) if args.Get(0) == nil { return nil, args.Error(1) } - return args.Get(0).(*auth.AuthenticationResult), args.Error(1) + return args.Get(0).(auth.Authenticator), args.Error(1) } -func (m *MockAuthManager) Close() error { +type MockAuthenticator struct { + mock.Mock +} + +func (m *MockAuthenticator) Type() string { args := m.Called() - return args.Error(0) + return args.String(0) } -// Helper function to create test TLS certificates -func createTestCertificates(t *testing.T) (string, string) { - // Generate private key - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) +func (m *MockAuthenticator) Init(ctx context.Context, config map[string]interface{}) error { + args := m.Called(ctx, config) + return args.Error(0) +} - // Create certificate template - template := x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - Organization: []string{"Test"}, - Country: []string{"US"}, - Province: []string{""}, - Locality: []string{"Test"}, - StreetAddress: []string{""}, - PostalCode: []string{""}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(365 * 24 * time.Hour), - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}, +func (m *MockAuthenticator) Authenticate(ctx context.Context, credentials interface{}) (*auth.AuthenticationResult, error) { + args := m.Called(ctx, credentials) + if args.Get(0) == nil { + return nil, args.Error(1) } + return args.Get(0).(*auth.AuthenticationResult), args.Error(1) +} - // Create certificate - certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) - require.NoError(t, err) - - // Create temporary files - certFile, err := os.CreateTemp("", "test-cert-*.pem") - require.NoError(t, err) - defer certFile.Close() - - keyFile, err := os.CreateTemp("", "test-key-*.pem") - require.NoError(t, err) - defer keyFile.Close() - - // Write certificate - err = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) - require.NoError(t, err) +func (m *MockAuthenticator) Close() error { + args := m.Called() + return args.Error(0) +} - // Write private key - privateKeyDER, err := x509.MarshalPKCS8PrivateKey(privateKey) - require.NoError(t, err) - err = pem.Encode(keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyDER}) - require.NoError(t, err) +type MockEncryptionCodecFactory struct { + mock.Mock +} - return certFile.Name(), keyFile.Name() +func (m *MockEncryptionCodecFactory) NewEncryptionCodec(options codec.EncryptionCodecOptions) (converter.PayloadCodec, error) { + args := m.Called(options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(converter.PayloadCodec), args.Error(1) } -func TestNewConn(t *testing.T) { - conn := NewConn() +type MockPayloadCodec struct { + mock.Mock +} - assert.NotNil(t, conn) - assert.NotNil(t, conn.namespace) - assert.Equal(t, 0, len(conn.namespace)) +func (m *MockPayloadCodec) Encode(payloads []*common.Payload) ([]*common.Payload, error) { + args := m.Called(payloads) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]*common.Payload), args.Error(1) } -func TestConn_AddConn(t *testing.T) { - // Create test certificates - certPath, keyPath := createTestCertificates(t) - defer os.Remove(certPath) - defer os.Remove(keyPath) +func (m *MockPayloadCodec) Decode(payloads []*common.Payload) ([]*common.Payload, error) { + args := m.Called(payloads) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]*common.Payload), args.Error(1) +} - tests := []struct { - name string - input AddConnInput - expectError bool - errorMsg string - }{ - { - name: "successful connection addition with TLS", - input: AddConnInput{ - Workload: &utils.WorkloadConfig{ - WorkloadId: "test-workload-id", - TemporalCloud: utils.TemporalCloudConfig{ - Namespace: "test-namespace", - HostPort: "localhost:7233", - Authentication: utils.TemporalAuthConfig{ - TLS: &utils.TLSConfig{ - CertFile: certPath, - KeyFile: keyPath, - }, - }, - }, - EncryptionKey: "test-key-id", - }, - AuthManager: nil, // Use nil for simplicity in tests - AuthType: "jwt", - MetricsHandler: metrics.NewMetricsHandler(metrics.MetricsHandlerOptions{}), - CryptoCachingConfig: nil, - }, - expectError: false, +// Test helper functions +func createValidConfig() config.ProxyConfig { + return config.ProxyConfig{ + Server: config.ServerConfig{ + Port: 7233, + Host: "0.0.0.0", }, - { - name: "successful connection addition with API key (value)", - input: AddConnInput{ - Workload: &utils.WorkloadConfig{ - WorkloadId: "test-workload-id-api", - TemporalCloud: utils.TemporalCloudConfig{ - Namespace: "test-namespace", - HostPort: "localhost:7233", - Authentication: utils.TemporalAuthConfig{ - ApiKey: &utils.TemporalApiKeyConfig{ - Value: "test-api-key", - }, - }, - }, - EncryptionKey: "test-key-id", - }, - AuthManager: nil, - AuthType: "jwt", - MetricsHandler: metrics.NewMetricsHandler(metrics.MetricsHandlerOptions{}), - CryptoCachingConfig: nil, - }, - expectError: false, + Metrics: config.MetricsConfig{ + Port: 9090, }, - { - name: "successful connection addition with API key (env var)", - input: AddConnInput{ - Workload: &utils.WorkloadConfig{ - WorkloadId: "test-workload-id-api-env", - TemporalCloud: utils.TemporalCloudConfig{ - Namespace: "test-namespace", - HostPort: "localhost:7233", - Authentication: utils.TemporalAuthConfig{ - ApiKey: &utils.TemporalApiKeyConfig{ - EnvVar: "TEST_TEMPORAL_API_KEY", - }, + Workloads: []config.WorkloadConfig{ + { + WorkloadId: "test-workload", + TemporalCloud: config.TemporalCloudConfig{ + Namespace: "test.namespace", + HostPort: "test.namespace.tmprl.cloud:7233", + Authentication: config.TemporalAuthConfig{ + ApiKey: &config.TemporalApiKeyConfig{ + Value: "test-api-key", }, }, - EncryptionKey: "test-key-id", }, - AuthManager: nil, - AuthType: "jwt", - MetricsHandler: metrics.NewMetricsHandler(metrics.MetricsHandlerOptions{}), - CryptoCachingConfig: nil, }, - expectError: true, // Will fail because env var is not set }, - { - name: "invalid certificate path", - input: AddConnInput{ - Workload: &utils.WorkloadConfig{ - WorkloadId: "test-workload-id", - TemporalCloud: utils.TemporalCloudConfig{ - Namespace: "test-namespace", - HostPort: "localhost:7233", - Authentication: utils.TemporalAuthConfig{ - TLS: &utils.TLSConfig{ - CertFile: "/nonexistent/cert.pem", - KeyFile: keyPath, - }, - }, - }, - EncryptionKey: "test-key-id", - }, - AuthManager: nil, - AuthType: "jwt", - MetricsHandler: metrics.NewMetricsHandler(metrics.MetricsHandlerOptions{}), - CryptoCachingConfig: nil, - }, - expectError: true, + } +} + +func createConfigWithMultipleWorkloads() config.ProxyConfig { + return config.ProxyConfig{ + Server: config.ServerConfig{ + Port: 7233, + Host: "0.0.0.0", }, - { - name: "invalid key path", - input: AddConnInput{ - Workload: &utils.WorkloadConfig{ - WorkloadId: "test-workload-id", - TemporalCloud: utils.TemporalCloudConfig{ - Namespace: "test-namespace", - HostPort: "localhost:7233", - Authentication: utils.TemporalAuthConfig{ - TLS: &utils.TLSConfig{ - CertFile: certPath, - KeyFile: "/nonexistent/key.pem", - }, + Metrics: config.MetricsConfig{ + Port: 9090, + }, + Workloads: []config.WorkloadConfig{ + { + WorkloadId: "workload-1", + TemporalCloud: config.TemporalCloudConfig{ + Namespace: "test1.namespace", + HostPort: "test1.namespace.tmprl.cloud:7233", + Authentication: config.TemporalAuthConfig{ + ApiKey: &config.TemporalApiKeyConfig{ + Value: "api-key-1", }, }, - EncryptionKey: "test-key-id", }, - AuthManager: nil, - AuthType: "jwt", - MetricsHandler: metrics.NewMetricsHandler(metrics.MetricsHandlerOptions{}), - CryptoCachingConfig: nil, }, - expectError: true, - }, - { - name: "both API key and TLS configured - should error", - input: AddConnInput{ - Workload: &utils.WorkloadConfig{ - WorkloadId: "test-workload-id", - TemporalCloud: utils.TemporalCloudConfig{ - Namespace: "test-namespace", - HostPort: "localhost:7233", - Authentication: utils.TemporalAuthConfig{ - ApiKey: &utils.TemporalApiKeyConfig{ - Value: "test-api-key", - }, - TLS: &utils.TLSConfig{ - CertFile: certPath, - KeyFile: keyPath, - }, + { + WorkloadId: "workload-2", + TemporalCloud: config.TemporalCloudConfig{ + Namespace: "test2.namespace", + HostPort: "test2.namespace.tmprl.cloud:7233", + Authentication: config.TemporalAuthConfig{ + ApiKey: &config.TemporalApiKeyConfig{ + Value: "api-key-2", }, }, - EncryptionKey: "test-key-id", }, - AuthManager: nil, - AuthType: "jwt", - MetricsHandler: metrics.NewMetricsHandler(metrics.MetricsHandlerOptions{}), - CryptoCachingConfig: nil, }, - expectError: true, - errorMsg: "cannot have both api key and mtls authentication", }, } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - conn := NewConn() - err := conn.AddConn(tt.input) - - if tt.expectError { - assert.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - assert.NoError(t, err) - assert.Equal(t, 1, len(conn.namespace)) - - // Verify the connection was stored correctly - nsConn, exists := conn.namespace[tt.input.Workload.WorkloadId] - assert.True(t, exists) - assert.NotNil(t, nsConn.conn) - assert.Equal(t, tt.input.AuthManager, nsConn.authManager) - assert.Equal(t, tt.input.AuthType, nsConn.authType) - } - }) - } -} - -func TestConn_CloseAll_Empty(t *testing.T) { - conn := NewConn() - err := conn.CloseAll() - assert.NoError(t, err) } -func TestConn_Invoke(t *testing.T) { +func TestNewProxyProvider_Success(t *testing.T) { tests := []struct { - name string - setupContext func() context.Context - setupConn func() *Conn - method string - args interface{} - reply interface{} - expectError bool - expectedCode codes.Code - errorContains string + name string + config config.ProxyConfig }{ { - name: "missing metadata", - setupContext: func() context.Context { - return context.Background() - }, - setupConn: func() *Conn { - return NewConn() - }, - method: "/test.Service/Method", - expectError: true, - expectedCode: codes.InvalidArgument, - }, - { - name: "missing workload-id", - setupContext: func() context.Context { - md := metadata.New(map[string]string{}) - return metadata.NewIncomingContext(context.Background(), md) - }, - setupConn: func() *Conn { - return NewConn() - }, - method: "/test.Service/Method", - expectError: true, - expectedCode: codes.InvalidArgument, - errorContains: "metadata missing workload-id", - }, - { - name: "multiple workload-id entries", - setupContext: func() context.Context { - md := metadata.New(map[string]string{}) - md.Append("workload-id", "workload-id-1") - md.Append("workload-id", "workload-id-2") - return metadata.NewIncomingContext(context.Background(), md) - }, - setupConn: func() *Conn { - return NewConn() - }, - method: "/test.Service/Method", - expectError: true, - expectedCode: codes.InvalidArgument, - errorContains: "multiple workload-id entries", - }, - { - name: "workload not found", - setupContext: func() context.Context { - md := metadata.New(map[string]string{ - "workload-id": "nonexistent-workload-id", - }) - return metadata.NewIncomingContext(context.Background(), md) - }, - setupConn: func() *Conn { - return NewConn() - }, - method: "/test.Service/Method", - expectError: true, - expectedCode: codes.InvalidArgument, - errorContains: "invalid workload-id: nonexistent-workload-id", - }, - { - name: "invoke without authentication - skips auth logic", - setupContext: func() context.Context { - md := metadata.New(map[string]string{ - "workload-id": "test-workload-id-no-auth", - }) - return metadata.NewIncomingContext(context.Background(), md) - }, - setupConn: func() *Conn { - conn := NewConn() - // Don't add any namespace connections to test the "workload not found" path - // This way we can test the logic without hitting the nil pointer - return conn - }, - method: "/test.Service/Method", - args: struct{}{}, - reply: struct{}{}, - expectError: true, - expectedCode: codes.InvalidArgument, - errorContains: "invalid workload-id: test-workload-id-no-auth", + name: "single workload with API key", + config: createValidConfig(), }, { - name: "missing authorization with auth manager", - setupContext: func() context.Context { - md := metadata.New(map[string]string{ - "workload-id": "test-workload-id", - }) - return metadata.NewIncomingContext(context.Background(), md) - }, - setupConn: func() *Conn { - conn := NewConn() - // Create a real auth manager for testing - authManager := auth.NewAuthManager() - - conn.namespace["test-workload-id"] = NamespaceConn{ - conn: nil, - authManager: authManager, - authType: "jwt", - } - return conn - }, - method: "/test.Service/Method", - expectError: true, - expectedCode: codes.InvalidArgument, - errorContains: "metadata is missing authorization", - }, - { - name: "multiple authorization entries", - setupContext: func() context.Context { - md := metadata.New(map[string]string{ - "workload-id": "test-workload-id", - }) - md.Append("authorization", "Bearer token1") - md.Append("authorization", "Bearer token2") - return metadata.NewIncomingContext(context.Background(), md) - }, - setupConn: func() *Conn { - conn := NewConn() - // Create a real auth manager for testing - authManager := auth.NewAuthManager() - - conn.namespace["test-workload-id"] = NamespaceConn{ - conn: nil, - authManager: authManager, - authType: "jwt", - } - return conn - }, - method: "/test.Service/Method", - expectError: true, - expectedCode: codes.InvalidArgument, - errorContains: "multiple authorization entries", + name: "multiple workloads", + config: createConfigWithMultipleWorkloads(), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := tt.setupContext() - conn := tt.setupConn() + configProvider := &MockConfigProvider{config: tt.config} + logger := zap.NewNop() - err := conn.Invoke(ctx, tt.method, tt.args, tt.reply) + mockAuthFactory := &MockAuthenticatorFactory{} + mockCodecFactory := &MockEncryptionCodecFactory{} - if tt.expectError { - assert.Error(t, err) - if tt.expectedCode != codes.OK { - st, ok := status.FromError(err) - assert.True(t, ok) - assert.Equal(t, tt.expectedCode, st.Code()) + // Setup mocks - no authentication or encryption for basic test + for _, workload := range tt.config.Workloads { + if workload.Authentication != nil { + mockAuth := &MockAuthenticator{} + mockAuthFactory.On("NewAuthenticator", *workload.Authentication).Return(mockAuth, nil) } - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) + if workload.Encryption != nil { + mockCodec := &MockPayloadCodec{} + mockCodecFactory.On("NewEncryptionCodec", mock.AnythingOfType("codec.EncryptionCodecOptions")).Return(mockCodec, nil) } - } else { - assert.NoError(t, err) } + + provider, err := newProxyProvider(configProvider, logger, mockAuthFactory, mockCodecFactory) + + assert.NoError(t, err) + assert.NotNil(t, provider) + assert.NotNil(t, provider.GetConnectionMux()) + + // Test that we can start and stop the provider + err = provider.Start() + assert.NoError(t, err) + + err = provider.Stop() + assert.NoError(t, err) + + mockAuthFactory.AssertExpectations(t) + mockCodecFactory.AssertExpectations(t) }) } } -func TestConn_NewStream(t *testing.T) { - conn := NewConn() - ctx := context.Background() - desc := &grpc.StreamDesc{} - method := "/test.Service/StreamMethod" +func TestNewProxyProvider_AuthenticatorFactoryError(t *testing.T) { + configProvider := &MockConfigProvider{config: config.ProxyConfig{ + Server: config.ServerConfig{Port: 7233, Host: "0.0.0.0"}, + Metrics: config.MetricsConfig{Port: 9090}, + Workloads: []config.WorkloadConfig{ + { + WorkloadId: "test-workload", + TemporalCloud: config.TemporalCloudConfig{ + Namespace: "test.namespace", + HostPort: "test.namespace.tmprl.cloud:7233", + Authentication: config.TemporalAuthConfig{ + ApiKey: &config.TemporalApiKeyConfig{Value: "test-key"}, + }, + }, + Authentication: &config.AuthConfig{ + Type: "jwt", + Config: map[string]interface{}{"jwks-url": "http://example.com"}, + }, + }, + }, + }} - stream, err := conn.NewStream(ctx, desc, method) + logger := zap.NewNop() + mockAuthFactory := &MockAuthenticatorFactory{} + mockCodecFactory := &MockEncryptionCodecFactory{} - assert.Nil(t, stream) - assert.Error(t, err) + // Setup mock to return error + expectedErr := errors.New("failed to create authenticator") + mockAuthFactory.On("NewAuthenticator", mock.AnythingOfType("config.AuthConfig")).Return(nil, expectedErr) - st, ok := status.FromError(err) - assert.True(t, ok) - assert.Equal(t, codes.Unimplemented, st.Code()) - assert.Contains(t, err.Error(), "streams not supported") -} + provider, err := newProxyProvider(configProvider, logger, mockAuthFactory, mockCodecFactory) -func TestCreateKMSClient(t *testing.T) { - // Test with environment variable - originalRegion := os.Getenv("AWS_REGION") - defer func() { - if originalRegion != "" { - os.Setenv("AWS_REGION", originalRegion) - } else { - os.Unsetenv("AWS_REGION") - } - }() - - // Test with custom region - os.Setenv("AWS_REGION", "us-east-1") - client := createKMSClient() - assert.NotNil(t, client) - - // Test with default region - os.Unsetenv("AWS_REGION") - client = createKMSClient() - assert.NotNil(t, client) + assert.Error(t, err) + assert.Nil(t, provider) + assert.Contains(t, err.Error(), "failed to create authenticator") + + mockAuthFactory.AssertExpectations(t) } -func TestConn_ConcurrentAccess(t *testing.T) { - // Create test certificates - certPath, keyPath := createTestCertificates(t) - defer os.Remove(certPath) - defer os.Remove(keyPath) - - conn := NewConn() - - // Test concurrent AddConn operations - numConnections := 10 - var wg sync.WaitGroup - wg.Add(numConnections) - - for i := 0; i < numConnections; i++ { - go func(id int) { - defer wg.Done() - - input := AddConnInput{ - Workload: &utils.WorkloadConfig{ - WorkloadId: fmt.Sprintf("workload-id-%d", id), - TemporalCloud: utils.TemporalCloudConfig{ - Namespace: fmt.Sprintf("namespace-%d", id), - HostPort: "localhost:7233", - Authentication: utils.TemporalAuthConfig{ - TLS: &utils.TLSConfig{ - CertFile: certPath, - KeyFile: keyPath, - }, - }, +func TestNewProxyProvider_EncryptionCodecFactoryError(t *testing.T) { + configProvider := &MockConfigProvider{config: config.ProxyConfig{ + Server: config.ServerConfig{Port: 7233, Host: "0.0.0.0"}, + Metrics: config.MetricsConfig{Port: 9090}, + Workloads: []config.WorkloadConfig{ + { + WorkloadId: "test-workload", + TemporalCloud: config.TemporalCloudConfig{ + Namespace: "test.namespace", + HostPort: "test.namespace.tmprl.cloud:7233", + Authentication: config.TemporalAuthConfig{ + ApiKey: &config.TemporalApiKeyConfig{Value: "test-key"}, }, - EncryptionKey: "test-key-id", }, - AuthManager: nil, - AuthType: "jwt", - MetricsHandler: metrics.NewMetricsHandler(metrics.MetricsHandlerOptions{}), - CryptoCachingConfig: nil, - } + Encryption: &config.EncryptionConfig{ + Type: "aws-kms", + Config: map[string]interface{}{"key-id": "test-key"}, + }, + }, + }, + }} - err := conn.AddConn(input) - assert.NoError(t, err) - }(i) - } + logger := zap.NewNop() + mockAuthFactory := &MockAuthenticatorFactory{} + mockCodecFactory := &MockEncryptionCodecFactory{} - wg.Wait() + // Setup mock to return error + expectedErr := errors.New("failed to create encryption codec") + mockCodecFactory.On("NewEncryptionCodec", mock.AnythingOfType("codec.EncryptionCodecOptions")).Return(nil, expectedErr) - // Verify all connections were added - assert.Equal(t, numConnections, len(conn.namespace)) + provider, err := newProxyProvider(configProvider, logger, mockAuthFactory, mockCodecFactory) - // Test concurrent Invoke operations - numInvokes := 50 - wg.Add(numInvokes) + assert.Error(t, err) + assert.Nil(t, provider) + assert.Contains(t, err.Error(), "failed to create encryption codec") - for i := 0; i < numInvokes; i++ { - go func(id int) { - defer wg.Done() + mockCodecFactory.AssertExpectations(t) +} - workloadId := id % numConnections - md := metadata.New(map[string]string{ - "workload-id": fmt.Sprintf("workload-id-%d", workloadId), - }) - ctx := metadata.NewIncomingContext(context.Background(), md) +func TestProxyServer_Invoke_Success(t *testing.T) { + configProvider := &MockConfigProvider{config: createValidConfig()} + logger := zap.NewNop() + mockAuthFactory := &MockAuthenticatorFactory{} + mockCodecFactory := &MockEncryptionCodecFactory{} - // This will fail because we don't have real gRPC connections, - // but it tests the concurrent access to the namespace map - conn.Invoke(ctx, "/test.Service/Method", struct{}{}, struct{}{}) - }(i) + provider, err := newProxyProvider(configProvider, logger, mockAuthFactory, mockCodecFactory) + require.NoError(t, err) + + // Create context with metadata + md := metadata.New(map[string]string{ + "workload-id": "test-workload", + }) + ctx := metadata.NewIncomingContext(context.Background(), md) + + // Test invoke - this will fail because we don't have a real gRPC connection, + // but we can test the metadata validation logic + err = provider.GetConnectionMux().Invoke(ctx, "/test.Service/TestMethod", nil, nil) + + // We expect this to fail with a connection error, not a validation error + assert.Error(t, err) + // Should not be a validation error (InvalidArgument) + st, ok := status.FromError(err) + if ok { + assert.NotEqual(t, codes.InvalidArgument, st.Code()) } +} - wg.Wait() +func TestProxyServer_Invoke_MissingWorkloadId(t *testing.T) { + configProvider := &MockConfigProvider{config: createValidConfig()} + logger := zap.NewNop() + mockAuthFactory := &MockAuthenticatorFactory{} + mockCodecFactory := &MockEncryptionCodecFactory{} + + provider, err := newProxyProvider(configProvider, logger, mockAuthFactory, mockCodecFactory) + require.NoError(t, err) + + // Create context without workload-id metadata + ctx := context.Background() + + err = provider.GetConnectionMux().Invoke(ctx, "/test.Service/TestMethod", nil, nil) + + assert.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Contains(t, st.Message(), "unable to read metadata") } -// Test authentication logic with a mock that can be properly cast -func TestConn_InvokeWithAuthentication(t *testing.T) { - conn := NewConn() +func TestProxyServer_Invoke_InvalidWorkloadId(t *testing.T) { + configProvider := &MockConfigProvider{config: createValidConfig()} + logger := zap.NewNop() + mockAuthFactory := &MockAuthenticatorFactory{} + mockCodecFactory := &MockEncryptionCodecFactory{} + + provider, err := newProxyProvider(configProvider, logger, mockAuthFactory, mockCodecFactory) + require.NoError(t, err) - // Create a mock auth manager - mockAuth := &MockAuthManager{} - mockAuth.On("Authenticate", mock.Anything, "jwt", "Bearer valid-token").Return( - &auth.AuthenticationResult{ - Authenticated: true, - Subject: "test-user", - }, nil) + // Create context with invalid workload-id + md := metadata.New(map[string]string{ + "workload-id": "invalid-workload", + }) + ctx := metadata.NewIncomingContext(context.Background(), md) - mockAuth.On("Authenticate", mock.Anything, "jwt", "Bearer invalid-token").Return( - nil, errors.New("invalid token")) + err = provider.GetConnectionMux().Invoke(ctx, "/test.Service/TestMethod", nil, nil) - mockAuth.On("Authenticate", mock.Anything, "jwt", "Bearer expired-token").Return( - &auth.AuthenticationResult{ - Authenticated: false, - }, nil) + assert.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Contains(t, st.Message(), "invalid workload-id") +} - // We can't easily cast our mock to *auth.AuthManager due to Go's type system, - // so we'll test the authentication logic indirectly by testing the error cases - // that don't require the actual authentication call. +func TestProxyServer_Invoke_MultipleWorkloadIds(t *testing.T) { + configProvider := &MockConfigProvider{config: createValidConfig()} + logger := zap.NewNop() + mockAuthFactory := &MockAuthenticatorFactory{} + mockCodecFactory := &MockEncryptionCodecFactory{} + + provider, err := newProxyProvider(configProvider, logger, mockAuthFactory, mockCodecFactory) + require.NoError(t, err) + + // Create context with multiple workload-id values + md := metadata.New(map[string]string{}) + md.Append("workload-id", "workload-1") + md.Append("workload-id", "workload-2") + ctx := metadata.NewIncomingContext(context.Background(), md) + + err = provider.GetConnectionMux().Invoke(ctx, "/test.Service/TestMethod", nil, nil) + + assert.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Contains(t, st.Message(), "multiple workload-id entries") +} + +func TestProxyServer_Invoke_WithAuthentication(t *testing.T) { + configProvider := &MockConfigProvider{config: config.ProxyConfig{ + Server: config.ServerConfig{Port: 7233, Host: "0.0.0.0"}, + Metrics: config.MetricsConfig{Port: 9090}, + Workloads: []config.WorkloadConfig{ + { + WorkloadId: "test-workload", + TemporalCloud: config.TemporalCloudConfig{ + Namespace: "test.namespace", + HostPort: "test.namespace.tmprl.cloud:7233", + Authentication: config.TemporalAuthConfig{ + ApiKey: &config.TemporalApiKeyConfig{Value: "test-key"}, + }, + }, + Authentication: &config.AuthConfig{ + Type: "jwt", + Config: map[string]interface{}{"jwks-url": "http://example.com"}, + }, + }, + }, + }} + + logger := zap.NewNop() + mockAuthFactory := &MockAuthenticatorFactory{} + mockCodecFactory := &MockEncryptionCodecFactory{} + + mockAuth := &MockAuthenticator{} + mockAuthFactory.On("NewAuthenticator", mock.AnythingOfType("config.AuthConfig")).Return(mockAuth, nil) + + provider, err := newProxyProvider(configProvider, logger, mockAuthFactory, mockCodecFactory) + require.NoError(t, err) tests := []struct { - name string - setupContext func() context.Context - expectError bool - expectedCode codes.Code - errorContains string + name string + setupAuth func() + metadata map[string]string + expectedCode codes.Code + expectedMsg string }{ { name: "missing authorization header", - setupContext: func() context.Context { - md := metadata.New(map[string]string{ - "workload-id": "test-workload-id", - }) - return metadata.NewIncomingContext(context.Background(), md) + setupAuth: func() { + // No setup needed + }, + metadata: map[string]string{ + "workload-id": "test-workload", }, - expectError: true, - expectedCode: codes.InvalidArgument, - errorContains: "metadata is missing authorization", + expectedCode: codes.InvalidArgument, + expectedMsg: "metadata is missing authorization", + }, + { + name: "authentication failure", + setupAuth: func() { + mockAuth.On("Authenticate", mock.Anything, "Bearer invalid-token").Return( + &auth.AuthenticationResult{Authenticated: false}, nil) + }, + metadata: map[string]string{ + "workload-id": "test-workload", + "authorization": "Bearer invalid-token", + }, + expectedCode: codes.Unauthenticated, + expectedMsg: "invalid token", + }, + { + name: "authentication error", + setupAuth: func() { + mockAuth.On("Authenticate", mock.Anything, "Bearer error-token").Return( + nil, errors.New("auth service error")) + }, + metadata: map[string]string{ + "workload-id": "test-workload", + "authorization": "Bearer error-token", + }, + expectedCode: codes.Unknown, + expectedMsg: "failed to authenticate", }, } - // Add a namespace with auth manager (using nil since we can't easily mock the interface) - conn.namespace["test-workload-id"] = NamespaceConn{ - conn: nil, // Will cause failure, but we're testing auth logic first - authManager: nil, // We'll set this to non-nil to trigger auth checks - authType: "jwt", - } - - // Set authManager to non-nil to trigger the auth logic - // Use a real auth manager since we can't easily mock the interface - authManager := auth.NewAuthManager() - nsConn := conn.namespace["test-workload-id"] - nsConn.authManager = authManager - conn.namespace["test-workload-id"] = nsConn - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := tt.setupContext() + tt.setupAuth() - err := conn.Invoke(ctx, "/test.Service/Method", struct{}{}, struct{}{}) + md := metadata.New(tt.metadata) + ctx := metadata.NewIncomingContext(context.Background(), md) - if tt.expectError { - assert.Error(t, err) - if tt.expectedCode != codes.OK { - st, ok := status.FromError(err) - assert.True(t, ok) - assert.Equal(t, tt.expectedCode, st.Code()) - } - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) - } + err = provider.GetConnectionMux().Invoke(ctx, "/test.Service/TestMethod", nil, nil) + + assert.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, tt.expectedCode, st.Code()) + assert.Contains(t, st.Message(), tt.expectedMsg) }) } + + mockAuthFactory.AssertExpectations(t) + mockAuth.AssertExpectations(t) +} + +func TestProxyServer_NewStream_NotSupported(t *testing.T) { + configProvider := &MockConfigProvider{config: createValidConfig()} + logger := zap.NewNop() + mockAuthFactory := &MockAuthenticatorFactory{} + mockCodecFactory := &MockEncryptionCodecFactory{} + + provider, err := newProxyProvider(configProvider, logger, mockAuthFactory, mockCodecFactory) + require.NoError(t, err) + + stream, err := provider.GetConnectionMux().NewStream(context.Background(), nil, "/test.Service/TestMethod") + + assert.Nil(t, stream) + assert.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Unimplemented, st.Code()) + assert.Contains(t, st.Message(), "streams not supported") +} + +// Note: Testing the Close() method is complex because namespaceConnection.conn is *grpc.ClientConn (concrete type) +// and we can't easily mock it. The error aggregation logic is tested through integration tests. +// The key behavior (error aggregation) is covered by the config validation tests and the overall proxy tests. + +func TestProxyServer_Stop_ErrorAggregation(t *testing.T) { + // This test verifies that Stop() properly aggregates errors from multiple connections + // We can't easily test namespaceConnection.Close() directly due to the concrete grpc.ClientConn type, + // but we can test the error aggregation behavior at the proxy level through integration testing. + + // For now, we'll focus on the more testable aspects of the proxy functionality + // The error aggregation logic in Close() methods is straightforward and follows the same pattern + // as the config validation error aggregation which is thoroughly tested. + + t.Skip("Skipping detailed Close() testing due to concrete grpc.ClientConn type - behavior is covered by integration tests") } diff --git a/transport/fx.go b/transport/fx.go new file mode 100644 index 0000000..e911252 --- /dev/null +++ b/transport/fx.go @@ -0,0 +1,9 @@ +package transport + +import ( + "go.uber.org/fx" +) + +var Module = fx.Provide( + newTransportProvider, +) diff --git a/transport/transport.go b/transport/transport.go new file mode 100644 index 0000000..8a90aa5 --- /dev/null +++ b/transport/transport.go @@ -0,0 +1,87 @@ +package transport + +import ( + "context" + "fmt" + "github.com/temporal-sa/temporal-cloud-proxy/config" + "github.com/temporal-sa/temporal-cloud-proxy/proxy" + "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/sdk/client" + "go.uber.org/fx" + "go.uber.org/zap" + "google.golang.org/grpc" + "net" +) + +type ( + TransportProvider interface { + Start() error + Stop() error + } + + grpcTransportProvider struct { + host string + port int + grpcServer *grpc.Server + logger *zap.Logger + } +) + +func newTransportProvider(lc fx.Lifecycle, configProvider config.ConfigProvider, logger *zap.Logger, + proxyProvider proxy.ProxyProvider) (TransportProvider, error) { + + transportManager := &grpcTransportProvider{ + host: configProvider.GetProxyConfig().Server.Host, + port: configProvider.GetProxyConfig().Server.Port, + logger: logger, + } + + workflowClient := workflowservice.NewWorkflowServiceClient(proxyProvider.GetConnectionMux()) + + handler, err := client.NewWorkflowServiceProxyServer( + client.WorkflowServiceProxyOptions{Client: workflowClient}, + ) + if err != nil { + return nil, err + } + + transportManager.grpcServer = grpc.NewServer() + workflowservice.RegisterWorkflowServiceServer(transportManager.grpcServer, handler) + + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + return transportManager.Start() + }, + OnStop: func(ctx context.Context) error { + return transportManager.Stop() + }, + }) + + return transportManager, nil +} + +func (t *grpcTransportProvider) Start() error { + lis, err := net.Listen("tcp", t.getHostPort()) + if err != nil { + return err + } + + t.logger.Info( + "proxy started", + zap.String("host", t.host), + zap.Int("port", t.port), + ) + + go t.grpcServer.Serve(lis) + + return nil +} + +func (t *grpcTransportProvider) Stop() error { + t.grpcServer.GracefulStop() + return nil +} + +func (t *grpcTransportProvider) getHostPort() string { + return fmt.Sprintf("%s:%d", t.host, t.port) +} diff --git a/transport/transport_test.go b/transport/transport_test.go new file mode 100644 index 0000000..98e67d2 --- /dev/null +++ b/transport/transport_test.go @@ -0,0 +1,203 @@ +package transport + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/fx" + "go.uber.org/zap" + "google.golang.org/grpc" + + "github.com/temporal-sa/temporal-cloud-proxy/config" +) + +// Mock implementations +type MockConfigProvider struct { + config config.ProxyConfig +} + +func (m *MockConfigProvider) GetProxyConfig() config.ProxyConfig { + return m.config +} + +type MockProxyProvider struct { + mock.Mock +} + +func (m *MockProxyProvider) GetConnectionMux() grpc.ClientConnInterface { + args := m.Called() + return args.Get(0).(grpc.ClientConnInterface) +} + +func (m *MockProxyProvider) Start() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockProxyProvider) Stop() error { + args := m.Called() + return args.Error(0) +} + +type MockClientConn struct { + mock.Mock +} + +func (m *MockClientConn) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error { + mockArgs := m.Called(ctx, method, args, reply, opts) + return mockArgs.Error(0) +} + +func (m *MockClientConn) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + mockArgs := m.Called(ctx, desc, method, opts) + return nil, mockArgs.Error(1) +} + +func TestNewTransportProvider_Success(t *testing.T) { + configProvider := &MockConfigProvider{ + config: config.ProxyConfig{ + Server: config.ServerConfig{ + Port: 7233, + Host: "0.0.0.0", + }, + Metrics: config.MetricsConfig{ + Port: 9090, + }, + }, + } + + logger := zap.NewNop() + mockProxyProvider := &MockProxyProvider{} + mockClientConn := &MockClientConn{} + + // Setup mock expectations + mockProxyProvider.On("GetConnectionMux").Return(mockClientConn) + + // Create a mock lifecycle that doesn't actually do anything + mockLifecycle := &MockLifecycle{} + + // Test the transport provider creation + transportProvider, err := newTransportProvider(mockLifecycle, configProvider, logger, mockProxyProvider) + + assert.NoError(t, err) + assert.NotNil(t, transportProvider) + + // Verify mock expectations + mockProxyProvider.AssertExpectations(t) +} + +func TestTransportProvider_StartStop(t *testing.T) { + configProvider := &MockConfigProvider{ + config: config.ProxyConfig{ + Server: config.ServerConfig{ + Port: 0, // Use port 0 to let the OS assign a free port + Host: "127.0.0.1", + }, + }, + } + + logger := zap.NewNop() + mockProxyProvider := &MockProxyProvider{} + mockClientConn := &MockClientConn{} + + // Setup mock expectations + mockProxyProvider.On("GetConnectionMux").Return(mockClientConn) + + // Create a mock lifecycle + mockLifecycle := &MockLifecycle{} + + transportProvider, err := newTransportProvider(mockLifecycle, configProvider, logger, mockProxyProvider) + require.NoError(t, err) + + // Test Start + err = transportProvider.Start() + assert.NoError(t, err) + + // Test Stop + err = transportProvider.Stop() + assert.NoError(t, err) + + // Verify mock expectations + mockProxyProvider.AssertExpectations(t) +} + +func TestTransportProvider_Configuration(t *testing.T) { + tests := []struct { + name string + config config.ProxyConfig + }{ + { + name: "default configuration", + config: config.ProxyConfig{ + Server: config.ServerConfig{ + Port: 7233, + Host: "0.0.0.0", + }, + }, + }, + { + name: "custom host and port", + config: config.ProxyConfig{ + Server: config.ServerConfig{ + Port: 8080, + Host: "127.0.0.1", + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + configProvider := &MockConfigProvider{config: tt.config} + logger := zap.NewNop() + mockProxyProvider := &MockProxyProvider{} + mockClientConn := &MockClientConn{} + + // Setup mock expectations + mockProxyProvider.On("GetConnectionMux").Return(mockClientConn) + + // Create a mock lifecycle + mockLifecycle := &MockLifecycle{} + + transportProvider, err := newTransportProvider(mockLifecycle, configProvider, logger, mockProxyProvider) + + assert.NoError(t, err) + assert.NotNil(t, transportProvider) + + // Cast to concrete type to verify configuration + concreteProvider, ok := transportProvider.(*grpcTransportProvider) + require.True(t, ok) + + assert.Equal(t, tt.config.Server.Host, concreteProvider.host) + assert.Equal(t, tt.config.Server.Port, concreteProvider.port) + + // Verify mock expectations + mockProxyProvider.AssertExpectations(t) + }) + } +} + +// Mock lifecycle for testing +type MockLifecycle struct { + hooks []fx.Hook +} + +func (m *MockLifecycle) Append(hook fx.Hook) { + // In a real test, we might want to capture and execute these hooks + // For now, we'll just ignore them since we're testing the transport provider directly + m.hooks = append(m.hooks, hook) +} + +// Note: This is a simplified test suite that focuses on the basic functionality +// of the transport provider. In a production environment, you might want to add +// more comprehensive tests including: +// - gRPC server integration tests +// - Error handling scenarios +// - Concurrent request handling +// - Graceful shutdown behavior +// +// However, these would require more complex setup and potentially real network +// connections, which goes beyond the scope of unit testing. diff --git a/utils/config.go b/utils/config.go deleted file mode 100644 index c886075..0000000 --- a/utils/config.go +++ /dev/null @@ -1,60 +0,0 @@ -package utils - -type Config struct { - Server ServerConfig `yaml:"server"` - Metrics MetricsConfig `yaml:"metrics"` - Encryption EncryptionConfig `yaml:"encryption"` - Workloads []WorkloadConfig `yaml:"workloads"` -} - -type ServerConfig struct { - Port int `yaml:"port"` - Host string `yaml:"host"` -} - -type MetricsConfig struct { - Port int `yaml:"port"` -} - -type EncryptionConfig struct { - Caching CachingConfig `yaml:"caching"` -} - -type CachingConfig struct { - MaxCache int `yaml:"max_cache,omitempty"` - MaxAge string `yaml:"max_age,omitempty"` - MaxUsage int `yaml:"max_usage,omitempty"` -} - -type WorkloadConfig struct { - WorkloadId string `yaml:"workload_id"` - TemporalCloud TemporalCloudConfig `yaml:"temporal_cloud"` - EncryptionKey string `yaml:"encryption_key"` - Authentication *AuthConfig `yaml:"authentication,omitempty"` -} - -type TemporalCloudConfig struct { - Namespace string `yaml:"namespace"` - HostPort string `yaml:"host_port"` - Authentication TemporalAuthConfig `yaml:"authentication"` -} - -type TemporalAuthConfig struct { - TLS *TLSConfig `yaml:"tls,omitempty"` - ApiKey *TemporalApiKeyConfig `yaml:"api_key,omitempty"` -} - -type TemporalApiKeyConfig struct { - Value string `yaml:"value,omitempty"` - EnvVar string `yaml:"env,omitempty"` -} - -type TLSConfig struct { - CertFile string `yaml:"cert_file"` - KeyFile string `yaml:"key_file"` -} - -type AuthConfig struct { - Type string `yaml:"type"` - Config map[string]interface{} `yaml:"config"` -} diff --git a/utils/config_manager.go b/utils/config_manager.go deleted file mode 100644 index 2dd83db..0000000 --- a/utils/config_manager.go +++ /dev/null @@ -1,58 +0,0 @@ -package utils - -import ( - "fmt" - "os" - "sync" - "time" - - "gopkg.in/yaml.v3" -) - -type ConfigManager struct { - configPath string - config *Config - lastLoadTime time.Time - mu sync.RWMutex -} - -func NewConfigManager(configPath string) (*ConfigManager, error) { - cm := &ConfigManager{ - configPath: configPath, - } - - if err := cm.loadConfig(); err != nil { - return nil, err - } - - return cm, nil -} - -func (cm *ConfigManager) GetConfig() *Config { - cm.mu.RLock() - defer cm.mu.RUnlock() - return cm.config -} - -func (cm *ConfigManager) Close() error { - return nil -} - -func (cm *ConfigManager) loadConfig() error { - configFile, err := os.ReadFile(cm.configPath) - if err != nil { - return fmt.Errorf("failed to read config file: %w", err) - } - - var cfg Config - if err = yaml.Unmarshal(configFile, &cfg); err != nil { - return fmt.Errorf("failed to unmarshal config file: %w", err) - } - - cm.mu.Lock() - cm.config = &cfg - cm.lastLoadTime = time.Now() - cm.mu.Unlock() - - return nil -} diff --git a/utils/config_manager_test.go b/utils/config_manager_test.go deleted file mode 100644 index 7b75730..0000000 --- a/utils/config_manager_test.go +++ /dev/null @@ -1,583 +0,0 @@ -package utils - -import ( - "fmt" - "os" - "path/filepath" - "sync" - "testing" - "time" -) - -func TestNewConfigManager(t *testing.T) { - tests := []struct { - name string - configData string - wantErr bool - expectNil bool - description string - }{ - { - name: "valid config file", - configData: ` -server: - port: 7233 - host: "0.0.0.0" -workloads: - - workload_id: "test.internal" - temporal_cloud: - namespace: "test-namespace" - host_port: "test.external:7233" - authentication: - tls: - cert_file: "/path/to/cert.crt" - key_file: "/path/to/key.key" - encryption_key: "test-key" -`, - wantErr: false, - expectNil: false, - description: "should successfully create config manager with valid config", - }, - { - name: "minimal valid config", - configData: ` -server: - port: 8080 - host: "localhost" -workloads: [] -`, - wantErr: false, - expectNil: false, - description: "should handle minimal config with empty workloads", - }, - { - name: "invalid yaml", - configData: ` -server: - port: 7233 - host: "0.0.0.0" -targets: - - proxy_id: "test.internal" - target: "test.external:7233" - invalid: yaml: [ -`, - wantErr: true, - expectNil: true, - description: "should fail with invalid YAML", - }, - { - name: "empty config file", - configData: "", - wantErr: false, - expectNil: false, - description: "should handle empty config file", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create temporary config file - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config.yaml") - - err := os.WriteFile(configPath, []byte(tt.configData), 0644) - if err != nil { - t.Fatalf("Failed to create test config file: %v", err) - } - - cm, err := NewConfigManager(configPath) - - if (err != nil) != tt.wantErr { - t.Errorf("NewConfigManager() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if (cm == nil) != tt.expectNil { - t.Errorf("NewConfigManager() returned nil = %v, expectNil %v", cm == nil, tt.expectNil) - return - } - - if !tt.wantErr && cm != nil { - // Verify the config manager was properly initialized - if cm.configPath != configPath { - t.Errorf("Expected configPath to be %s, got %s", configPath, cm.configPath) - } - - config := cm.GetConfig() - if config == nil { - t.Error("Expected GetConfig() to return non-nil config") - } - - // Verify lastLoadTime was set - if cm.lastLoadTime.IsZero() { - t.Error("Expected lastLoadTime to be set") - } - } - }) - } -} - -func TestNewConfigManager_FileNotFound(t *testing.T) { - nonExistentPath := "/path/that/does/not/exist/config.yaml" - - cm, err := NewConfigManager(nonExistentPath) - - if err == nil { - t.Error("Expected error when config file does not exist") - } - - if cm != nil { - t.Error("Expected ConfigManager to be nil when file does not exist") - } -} - -func TestConfigManager_GetConfig(t *testing.T) { - configData := ` -server: - port: 9090 - host: "127.0.0.1" -workloads: - - workload_id: "test1.internal" - temporal_cloud: - namespace: "namespace1" - host_port: "test1.external:9090" - authentication: - tls: - cert_file: "/test1.crt" - key_file: "/test1.key" - encryption_key: "key1" - - workload_id: "test2.internal" - temporal_cloud: - namespace: "namespace2" - host_port: "test2.external:9091" - authentication: - tls: - cert_file: "/test2.crt" - key_file: "/test2.key" - encryption_key: "key2" - authentication: - type: "spiffe" - config: - trust_domain: "spiffe://example.org/" -` - - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config.yaml") - - err := os.WriteFile(configPath, []byte(configData), 0644) - if err != nil { - t.Fatalf("Failed to create test config file: %v", err) - } - - cm, err := NewConfigManager(configPath) - if err != nil { - t.Fatalf("Failed to create ConfigManager: %v", err) - } - - config := cm.GetConfig() - - if config == nil { - t.Fatal("GetConfig() returned nil") - } - - // Verify server config - if config.Server.Port != 9090 { - t.Errorf("Expected server port to be 9090, got %d", config.Server.Port) - } - if config.Server.Host != "127.0.0.1" { - t.Errorf("Expected server host to be '127.0.0.1', got %s", config.Server.Host) - } - - // Verify workloads - if len(config.Workloads) != 2 { - t.Errorf("Expected 2 workloads, got %d", len(config.Workloads)) - } - - if len(config.Workloads) >= 1 { - workload1 := config.Workloads[0] - if workload1.WorkloadId != "test1.internal" { - t.Errorf("Expected first workload workload_id to be 'test1.internal', got %s", workload1.WorkloadId) - } - if workload1.Authentication != nil { - t.Error("Expected first workload to have no authentication") - } - } - - if len(config.Workloads) >= 2 { - workload2 := config.Workloads[1] - if workload2.WorkloadId != "test2.internal" { - t.Errorf("Expected second workload workload_id to be 'test2.internal', got %s", workload2.WorkloadId) - } - if workload2.Authentication == nil { - t.Error("Expected second workload to have authentication") - } else if workload2.Authentication.Type != "spiffe" { - t.Errorf("Expected second workload auth type to be 'spiffe', got %s", workload2.Authentication.Type) - } - } -} - -func TestConfigManager_GetConfig_ThreadSafety(t *testing.T) { - configData := ` -server: - port: 8080 - host: "localhost" -workloads: - - workload_id: "concurrent.internal" - temporal_cloud: - namespace: "concurrent-namespace" - host_port: "concurrent.external:8080" - authentication: - tls: - cert_file: "/concurrent.crt" - key_file: "/concurrent.key" - encryption_key: "concurrent-key" -` - - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config.yaml") - - err := os.WriteFile(configPath, []byte(configData), 0644) - if err != nil { - t.Fatalf("Failed to create test config file: %v", err) - } - - cm, err := NewConfigManager(configPath) - if err != nil { - t.Fatalf("Failed to create ConfigManager: %v", err) - } - - // Test concurrent access to GetConfig - const numGoroutines = 100 - const numIterations = 10 - - var wg sync.WaitGroup - errors := make(chan error, numGoroutines*numIterations) - - for i := 0; i < numGoroutines; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for j := 0; j < numIterations; j++ { - config := cm.GetConfig() - if config == nil { - errors <- err - return - } - if config.Server.Port != 8080 { - errors <- err - return - } - if len(config.Workloads) != 1 { - errors <- err - return - } - if config.Workloads[0].WorkloadId != "concurrent.internal" { - errors <- err - return - } - } - }() - } - - wg.Wait() - close(errors) - - for err := range errors { - if err != nil { - t.Errorf("Concurrent access error: %v", err) - } - } -} - -func TestConfigManager_Close(t *testing.T) { - configData := ` -server: - port: 7233 - host: "0.0.0.0" -targets: [] -` - - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config.yaml") - - err := os.WriteFile(configPath, []byte(configData), 0644) - if err != nil { - t.Fatalf("Failed to create test config file: %v", err) - } - - cm, err := NewConfigManager(configPath) - if err != nil { - t.Fatalf("Failed to create ConfigManager: %v", err) - } - - // Test that Close() doesn't return an error - err = cm.Close() - if err != nil { - t.Errorf("Close() returned error: %v", err) - } - - // Test that we can still get config after Close() (since Close() is currently a no-op) - config := cm.GetConfig() - if config == nil { - t.Error("GetConfig() returned nil after Close()") - } -} - -func TestConfigManager_loadConfig(t *testing.T) { - tests := []struct { - name string - configData string - wantErr bool - description string - }{ - { - name: "valid config", - configData: ` -server: - port: 7233 - host: "0.0.0.0" -targets: - - proxy_id: "test.internal" - target: "test.external:7233" - tls: - cert_file: "/path/to/cert.crt" - key_file: "/path/to/key.key" - encryption_key: "test-key" - namespace: "test-namespace" -`, - wantErr: false, - description: "should load valid config successfully", - }, - { - name: "invalid yaml structure", - configData: ` -server: - port: "invalid_port" # port should be int, not string - host: "0.0.0.0" -targets: [] -`, - wantErr: true, - description: "should fail with invalid YAML structure", - }, - { - name: "malformed yaml", - configData: ` -server: - port: 7233 - host: "0.0.0.0" -targets: - - proxy_id: "test.internal" - target: "test.external:7233" - invalid: [unclosed -`, - wantErr: true, - description: "should fail with malformed YAML", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config.yaml") - - err := os.WriteFile(configPath, []byte(tt.configData), 0644) - if err != nil { - t.Fatalf("Failed to create test config file: %v", err) - } - - cm := &ConfigManager{ - configPath: configPath, - } - - beforeLoad := time.Now() - err = cm.loadConfig() - afterLoad := time.Now() - - if (err != nil) != tt.wantErr { - t.Errorf("loadConfig() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if !tt.wantErr { - // Verify config was loaded - if cm.config == nil { - t.Error("Expected config to be loaded") - } - - // Verify lastLoadTime was updated - if cm.lastLoadTime.Before(beforeLoad) || cm.lastLoadTime.After(afterLoad) { - t.Error("Expected lastLoadTime to be updated during load") - } - } - }) - } -} - -func TestConfigManager_loadConfig_FilePermissions(t *testing.T) { - configData := ` -server: - port: 7233 - host: "0.0.0.0" -targets: [] -` - - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config.yaml") - - err := os.WriteFile(configPath, []byte(configData), 0644) - if err != nil { - t.Fatalf("Failed to create test config file: %v", err) - } - - // Remove read permissions - err = os.Chmod(configPath, 0000) - if err != nil { - t.Fatalf("Failed to change file permissions: %v", err) - } - - // Restore permissions after test - defer func() { - os.Chmod(configPath, 0644) - }() - - cm := &ConfigManager{ - configPath: configPath, - } - - err = cm.loadConfig() - if err == nil { - t.Error("Expected error when config file is not readable") - } -} - -func TestConfigManager_ConfigPath(t *testing.T) { - configData := ` -server: - port: 7233 - host: "0.0.0.0" -targets: [] -` - - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "test-config.yaml") - - err := os.WriteFile(configPath, []byte(configData), 0644) - if err != nil { - t.Fatalf("Failed to create test config file: %v", err) - } - - cm, err := NewConfigManager(configPath) - if err != nil { - t.Fatalf("Failed to create ConfigManager: %v", err) - } - - if cm.configPath != configPath { - t.Errorf("Expected configPath to be %s, got %s", configPath, cm.configPath) - } -} - -func TestConfigManager_LastLoadTime(t *testing.T) { - configData := ` -server: - port: 7233 - host: "0.0.0.0" -targets: [] -` - - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config.yaml") - - err := os.WriteFile(configPath, []byte(configData), 0644) - if err != nil { - t.Fatalf("Failed to create test config file: %v", err) - } - - beforeCreate := time.Now() - cm, err := NewConfigManager(configPath) - afterCreate := time.Now() - - if err != nil { - t.Fatalf("Failed to create ConfigManager: %v", err) - } - - // Verify lastLoadTime is within expected range - if cm.lastLoadTime.Before(beforeCreate) || cm.lastLoadTime.After(afterCreate) { - t.Errorf("Expected lastLoadTime to be between %v and %v, got %v", - beforeCreate, afterCreate, cm.lastLoadTime) - } -} - -// Benchmark tests -func BenchmarkConfigManager_GetConfig(b *testing.B) { - configData := ` -server: - port: 7233 - host: "0.0.0.0" -targets: - - proxy_id: "bench.internal" - target: "bench.external:7233" - tls: - cert_file: "/bench.crt" - key_file: "/bench.key" - encryption_key: "bench-key" - namespace: "bench-namespace" -` - - tmpDir := b.TempDir() - configPath := filepath.Join(tmpDir, "config.yaml") - - err := os.WriteFile(configPath, []byte(configData), 0644) - if err != nil { - b.Fatalf("Failed to create test config file: %v", err) - } - - cm, err := NewConfigManager(configPath) - if err != nil { - b.Fatalf("Failed to create ConfigManager: %v", err) - } - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - config := cm.GetConfig() - if config == nil { - b.Error("GetConfig returned nil") - } - } - }) -} - -func BenchmarkConfigManager_NewConfigManager(b *testing.B) { - configData := ` -server: - port: 7233 - host: "0.0.0.0" -targets: - - proxy_id: "bench.internal" - target: "bench.external:7233" - tls: - cert_file: "/bench.crt" - key_file: "/bench.key" - encryption_key: "bench-key" - namespace: "bench-namespace" -` - - tmpDir := b.TempDir() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - configPath := filepath.Join(tmpDir, fmt.Sprintf("config-%d.yaml", i)) - - err := os.WriteFile(configPath, []byte(configData), 0644) - if err != nil { - b.Fatalf("Failed to create test config file: %v", err) - } - - cm, err := NewConfigManager(configPath) - if err != nil { - b.Fatalf("Failed to create ConfigManager: %v", err) - } - - _ = cm.Close() - } -} diff --git a/utils/config_test.go b/utils/config_test.go deleted file mode 100644 index f955a73..0000000 --- a/utils/config_test.go +++ /dev/null @@ -1,729 +0,0 @@ -package utils - -import ( - "testing" - - "gopkg.in/yaml.v3" -) - -func TestConfig_UnmarshalYAML(t *testing.T) { - tests := []struct { - name string - yamlData string - want Config - wantErr bool - }{ - { - name: "valid complete config with TLS authentication", - yamlData: ` -server: - port: 7233 - host: "0.0.0.0" -metrics: - port: 8080 -encryption: - caching: - max_cache: 100 - max_age: "1h" - max_usage: 1000 -workloads: - - workload_id: "test.namespace.internal" - temporal_cloud: - namespace: "test.namespace" - host_port: "test.namespace.tmprl.cloud:7233" - authentication: - tls: - cert_file: "/path/to/cert.crt" - key_file: "/path/to/key.key" - encryption_key: "test-key" - authentication: - type: "spiffe" - config: - trust_domain: "spiffe://example.org/" - endpoint: "unix:///tmp/spire-agent/public/api.sock" - audiences: - - "test_audience" -`, - want: Config{ - Server: ServerConfig{ - Port: 7233, - Host: "0.0.0.0", - }, - Metrics: MetricsConfig{ - Port: 8080, - }, - Encryption: EncryptionConfig{ - Caching: CachingConfig{ - MaxCache: 100, - MaxAge: "1h", - MaxUsage: 1000, - }, - }, - Workloads: []WorkloadConfig{ - { - WorkloadId: "test.namespace.internal", - TemporalCloud: TemporalCloudConfig{ - Namespace: "test.namespace", - HostPort: "test.namespace.tmprl.cloud:7233", - Authentication: TemporalAuthConfig{ - TLS: &TLSConfig{ - CertFile: "/path/to/cert.crt", - KeyFile: "/path/to/key.key", - }, - }, - }, - EncryptionKey: "test-key", - Authentication: &AuthConfig{ - Type: "spiffe", - Config: map[string]interface{}{ - "trust_domain": "spiffe://example.org/", - "endpoint": "unix:///tmp/spire-agent/public/api.sock", - "audiences": []interface{}{"test_audience"}, - }, - }, - }, - }, - }, - wantErr: false, - }, - { - name: "valid config with API key authentication (value)", - yamlData: ` -server: - port: 8080 - host: "localhost" -metrics: - port: 9090 -workloads: - - workload_id: "simple.internal" - temporal_cloud: - namespace: "simple" - host_port: "simple.external:8080" - authentication: - api_key: - value: "your-api-key-here" - encryption_key: "simple-key" -`, - want: Config{ - Server: ServerConfig{ - Port: 8080, - Host: "localhost", - }, - Metrics: MetricsConfig{ - Port: 9090, - }, - Encryption: EncryptionConfig{ - Caching: CachingConfig{}, - }, - Workloads: []WorkloadConfig{ - { - WorkloadId: "simple.internal", - TemporalCloud: TemporalCloudConfig{ - Namespace: "simple", - HostPort: "simple.external:8080", - Authentication: TemporalAuthConfig{ - ApiKey: &TemporalApiKeyConfig{ - Value: "your-api-key-here", - }, - }, - }, - EncryptionKey: "simple-key", - Authentication: nil, - }, - }, - }, - wantErr: false, - }, - { - name: "valid config with API key authentication (env var)", - yamlData: ` -server: - port: 8080 - host: "localhost" -metrics: - port: 9090 -workloads: - - workload_id: "simple.internal" - temporal_cloud: - namespace: "simple" - host_port: "simple.external:8080" - authentication: - api_key: - env: "TEMPORAL_API_KEY" - encryption_key: "simple-key" -`, - want: Config{ - Server: ServerConfig{ - Port: 8080, - Host: "localhost", - }, - Metrics: MetricsConfig{ - Port: 9090, - }, - Encryption: EncryptionConfig{ - Caching: CachingConfig{}, - }, - Workloads: []WorkloadConfig{ - { - WorkloadId: "simple.internal", - TemporalCloud: TemporalCloudConfig{ - Namespace: "simple", - HostPort: "simple.external:8080", - Authentication: TemporalAuthConfig{ - ApiKey: &TemporalApiKeyConfig{ - EnvVar: "TEMPORAL_API_KEY", - }, - }, - }, - EncryptionKey: "simple-key", - Authentication: nil, - }, - }, - }, - wantErr: false, - }, - { - name: "multiple workloads with mixed authentication", - yamlData: ` -server: - port: 9090 - host: "127.0.0.1" -metrics: - port: 8081 -encryption: - caching: - max_cache: 50 -workloads: - - workload_id: "workload1.internal" - temporal_cloud: - namespace: "namespace1" - host_port: "workload1.external:9090" - authentication: - tls: - cert_file: "/workload1.crt" - key_file: "/workload1.key" - encryption_key: "key1" - - workload_id: "workload2.internal" - temporal_cloud: - namespace: "namespace2" - host_port: "workload2.external:9091" - authentication: - api_key: - value: "workload2-api-key" - encryption_key: "key2" - authentication: - type: "oauth" - config: - client_id: "test-client" - client_secret: "test-secret" -`, - want: Config{ - Server: ServerConfig{ - Port: 9090, - Host: "127.0.0.1", - }, - Metrics: MetricsConfig{ - Port: 8081, - }, - Encryption: EncryptionConfig{ - Caching: CachingConfig{ - MaxCache: 50, - }, - }, - Workloads: []WorkloadConfig{ - { - WorkloadId: "workload1.internal", - TemporalCloud: TemporalCloudConfig{ - Namespace: "namespace1", - HostPort: "workload1.external:9090", - Authentication: TemporalAuthConfig{ - TLS: &TLSConfig{ - CertFile: "/workload1.crt", - KeyFile: "/workload1.key", - }, - }, - }, - EncryptionKey: "key1", - Authentication: nil, - }, - { - WorkloadId: "workload2.internal", - TemporalCloud: TemporalCloudConfig{ - Namespace: "namespace2", - HostPort: "workload2.external:9091", - Authentication: TemporalAuthConfig{ - ApiKey: &TemporalApiKeyConfig{ - Value: "workload2-api-key", - }, - }, - }, - EncryptionKey: "key2", - Authentication: &AuthConfig{ - Type: "oauth", - Config: map[string]interface{}{ - "client_id": "test-client", - "client_secret": "test-secret", - }, - }, - }, - }, - }, - wantErr: false, - }, - { - name: "invalid yaml", - yamlData: `invalid: yaml: content: [`, - want: Config{}, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var got Config - err := yaml.Unmarshal([]byte(tt.yamlData), &got) - - if (err != nil) != tt.wantErr { - t.Errorf("yaml.Unmarshal() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if !tt.wantErr { - if !configEqual(got, tt.want) { - t.Errorf("yaml.Unmarshal() got = %+v, want %+v", got, tt.want) - } - } - }) - } -} - -func TestServerConfig_Validation(t *testing.T) { - tests := []struct { - name string - config ServerConfig - valid bool - }{ - { - name: "valid server config", - config: ServerConfig{ - Port: 7233, - Host: "0.0.0.0", - }, - valid: true, - }, - { - name: "valid localhost config", - config: ServerConfig{ - Port: 8080, - Host: "localhost", - }, - valid: true, - }, - { - name: "zero port should be handled by application logic", - config: ServerConfig{ - Port: 0, - Host: "localhost", - }, - valid: true, // Structure is valid, business logic should handle port validation - }, - { - name: "empty host should be handled by application logic", - config: ServerConfig{ - Port: 8080, - Host: "", - }, - valid: true, // Structure is valid, business logic should handle host validation - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Since there's no validation method in the struct, we just test that the struct can be created - // In a real application, you might have validation methods - if tt.config.Port < 0 || tt.config.Port > 65535 { - t.Errorf("Port %d is outside valid range", tt.config.Port) - } - }) - } -} - -func TestWorkloadConfig_Structure(t *testing.T) { - workload := WorkloadConfig{ - WorkloadId: "test.internal", - TemporalCloud: TemporalCloudConfig{ - Namespace: "test-namespace", - HostPort: "test.external:7233", - Authentication: TemporalAuthConfig{ - TLS: &TLSConfig{ - CertFile: "/path/to/cert.crt", - KeyFile: "/path/to/key.key", - }, - }, - }, - EncryptionKey: "test-key", - Authentication: &AuthConfig{ - Type: "spiffe", - Config: map[string]interface{}{ - "trust_domain": "spiffe://example.org/", - }, - }, - } - - if workload.WorkloadId != "test.internal" { - t.Errorf("Expected WorkloadId to be 'test.internal', got %s", workload.WorkloadId) - } - if workload.TemporalCloud.HostPort != "test.external:7233" { - t.Errorf("Expected TemporalCloud.HostPort to be 'test.external:7233', got %s", workload.TemporalCloud.HostPort) - } - if workload.EncryptionKey != "test-key" { - t.Errorf("Expected EncryptionKey to be 'test-key', got %s", workload.EncryptionKey) - } - if workload.TemporalCloud.Namespace != "test-namespace" { - t.Errorf("Expected TemporalCloud.Namespace to be 'test-namespace', got %s", workload.TemporalCloud.Namespace) - } - if workload.TemporalCloud.Authentication.TLS.CertFile != "/path/to/cert.crt" { - t.Errorf("Expected TemporalCloud.Authentication.TLS.CertFile to be '/path/to/cert.crt', got %s", workload.TemporalCloud.Authentication.TLS.CertFile) - } - if workload.TemporalCloud.Authentication.TLS.KeyFile != "/path/to/key.key" { - t.Errorf("Expected TemporalCloud.Authentication.TLS.KeyFile to be '/path/to/key.key', got %s", workload.TemporalCloud.Authentication.TLS.KeyFile) - } - if workload.Authentication == nil { - t.Error("Expected Authentication to not be nil") - } else { - if workload.Authentication.Type != "spiffe" { - t.Errorf("Expected Authentication.Type to be 'spiffe', got %s", workload.Authentication.Type) - } - if trustDomain, ok := workload.Authentication.Config["trust_domain"]; !ok || trustDomain != "spiffe://example.org/" { - t.Errorf("Expected trust_domain to be 'spiffe://example.org/', got %v", trustDomain) - } - } -} - -func TestAuthConfig_Types(t *testing.T) { - tests := []struct { - name string - authConfig AuthConfig - wantType string - }{ - { - name: "spiffe auth", - authConfig: AuthConfig{ - Type: "spiffe", - Config: map[string]interface{}{ - "trust_domain": "spiffe://example.org/", - "endpoint": "unix:///tmp/spire-agent/public/api.sock", - }, - }, - wantType: "spiffe", - }, - { - name: "oauth auth", - authConfig: AuthConfig{ - Type: "oauth", - Config: map[string]interface{}{ - "client_id": "test-client", - "client_secret": "test-secret", - }, - }, - wantType: "oauth", - }, - { - name: "custom auth", - authConfig: AuthConfig{ - Type: "custom", - Config: map[string]interface{}{ - "custom_field": "custom_value", - }, - }, - wantType: "custom", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.authConfig.Type != tt.wantType { - t.Errorf("Expected Type to be %s, got %s", tt.wantType, tt.authConfig.Type) - } - if tt.authConfig.Config == nil { - t.Error("Expected Config to not be nil") - } - }) - } -} - -func TestMetricsConfig_Structure(t *testing.T) { - tests := []struct { - name string - config MetricsConfig - want int - }{ - { - name: "default metrics port", - config: MetricsConfig{Port: 8080}, - want: 8080, - }, - { - name: "custom metrics port", - config: MetricsConfig{Port: 9090}, - want: 9090, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.config.Port != tt.want { - t.Errorf("Expected Port to be %d, got %d", tt.want, tt.config.Port) - } - }) - } -} - -func TestEncryptionConfig_Structure(t *testing.T) { - tests := []struct { - name string - config EncryptionConfig - want CachingConfig - }{ - { - name: "full caching config", - config: EncryptionConfig{ - Caching: CachingConfig{ - MaxCache: 100, - MaxAge: "1h", - MaxUsage: 1000, - }, - }, - want: CachingConfig{ - MaxCache: 100, - MaxAge: "1h", - MaxUsage: 1000, - }, - }, - { - name: "partial caching config", - config: EncryptionConfig{ - Caching: CachingConfig{ - MaxCache: 50, - }, - }, - want: CachingConfig{ - MaxCache: 50, - }, - }, - { - name: "empty caching config", - config: EncryptionConfig{ - Caching: CachingConfig{}, - }, - want: CachingConfig{}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.config.Caching.MaxCache != tt.want.MaxCache { - t.Errorf("Expected MaxCache to be %d, got %d", tt.want.MaxCache, tt.config.Caching.MaxCache) - } - if tt.config.Caching.MaxAge != tt.want.MaxAge { - t.Errorf("Expected MaxAge to be %s, got %s", tt.want.MaxAge, tt.config.Caching.MaxAge) - } - if tt.config.Caching.MaxUsage != tt.want.MaxUsage { - t.Errorf("Expected MaxUsage to be %d, got %d", tt.want.MaxUsage, tt.config.Caching.MaxUsage) - } - }) - } -} - -func TestTemporalAuthConfig_Structure(t *testing.T) { - tests := []struct { - name string - config TemporalAuthConfig - desc string - }{ - { - name: "TLS authentication", - config: TemporalAuthConfig{ - TLS: &TLSConfig{ - CertFile: "/path/to/cert.crt", - KeyFile: "/path/to/key.key", - }, - }, - desc: "should have TLS config and no API key", - }, - { - name: "API key authentication with value", - config: TemporalAuthConfig{ - ApiKey: &TemporalApiKeyConfig{ - Value: "test-api-key", - }, - }, - desc: "should have API key and no TLS config", - }, - { - name: "API key authentication with env var", - config: TemporalAuthConfig{ - ApiKey: &TemporalApiKeyConfig{ - EnvVar: "TEMPORAL_API_KEY", - }, - }, - desc: "should have API key env var and no TLS config", - }, - { - name: "empty authentication", - config: TemporalAuthConfig{}, - desc: "should have neither TLS nor API key", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - switch tt.name { - case "TLS authentication": - if tt.config.TLS == nil { - t.Error("Expected TLS to not be nil") - } else { - if tt.config.TLS.CertFile != "/path/to/cert.crt" { - t.Errorf("Expected CertFile to be '/path/to/cert.crt', got %s", tt.config.TLS.CertFile) - } - if tt.config.TLS.KeyFile != "/path/to/key.key" { - t.Errorf("Expected KeyFile to be '/path/to/key.key', got %s", tt.config.TLS.KeyFile) - } - } - if tt.config.ApiKey != nil { - t.Error("Expected ApiKey to be nil") - } - case "API key authentication with value": - if tt.config.ApiKey == nil { - t.Error("Expected ApiKey to not be nil") - } else { - if tt.config.ApiKey.Value != "test-api-key" { - t.Errorf("Expected ApiKey.Value to be 'test-api-key', got %s", tt.config.ApiKey.Value) - } - if tt.config.ApiKey.EnvVar != "" { - t.Errorf("Expected ApiKey.EnvVar to be empty, got %s", tt.config.ApiKey.EnvVar) - } - } - if tt.config.TLS != nil { - t.Error("Expected TLS to be nil") - } - case "API key authentication with env var": - if tt.config.ApiKey == nil { - t.Error("Expected ApiKey to not be nil") - } else { - if tt.config.ApiKey.EnvVar != "TEMPORAL_API_KEY" { - t.Errorf("Expected ApiKey.EnvVar to be 'TEMPORAL_API_KEY', got %s", tt.config.ApiKey.EnvVar) - } - if tt.config.ApiKey.Value != "" { - t.Errorf("Expected ApiKey.Value to be empty, got %s", tt.config.ApiKey.Value) - } - } - if tt.config.TLS != nil { - t.Error("Expected TLS to be nil") - } - case "empty authentication": - if tt.config.TLS != nil { - t.Error("Expected TLS to be nil") - } - if tt.config.ApiKey != nil { - t.Error("Expected ApiKey to be nil") - } - } - }) - } -} - -// Helper function to compare Config structs -func configEqual(a, b Config) bool { - if a.Server.Port != b.Server.Port || a.Server.Host != b.Server.Host { - return false - } - - if len(a.Workloads) != len(b.Workloads) { - return false - } - - for i, workloadA := range a.Workloads { - workloadB := b.Workloads[i] - if !workloadConfigEqual(workloadA, workloadB) { - return false - } - } - - return true -} - -func workloadConfigEqual(a, b WorkloadConfig) bool { - if a.WorkloadId != b.WorkloadId || a.EncryptionKey != b.EncryptionKey { - return false - } - - // Compare TemporalCloud configuration - if a.TemporalCloud.Namespace != b.TemporalCloud.Namespace || a.TemporalCloud.HostPort != b.TemporalCloud.HostPort { - return false - } - - // Compare TemporalCloud Authentication - API Key - if (a.TemporalCloud.Authentication.ApiKey == nil) != (b.TemporalCloud.Authentication.ApiKey == nil) { - return false - } - - if a.TemporalCloud.Authentication.ApiKey != nil && b.TemporalCloud.Authentication.ApiKey != nil { - if a.TemporalCloud.Authentication.ApiKey.Value != b.TemporalCloud.Authentication.ApiKey.Value || - a.TemporalCloud.Authentication.ApiKey.EnvVar != b.TemporalCloud.Authentication.ApiKey.EnvVar { - return false - } - } - - // Compare TLS configuration - if (a.TemporalCloud.Authentication.TLS == nil) != (b.TemporalCloud.Authentication.TLS == nil) { - return false - } - - if a.TemporalCloud.Authentication.TLS != nil && b.TemporalCloud.Authentication.TLS != nil { - if a.TemporalCloud.Authentication.TLS.CertFile != b.TemporalCloud.Authentication.TLS.CertFile || - a.TemporalCloud.Authentication.TLS.KeyFile != b.TemporalCloud.Authentication.TLS.KeyFile { - return false - } - } - - // Compare proxy Authentication (spiffe, oauth, etc.) - if (a.Authentication == nil) != (b.Authentication == nil) { - return false - } - - if a.Authentication != nil && b.Authentication != nil { - if a.Authentication.Type != b.Authentication.Type { - return false - } - if len(a.Authentication.Config) != len(b.Authentication.Config) { - return false - } - // Simple comparison - in production you might want more sophisticated comparison - for key, valueA := range a.Authentication.Config { - if valueB, exists := b.Authentication.Config[key]; !exists { - return false - } else { - // Handle slice comparison for audiences - if sliceA, okA := valueA.([]interface{}); okA { - if sliceB, okB := valueB.([]interface{}); okB { - if len(sliceA) != len(sliceB) { - return false - } - for j, itemA := range sliceA { - if itemA != sliceB[j] { - return false - } - } - } else { - return false - } - } else if valueA != valueB { - return false - } - } - } - } - - return true -} diff --git a/utils/crypt.go b/utils/crypt.go deleted file mode 100644 index 569108f..0000000 --- a/utils/crypt.go +++ /dev/null @@ -1,48 +0,0 @@ -package utils - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "fmt" - "io" -) - -func Encrypt(plainData []byte, key []byte) ([]byte, error) { - c, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - - gcm, err := cipher.NewGCM(c) - if err != nil { - return nil, err - } - - nonce := make([]byte, gcm.NonceSize()) - if _, err = io.ReadFull(rand.Reader, nonce); err != nil { - return nil, err - } - - return gcm.Seal(nonce, nonce, plainData, nil), nil -} - -func Decrypt(encryptedData []byte, key []byte) ([]byte, error) { - c, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - - gcm, err := cipher.NewGCM(c) - if err != nil { - return nil, err - } - - nonceSize := gcm.NonceSize() - if len(encryptedData) < nonceSize { - return nil, fmt.Errorf("ciphertext too short: %v", encryptedData) - } - - nonce, encryptedData := encryptedData[:nonceSize], encryptedData[nonceSize:] - return gcm.Open(nil, nonce, encryptedData, nil) -}