diff --git a/firestore/client.go b/firestore/client.go index ab3347bb7df4..35ace3a3678b 100644 --- a/firestore/client.go +++ b/firestore/client.go @@ -22,6 +22,7 @@ import ( "net/url" "os" "strings" + "syscall" "time" vkit "cloud.google.com/go/firestore/apiv1" @@ -324,53 +325,103 @@ func (c *Client) getAll(ctx context.Context, docRefs []*DocumentRef, tid []byte, batchGetDocsCtx := withResourceHeader(ctx, req.Database) batchGetDocsCtx = withRequestParamsHeader(batchGetDocsCtx, reqParamsHeaderVal(c.path())) - streamClient, err := c.c.BatchGetDocuments(batchGetDocsCtx, req) - if err != nil { - return nil, err + var backoff gax.Backoff = gax.Backoff{ + Initial: 100 * time.Millisecond, + Max: 60000 * time.Millisecond, + Multiplier: 1.30, } - // Read and remember all results from the stream. - var resps []*pb.BatchGetDocumentsResponse for { - resp, err := streamClient.Recv() - if err == io.EOF { - break - } - if err != nil { - return nil, err + var streamClient interface { + Recv() (*pb.BatchGetDocumentsResponse, error) } - resps = append(resps, resp) - } - - // Results may arrive out of order. Put each at the right indices. - docs := make([]*DocumentSnapshot, len(docNames)) - for _, resp := range resps { - var ( - indices []int - doc *pb.Document - err error - ) - switch r := resp.Result.(type) { - case *pb.BatchGetDocumentsResponse_Found: - indices = docIndices[r.Found.Name] - doc = r.Found - case *pb.BatchGetDocumentsResponse_Missing: - indices = docIndices[r.Missing] - doc = nil - default: - return nil, errors.New("firestore: unknown BatchGetDocumentsResponse result type") - } - for _, index := range indices { - if docs[index] != nil { - return nil, fmt.Errorf("firestore: %q seen twice", docRefs[index].Path) + streamClient, err = c.c.BatchGetDocuments(batchGetDocsCtx, req) + if err == nil { + + // Read and remember all results from the stream. + var resps []*pb.BatchGetDocumentsResponse + for { + var resp *pb.BatchGetDocumentsResponse + resp, err = streamClient.Recv() + if err == io.EOF { + break + } + if err != nil { + break + } + resps = append(resps, resp) } - docs[index], err = newDocumentSnapshot(docRefs[index], doc, c, resp.ReadTime) - if err != nil { - return nil, err + if err == io.EOF || err == nil { + // Successfully read everything from the stream. + // Results may arrive out of order. Put each at the right indices. + docs := make([]*DocumentSnapshot, len(docNames)) + for _, resp := range resps { + var ( + indices []int + doc *pb.Document + err error + ) + switch r := resp.Result.(type) { + case *pb.BatchGetDocumentsResponse_Found: + indices = docIndices[r.Found.Name] + doc = r.Found + case *pb.BatchGetDocumentsResponse_Missing: + indices = docIndices[r.Missing] + doc = nil + default: + return nil, errors.New("firestore: unknown BatchGetDocumentsResponse result type") + } + for _, index := range indices { + if docs[index] != nil { + return nil, fmt.Errorf("firestore: %q seen twice", docRefs[index].Path) + } + docs[index], err = newDocumentSnapshot(docRefs[index], doc, c, resp.ReadTime) + if err != nil { + return nil, err + } + } + } + return docs, nil } } + + // If we got an error, check if it's retryable. + if !isRetryableError(err) { + return nil, err + } + + // Sleep first using backoff + dur := backoff.Pause() + if sleepErr := gax.Sleep(ctx, dur); sleepErr != nil { + return nil, err // Return original error if context was canceled + } + } +} + +func isRetryableError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) { + return false } - return docs, nil + if errors.Is(err, io.ErrUnexpectedEOF) { + return true + } + if errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.ECONNREFUSED) { + return true + } + + // Check gRPC status codes + if st, ok := status.FromError(err); ok { + return st.Code() == codes.Unavailable + } + + // Fallback for unexported/wrapped network errors using strings (similar to cloud-storage) + errStr := err.Error() + return strings.Contains(errStr, "connection refused") || + strings.Contains(errStr, "connection reset") || + strings.Contains(errStr, "broken pipe") } // Collections returns an iterator over the top-level collections.