Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions cmd/api/api/instances.go
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,80 @@ func (s *ApiService) StatInstancePath(ctx context.Context, request oapi.StatInst
return response, nil
}

// UpdateInstanceCredentials replaces credential brokering policies for an instance.
// The id parameter can be an instance ID, name, or ID prefix.
// Note: Resolution is handled by ResolveResource middleware.
func (s *ApiService) UpdateInstanceCredentials(ctx context.Context, request oapi.UpdateInstanceCredentialsRequestObject) (oapi.UpdateInstanceCredentialsResponseObject, error) {
inst := mw.GetResolvedInstance[instances.Instance](ctx)
if inst == nil {
return oapi.UpdateInstanceCredentials500JSONResponse{
Code: "internal_error",
Message: "resource not resolved",
}, nil
}
log := logger.FromContext(ctx)

if request.Body == nil {
return oapi.UpdateInstanceCredentials400JSONResponse{
Code: "invalid_request",
Message: "request body is required",
}, nil
}

// Convert OAPI credential types to domain types
credentials := make(map[string]instances.CredentialPolicy, len(request.Body.Credentials))
for credentialName, credential := range request.Body.Credentials {
policy := instances.CredentialPolicy{
Source: instances.CredentialSource{
Env: credential.Source.Env,
},
Inject: make([]instances.CredentialInjectRule, 0, len(credential.Inject)),
}
for _, inject := range credential.Inject {
rule := instances.CredentialInjectRule{
As: instances.CredentialInjectAs{
Header: inject.As.Header,
Format: inject.As.Format,
},
}
if inject.Hosts != nil {
rule.Hosts = append([]string(nil), (*inject.Hosts)...)
}
policy.Inject = append(policy.Inject, rule)
}
credentials[credentialName] = policy
}

env := make(map[string]string)
if request.Body.Env != nil {
env = *request.Body.Env
}

result, err := s.InstanceManager.UpdateInstanceCredentials(ctx, inst.Id, credentials, env)
if err != nil {
switch {
case errors.Is(err, instances.ErrNotFound):
return oapi.UpdateInstanceCredentials404JSONResponse{
Code: "not_found",
Message: "instance not found",
}, nil
case errors.Is(err, instances.ErrInvalidRequest):
return oapi.UpdateInstanceCredentials400JSONResponse{
Code: "invalid_request",
Message: err.Error(),
}, nil
default:
log.ErrorContext(ctx, "failed to update instance credentials", "error", err)
return oapi.UpdateInstanceCredentials500JSONResponse{
Code: "internal_error",
Message: "failed to update instance credentials",
}, nil
}
}

return oapi.UpdateInstanceCredentials200JSONResponse(instanceToOAPI(*result)), nil
}

// AttachVolume attaches a volume to an instance (not yet implemented)
func (s *ApiService) AttachVolume(ctx context.Context, request oapi.AttachVolumeRequestObject) (oapi.AttachVolumeResponseObject, error) {
return oapi.AttachVolume500JSONResponse{
Expand Down
205 changes: 205 additions & 0 deletions cmd/api/api/instances_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,25 @@ func TestCreateInstance_InvalidSizeFormat(t *testing.T) {
assert.Contains(t, badReq.Message, "invalid size format")
}

type captureUpdateCredentialsManager struct {
instances.Manager
lastID string
lastCredentials map[string]instances.CredentialPolicy
lastEnv map[string]string
result *instances.Instance
err error
}

func (m *captureUpdateCredentialsManager) UpdateInstanceCredentials(ctx context.Context, id string, credentials map[string]instances.CredentialPolicy, env map[string]string) (*instances.Instance, error) {
m.lastID = id
m.lastCredentials = credentials
m.lastEnv = env
if m.err != nil {
return nil, m.err
}
return m.result, nil
}

type captureCreateManager struct {
instances.Manager
lastReq *instances.CreateInstanceRequest
Expand Down Expand Up @@ -688,3 +707,189 @@ func waitForState(t *testing.T, svc *ApiService, instanceID string, expectedStat
}
t.Fatalf("Timeout waiting for instance to reach %s state", expectedState)
}

func TestUpdateInstanceCredentials_Success(t *testing.T) {
t.Parallel()
svc := newTestService(t)

now := time.Now()
source := instances.Instance{
StoredMetadata: instances.StoredMetadata{
Id: "inst-creds",
Name: "inst-creds",
Image: "docker.io/library/alpine:latest",
CreatedAt: now,
HypervisorType: hypervisor.TypeCloudHypervisor,
},
State: instances.StateRunning,
}

mockMgr := &captureUpdateCredentialsManager{
Manager: svc.InstanceManager,
result: &source,
}
svc.InstanceManager = mockMgr

credentials := map[string]oapi.CreateInstanceRequestCredential{
"OPENAI_KEY": {
Source: oapi.CreateInstanceRequestCredentialSource{Env: "OPENAI_KEY"},
Inject: []oapi.CreateInstanceRequestCredentialInject{
{
Hosts: &[]string{"api.openai.com"},
As: oapi.CreateInstanceRequestCredentialInjectAs{
Header: "Authorization",
Format: "Bearer ${value}",
},
},
},
},
}
env := map[string]string{"OPENAI_KEY": "sk-rotated-key"}

resp, err := svc.UpdateInstanceCredentials(
mw.WithResolvedInstance(ctx(), source.Id, source),
oapi.UpdateInstanceCredentialsRequestObject{
Id: source.Id,
Body: &oapi.UpdateInstanceCredentialsRequest{
Credentials: credentials,
Env: &env,
},
},
)
require.NoError(t, err)

ok200, ok := resp.(oapi.UpdateInstanceCredentials200JSONResponse)
require.True(t, ok, "expected 200 response, got %T", resp)
assert.Equal(t, "inst-creds", ok200.Id)

// Verify domain conversion
assert.Equal(t, source.Id, mockMgr.lastID)
require.Contains(t, mockMgr.lastCredentials, "OPENAI_KEY")
policy := mockMgr.lastCredentials["OPENAI_KEY"]
assert.Equal(t, "OPENAI_KEY", policy.Source.Env)
require.Len(t, policy.Inject, 1)
assert.Equal(t, []string{"api.openai.com"}, policy.Inject[0].Hosts)
assert.Equal(t, "Authorization", policy.Inject[0].As.Header)
assert.Equal(t, "Bearer ${value}", policy.Inject[0].As.Format)
assert.Equal(t, "sk-rotated-key", mockMgr.lastEnv["OPENAI_KEY"])
}

func TestUpdateInstanceCredentials_EmptyCredentialsClears(t *testing.T) {
t.Parallel()
svc := newTestService(t)

now := time.Now()
source := instances.Instance{
StoredMetadata: instances.StoredMetadata{
Id: "inst-clear-creds",
Name: "inst-clear-creds",
Image: "docker.io/library/alpine:latest",
CreatedAt: now,
HypervisorType: hypervisor.TypeCloudHypervisor,
},
State: instances.StateRunning,
}

mockMgr := &captureUpdateCredentialsManager{
Manager: svc.InstanceManager,
result: &source,
}
svc.InstanceManager = mockMgr

resp, err := svc.UpdateInstanceCredentials(
mw.WithResolvedInstance(ctx(), source.Id, source),
oapi.UpdateInstanceCredentialsRequestObject{
Id: source.Id,
Body: &oapi.UpdateInstanceCredentialsRequest{
Credentials: map[string]oapi.CreateInstanceRequestCredential{},
},
},
)
require.NoError(t, err)

_, ok := resp.(oapi.UpdateInstanceCredentials200JSONResponse)
require.True(t, ok, "expected 200 response")
assert.Empty(t, mockMgr.lastCredentials)
}

func TestUpdateInstanceCredentials_InvalidRequest(t *testing.T) {
t.Parallel()
svc := newTestService(t)

source := instances.Instance{
StoredMetadata: instances.StoredMetadata{
Id: "inst-invalid-creds",
Name: "inst-invalid-creds",
Image: "docker.io/library/alpine:latest",
CreatedAt: time.Now(),
HypervisorType: hypervisor.TypeCloudHypervisor,
},
State: instances.StateRunning,
}

mockMgr := &captureUpdateCredentialsManager{
Manager: svc.InstanceManager,
err: fmt.Errorf("%w: credentials require network.egress.enabled=true", instances.ErrInvalidRequest),
}
svc.InstanceManager = mockMgr

resp, err := svc.UpdateInstanceCredentials(
mw.WithResolvedInstance(ctx(), source.Id, source),
oapi.UpdateInstanceCredentialsRequestObject{
Id: source.Id,
Body: &oapi.UpdateInstanceCredentialsRequest{
Credentials: map[string]oapi.CreateInstanceRequestCredential{
"KEY": {
Source: oapi.CreateInstanceRequestCredentialSource{Env: "KEY"},
Inject: []oapi.CreateInstanceRequestCredentialInject{
{As: oapi.CreateInstanceRequestCredentialInjectAs{Header: "X-Key", Format: "${value}"}},
},
},
},
},
},
)
require.NoError(t, err)

badReq, ok := resp.(oapi.UpdateInstanceCredentials400JSONResponse)
require.True(t, ok, "expected 400 response, got %T", resp)
assert.Equal(t, "invalid_request", badReq.Code)
assert.Contains(t, badReq.Message, "credentials require network.egress.enabled=true")
}

func TestUpdateInstanceCredentials_NotFound(t *testing.T) {
t.Parallel()
svc := newTestService(t)

source := instances.Instance{
StoredMetadata: instances.StoredMetadata{
Id: "inst-gone",
Name: "inst-gone",
Image: "docker.io/library/alpine:latest",
CreatedAt: time.Now(),
HypervisorType: hypervisor.TypeCloudHypervisor,
},
State: instances.StateRunning,
}

mockMgr := &captureUpdateCredentialsManager{
Manager: svc.InstanceManager,
err: instances.ErrNotFound,
}
svc.InstanceManager = mockMgr

resp, err := svc.UpdateInstanceCredentials(
mw.WithResolvedInstance(ctx(), source.Id, source),
oapi.UpdateInstanceCredentialsRequestObject{
Id: source.Id,
Body: &oapi.UpdateInstanceCredentialsRequest{
Credentials: map[string]oapi.CreateInstanceRequestCredential{},
},
},
)
require.NoError(t, err)

notFound, ok := resp.(oapi.UpdateInstanceCredentials404JSONResponse)
require.True(t, ok, "expected 404 response, got %T", resp)
assert.Equal(t, "not_found", notFound.Code)
}
20 changes: 20 additions & 0 deletions lib/egressproxy/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,26 @@ func compileHeaderInjectRules(cfgRules []HeaderInjectRuleConfig) ([]headerInject
return out, nil
}

// UpdateInstancePolicy updates the header injection rules for an already-registered instance.
// Returns false if the instance is not registered (no-op).
func (s *Service) UpdateInstancePolicy(instanceID string, rules []HeaderInjectRuleConfig) (bool, error) {
compiled, err := compileHeaderInjectRules(rules)
if err != nil {
return false, err
}

s.mu.Lock()
defer s.mu.Unlock()

sourceIP, ok := s.sourceIPByInstance[instanceID]
if !ok {
return false, nil
}

s.policiesBySourceIP[sourceIP] = sourcePolicy{headerInjectRules: compiled}
return true, nil
}

func (s *Service) UnregisterInstance(_ context.Context, instanceID string) {
s.mu.Lock()
sourceIP, ok := s.sourceIPByInstance[instanceID]
Expand Down
Loading
Loading