diff --git a/httputilx/httputilx.go b/httputilx/httputilx.go index c135bde..a98389f 100644 --- a/httputilx/httputilx.go +++ b/httputilx/httputilx.go @@ -233,14 +233,9 @@ func DoExponentialBackoff(req *http.Request, options ...ExponentialBackoffOption backoff := o.initialBackoff for attempt := 0; attempt <= o.maxRetries; attempt++ { - reqClone := req.Clone(req.Context()) - if req.Body != nil { - if seeker, ok := req.Body.(interface { - Seek(int64, int) (int64, error) - }); ok { - _, _ = seeker.Seek(0, 0) - } - reqClone.Body = req.Body + reqClone, err := cloneWithBody(req) + if err != nil { + return nil, errors.Wrap(err, "failed to clone request with body") } resp, err := o.client.Do(reqClone) @@ -272,3 +267,56 @@ func DoExponentialBackoff(req *http.Request, options ...ExponentialBackoffOption return nil, fmt.Errorf("request failed after %d attempts", o.maxRetries+1) } + +func cloneWithBody(req *http.Request) (*http.Request, error) { + newReq := req.Clone(req.Context()) + if req.Body == nil { + return newReq, nil + } + if req.GetBody != nil { + var err error + newReq.Body, err = req.GetBody() + if err != nil { + return nil, err + } + return newReq, nil + } + + if seeker, ok := req.Body.(io.Seeker); ok { + if _, err := seeker.Seek(0, io.SeekStart); err != nil { + return nil, err + } + newReq.Body = req.Body + newReq.GetBody = func() (io.ReadCloser, error) { + if _, err := seeker.Seek(0, io.SeekStart); err != nil { + return nil, err + } + return req.Body, nil + } + return newReq, nil + } + + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + if err := req.Body.Close(); err != nil { + return nil, err + } + + createBody := func(bodyBytes []byte) func() (io.ReadCloser, error) { + return func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(bodyBytes)), nil + } + } + + req.GetBody = createBody(bodyBytes) + req.Body, _ = req.GetBody() + req.ContentLength = int64(len(bodyBytes)) + + newReq.GetBody = createBody(bodyBytes) + newReq.Body, _ = newReq.GetBody() + newReq.ContentLength = int64(len(bodyBytes)) + + return newReq, nil +} diff --git a/httputilx/httputilx_test.go b/httputilx/httputilx_test.go index 89df920..3ae29b4 100644 --- a/httputilx/httputilx_test.go +++ b/httputilx/httputilx_test.go @@ -230,6 +230,7 @@ func TestDoExponentialBackoff(t *testing.T) { name string options []ExponentialBackoffOption handler http.HandlerFunc + requestBody io.Reader wantBody string wantErr string wantAttempts int @@ -323,6 +324,55 @@ func TestDoExponentialBackoff(t *testing.T) { wantErr: "", wantAttempts: 3, }, + { + name: "RequestBodyCopiedOnRetry", + options: []ExponentialBackoffOption{ + ExponentialBackoffWithConfig(4, 100*time.Millisecond, 5*time.Second, 2.0), + }, + handler: func() http.HandlerFunc { + initialBody := "request body content" + + attempts := 0 + return func(w http.ResponseWriter, r *http.Request) { + attempts++ + + if r.ContentLength != int64(len(initialBody)) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = fmt.Fprintf(w, "wrong content-length: got %d, want %d", r.ContentLength, len(initialBody)) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + if len(body) != len(initialBody) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = fmt.Fprintf(w, "content-length mismatch: header=%d actual=%d", r.ContentLength, len(body)) + return + } + + // Verify body is correctly sent on all attempts + if string(body) != initialBody { + w.WriteHeader(http.StatusInternalServerError) + _, _ = fmt.Fprintf(w, "incorrect body: %q", string(body)) + return + } + if attempts < 3 { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("body received correctly")) + } + }(), + requestBody: io.NopCloser(bytes.NewBuffer([]byte("request body content"))), + wantBody: "body received correctly", + wantErr: "", + wantAttempts: 3, + }, } for _, tt := range tests { @@ -334,7 +384,11 @@ func TestDoExponentialBackoff(t *testing.T) { })) defer ts.Close() - req, err := http.NewRequest(http.MethodGet, ts.URL, nil) + method := http.MethodGet + if tt.requestBody != nil { + method = http.MethodPost + } + req, err := http.NewRequest(method, ts.URL, tt.requestBody) if err != nil { t.Fatalf("failed to create request: %v", err) }