Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 64 additions & 13 deletions storage/http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ import (

// httpStorageClient is the HTTP-JSON API implementation of the transport-agnostic
// storageClient interface.
//
// TODO(b/498422946): Add client feature tracker in HTTP client.
type httpStorageClient struct {
creds *auth.Credentials
hc *http.Client
Expand All @@ -56,6 +54,10 @@ type httpStorageClient struct {
settings *settings
config *storageConfig
dynamicReadReqStallTimeout *bucketDelayManager

// configFeatureAttributes tracks client-level features that are enabled for this
// client instance.
configFeatureAttributes uint32
}

// newHTTPStorageClient initializes a new storageClient that uses the HTTP-JSON
Expand Down Expand Up @@ -119,8 +121,24 @@ func newHTTPStorageClient(ctx context.Context, opts ...storageOption) (storageCl
if err != nil {
return nil, fmt.Errorf("dialing: %w", err)
}

// Clone the http.Client to avoid modifying the original one if it was provided by the user.
hcClone := *hc
c := &httpStorageClient{
creds: creds,
hc: &hcClone,
settings: s,
config: &config,
}
Comment thread
krishnamd-jkp marked this conversation as resolved.
Comment thread
krishnamd-jkp marked this conversation as resolved.

// Wrap transport to inject tracking headers.
hcClone.Transport = &trackingTransport{
base: hc.Transport,
features: c.configFeatureAttributes,
}

// RawService should be created with the chosen endpoint to take account of user override.
rawService, err := raw.NewService(ctx, option.WithEndpoint(ep), option.WithHTTPClient(hc))
rawService, err := raw.NewService(ctx, option.WithEndpoint(ep), option.WithHTTPClient(c.hc))
if err != nil {
return nil, fmt.Errorf("storage client: %w", err)
}
Expand All @@ -144,23 +162,56 @@ func newHTTPStorageClient(ctx context.Context, opts ...storageOption) (storageCl
}
}

return &httpStorageClient{
creds: creds,
hc: hc,
xmlHost: u.Host,
raw: rawService,
scheme: u.Scheme,
settings: s,
config: &config,
dynamicReadReqStallTimeout: bd,
}, nil
c.xmlHost = u.Host
c.raw = rawService
c.scheme = u.Scheme
c.dynamicReadReqStallTimeout = bd

return c, nil
}

func (c *httpStorageClient) Close() error {
c.hc.CloseIdleConnections()
return nil
}

// trackingTransport wraps an http.RoundTripper to inject feature tracking headers.
type trackingTransport struct {
base http.RoundTripper
features uint32
}

func (t *trackingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
baseRT := t.base
if baseRT == nil {
baseRT = http.DefaultTransport
}
features := t.features | featureAttributes(req.Context())
// Merge all existing headers for this key.
features |= mergeFeatureAttributes(req.Header.Values(featureTrackerHeaderName))
if features > 0 {
// Clone the request to avoid modifying the original one.
clonedReq := req.Clone(req.Context())
clonedReq.Header.Set(featureTrackerHeaderName, encodeUint32(features))
return baseRT.RoundTrip(clonedReq)
}

return baseRT.RoundTrip(req)
}
Comment thread
krishnamd-jkp marked this conversation as resolved.

func (t *trackingTransport) CloseIdleConnections() {
type closeIdler interface {
CloseIdleConnections()
}
base := t.base
if base == nil {
base = http.DefaultTransport
}
if tr, ok := base.(closeIdler); ok {
tr.CloseIdleConnections()
}
}

// Top-level methods.

func (c *httpStorageClient) GetServiceAccount(ctx context.Context, project string, opts ...storageOption) (string, error) {
Expand Down
52 changes: 52 additions & 0 deletions storage/http_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,55 @@ func TestValidateChecksumFromServer(t *testing.T) {
})
}
}

func TestTrackingTransport(t *testing.T) {
mock := &mockTransport{}
mock.addResult(&http.Response{Status: "200 OK"}, nil)
tt := &trackingTransport{
base: mock,
features: uint32(1 << featurePCU),
}

ctx := context.Background()
ctx = addFeatureAttributes(ctx, featureMultistreamInMRD)
req, _ := http.NewRequestWithContext(ctx, "GET", "http://example.com", nil)

_, err := tt.RoundTrip(req)
if err != nil {
t.Fatalf("RoundTrip failed: %v", err)
}

gotHeader := mock.gotReq.Header.Get(featureTrackerHeaderName)
wantFeatures := uint32(1<<featurePCU) | uint32(1<<featureMultistreamInMRD)
wantHeader := encodeUint32(wantFeatures)

if gotHeader != wantHeader {
t.Errorf("Header %s = %q; want %q", featureTrackerHeaderName, gotHeader, wantHeader)
}

// Verify original request was not modified.
if req.Header.Get(featureTrackerHeaderName) != "" {
t.Errorf("Original request header was modified")
}
}

type mockCloseIdler struct {
mockTransport
closedIdle bool
}

func (m *mockCloseIdler) CloseIdleConnections() {
m.closedIdle = true
}

func TestTrackingTransport_CloseIdleConnections(t *testing.T) {
mock := &mockCloseIdler{}
tt := &trackingTransport{
base: mock,
}

tt.CloseIdleConnections()
if !mock.closedIdle {
t.Errorf("CloseIdleConnections was not called on base transport")
}
}
Loading