Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 56 additions & 8 deletions httputilx/httputilx.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,9 @@ func DoExponentialBackoff(req *http.Request, options ...ExponentialBackoffOption
backoff := o.initialBackoff

for attempt := 0; attempt <= o.maxRetries; attempt++ {
reqClone := req.Clone(req.Context())
if req.Body != nil {
if seeker, ok := req.Body.(interface {
Seek(int64, int) (int64, error)
}); ok {
_, _ = seeker.Seek(0, 0)
}
reqClone.Body = req.Body
reqClone, err := cloneWithBody(req)
if err != nil {
return nil, errors.Wrap(err, "failed to clone request with body")
}

resp, err := o.client.Do(reqClone)
Expand Down Expand Up @@ -272,3 +267,56 @@ func DoExponentialBackoff(req *http.Request, options ...ExponentialBackoffOption

return nil, fmt.Errorf("request failed after %d attempts", o.maxRetries+1)
}

func cloneWithBody(req *http.Request) (*http.Request, error) {
newReq := req.Clone(req.Context())
if req.Body == nil {
return newReq, nil
}
if req.GetBody != nil {
var err error
newReq.Body, err = req.GetBody()
if err != nil {
return nil, err
}
return newReq, nil
}

if seeker, ok := req.Body.(io.Seeker); ok {
if _, err := seeker.Seek(0, io.SeekStart); err != nil {
return nil, err
}
newReq.Body = req.Body
newReq.GetBody = func() (io.ReadCloser, error) {
if _, err := seeker.Seek(0, io.SeekStart); err != nil {
return nil, err
}
return req.Body, nil
}
return newReq, nil
}

bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
if err := req.Body.Close(); err != nil {
return nil, err
}

createBody := func(bodyBytes []byte) func() (io.ReadCloser, error) {
return func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(bodyBytes)), nil
}
}

req.GetBody = createBody(bodyBytes)
req.Body, _ = req.GetBody()
req.ContentLength = int64(len(bodyBytes))

newReq.GetBody = createBody(bodyBytes)
newReq.Body, _ = newReq.GetBody()
newReq.ContentLength = int64(len(bodyBytes))

return newReq, nil
}
56 changes: 55 additions & 1 deletion httputilx/httputilx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ func TestDoExponentialBackoff(t *testing.T) {
name string
options []ExponentialBackoffOption
handler http.HandlerFunc
requestBody io.Reader
wantBody string
wantErr string
wantAttempts int
Expand Down Expand Up @@ -323,6 +324,55 @@ func TestDoExponentialBackoff(t *testing.T) {
wantErr: "",
wantAttempts: 3,
},
{
name: "RequestBodyCopiedOnRetry",
options: []ExponentialBackoffOption{
ExponentialBackoffWithConfig(4, 100*time.Millisecond, 5*time.Second, 2.0),
},
handler: func() http.HandlerFunc {
initialBody := "request body content"

attempts := 0
return func(w http.ResponseWriter, r *http.Request) {
attempts++

if r.ContentLength != int64(len(initialBody)) {
w.WriteHeader(http.StatusInternalServerError)
_, _ = fmt.Fprintf(w, "wrong content-length: got %d, want %d", r.ContentLength, len(initialBody))
return
}

body, err := io.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}

if len(body) != len(initialBody) {
w.WriteHeader(http.StatusInternalServerError)
_, _ = fmt.Fprintf(w, "content-length mismatch: header=%d actual=%d", r.ContentLength, len(body))
return
}

// Verify body is correctly sent on all attempts
if string(body) != initialBody {
w.WriteHeader(http.StatusInternalServerError)
_, _ = fmt.Fprintf(w, "incorrect body: %q", string(body))
return
}
if attempts < 3 {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("body received correctly"))
}
}(),
requestBody: io.NopCloser(bytes.NewBuffer([]byte("request body content"))),
wantBody: "body received correctly",
wantErr: "",
wantAttempts: 3,
},
}

for _, tt := range tests {
Expand All @@ -334,7 +384,11 @@ func TestDoExponentialBackoff(t *testing.T) {
}))
defer ts.Close()

req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
method := http.MethodGet
if tt.requestBody != nil {
method = http.MethodPost
}
req, err := http.NewRequest(method, ts.URL, tt.requestBody)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
Expand Down