diff --git a/storage/http_client.go b/storage/http_client.go index d117d90e551e..9f0d04b5ce6f 100644 --- a/storage/http_client.go +++ b/storage/http_client.go @@ -45,8 +45,6 @@ import ( // httpStorageClient is the HTTP-JSON API implementation of the transport-agnostic // storageClient interface. -// -// TODO(b/498422946): Add client feature tracker in HTTP client. type httpStorageClient struct { creds *auth.Credentials hc *http.Client @@ -56,6 +54,10 @@ type httpStorageClient struct { settings *settings config *storageConfig dynamicReadReqStallTimeout *bucketDelayManager + + // configFeatureAttributes tracks client-level features that are enabled for this + // client instance. + configFeatureAttributes uint32 } // newHTTPStorageClient initializes a new storageClient that uses the HTTP-JSON @@ -119,8 +121,24 @@ func newHTTPStorageClient(ctx context.Context, opts ...storageOption) (storageCl if err != nil { return nil, fmt.Errorf("dialing: %w", err) } + + // Clone the http.Client to avoid modifying the original one if it was provided by the user. + hcClone := *hc + c := &httpStorageClient{ + creds: creds, + hc: &hcClone, + settings: s, + config: &config, + } + + // Wrap transport to inject tracking headers. + hcClone.Transport = &trackingTransport{ + base: hc.Transport, + features: c.configFeatureAttributes, + } + // RawService should be created with the chosen endpoint to take account of user override. - rawService, err := raw.NewService(ctx, option.WithEndpoint(ep), option.WithHTTPClient(hc)) + rawService, err := raw.NewService(ctx, option.WithEndpoint(ep), option.WithHTTPClient(c.hc)) if err != nil { return nil, fmt.Errorf("storage client: %w", err) } @@ -144,16 +162,12 @@ func newHTTPStorageClient(ctx context.Context, opts ...storageOption) (storageCl } } - return &httpStorageClient{ - creds: creds, - hc: hc, - xmlHost: u.Host, - raw: rawService, - scheme: u.Scheme, - settings: s, - config: &config, - dynamicReadReqStallTimeout: bd, - }, nil + c.xmlHost = u.Host + c.raw = rawService + c.scheme = u.Scheme + c.dynamicReadReqStallTimeout = bd + + return c, nil } func (c *httpStorageClient) Close() error { @@ -161,6 +175,43 @@ func (c *httpStorageClient) Close() error { return nil } +// trackingTransport wraps an http.RoundTripper to inject feature tracking headers. +type trackingTransport struct { + base http.RoundTripper + features uint32 +} + +func (t *trackingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + baseRT := t.base + if baseRT == nil { + baseRT = http.DefaultTransport + } + features := t.features | featureAttributes(req.Context()) + // Merge all existing headers for this key. + features |= mergeFeatureAttributes(req.Header.Values(featureTrackerHeaderName)) + if features > 0 { + // Clone the request to avoid modifying the original one. + clonedReq := req.Clone(req.Context()) + clonedReq.Header.Set(featureTrackerHeaderName, encodeUint32(features)) + return baseRT.RoundTrip(clonedReq) + } + + return baseRT.RoundTrip(req) +} + +func (t *trackingTransport) CloseIdleConnections() { + type closeIdler interface { + CloseIdleConnections() + } + base := t.base + if base == nil { + base = http.DefaultTransport + } + if tr, ok := base.(closeIdler); ok { + tr.CloseIdleConnections() + } +} + // Top-level methods. func (c *httpStorageClient) GetServiceAccount(ctx context.Context, project string, opts ...storageOption) (string, error) { diff --git a/storage/http_client_test.go b/storage/http_client_test.go index 5301a3359e6f..477864ad3e34 100644 --- a/storage/http_client_test.go +++ b/storage/http_client_test.go @@ -171,3 +171,55 @@ func TestValidateChecksumFromServer(t *testing.T) { }) } } + +func TestTrackingTransport(t *testing.T) { + mock := &mockTransport{} + mock.addResult(&http.Response{Status: "200 OK"}, nil) + tt := &trackingTransport{ + base: mock, + features: uint32(1 << featurePCU), + } + + ctx := context.Background() + ctx = addFeatureAttributes(ctx, featureMultistreamInMRD) + req, _ := http.NewRequestWithContext(ctx, "GET", "http://example.com", nil) + + _, err := tt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip failed: %v", err) + } + + gotHeader := mock.gotReq.Header.Get(featureTrackerHeaderName) + wantFeatures := uint32(1<