diff --git a/firestore/pipeline.go b/firestore/pipeline.go index c269e1e8bd28..6eef31b5ae6e 100644 --- a/firestore/pipeline.go +++ b/firestore/pipeline.go @@ -822,3 +822,94 @@ func (p *Pipeline) RawStage(name string, args []any, opts ...RawStageOptions) *P } return p.append(stage) } + +// UpdateOption is an option for an Update pipeline stage. +// +// Experimental: Firestore Pipelines is currently in preview and is subject to potential breaking changes in future versions, +// regardless of any other documented package stability guarantees. +type UpdateOption interface { + isUpdateOption() +} + +type updateTransformationsOption struct { + fields []Selectable +} + +func (updateTransformationsOption) isUpdateOption() {} + +// WithUpdateTransformations specifies the list of field transformations to apply in an update operation. +// +// Experimental: Firestore Pipelines is currently in preview and is subject to potential breaking changes in future versions, +// regardless of any other documented package stability guarantees. +func WithUpdateTransformations(field Selectable, additionalFields ...Selectable) UpdateOption { + return updateTransformationsOption{ + fields: append([]Selectable{field}, additionalFields...), + } +} + +// Update performs an update operation using documents from previous stages. +// +// This method updates the documents in place based on the data flowing through the pipeline. +// You can optionally specify a list of [Selectable] field transformations using [WithUpdateTransformations]. +// If no transformations are provided, it performs the update in-place without any changes. +// +// Example: +// +// // In-place update +// client.Pipeline().Literals(updateData).Update() +// +// // Update with transformations +// client.Pipeline().Collection("books"). +// Where(GreaterThan("price", 50)). +// Update(WithUpdateTransformations(ConstantOf("Discounted").As("status"))) +// +// Experimental: Firestore Pipelines is currently in preview and is subject to potential breaking changes in future versions, +// regardless of any other documented package stability guarantees. +func (p *Pipeline) Update(opts ...UpdateOption) *Pipeline { + if p.err != nil { + return p + } + + var transformations []Selectable + for _, opt := range opts { + if opt != nil { + switch o := opt.(type) { + case updateTransformationsOption: + transformations = append(transformations, o.fields...) + } + } + } + + stage, err := newUpdateStage(transformations) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + +// DeleteOption is an option for a Delete pipeline stage. +// +// Experimental: Firestore Pipelines is currently in preview and is subject to potential breaking changes in future versions, +// regardless of any other documented package stability guarantees. +type DeleteOption interface { + isDeleteOption() +} + +// Delete deletes the documents from previous stages. +// +// Example: +// +// client.Pipeline().Collection("logs"). +// Where(Equal("status", "archived")). +// Delete() +// +// Experimental: Firestore Pipelines is currently in preview and is subject to potential breaking changes in future versions, +// regardless of any other documented package stability guarantees. +func (p *Pipeline) Delete(opts ...DeleteOption) *Pipeline { + if p.err != nil { + return p + } + stage := newDeleteStage() + return p.append(stage) +} diff --git a/firestore/pipeline_integration_test.go b/firestore/pipeline_integration_test.go index da7da4671d79..0e9804561055 100644 --- a/firestore/pipeline_integration_test.go +++ b/firestore/pipeline_integration_test.go @@ -859,6 +859,47 @@ func TestIntegration_PipelineStages(t *testing.T) { t.Errorf("got %d documents, want 4", len(results)) } }) + t.Run("Update", func(t *testing.T) { + t.Skip("Skipping test until feature is available in PROD") + updateIter := client.Pipeline().Collection(coll.ID). + Where(Equal(FieldOf("author.country"), "UK")). + Update(WithUpdateTransformations(ConstantOf("Active").As("status"))). + Execute(ctx).Results() + defer updateIter.Stop() + _, err := updateIter.GetAll() + if err != nil { + t.Fatalf("Failed to execute update: %v", err) + } + + verifyIter := client.Pipeline().Collection(coll.ID).Where(Equal(FieldOf("status"), "Active")).Execute(ctx).Results() + defer verifyIter.Stop() + results, err := verifyIter.GetAll() + if err != nil { + t.Fatalf("Failed to execute verify: %v", err) + } + if len(results) != 4 { + t.Errorf("got %d updated documents, want 4", len(results)) + } + }) + t.Run("Delete", func(t *testing.T) { + t.Skip("Skipping test until feature is available in PROD") + deleteIter := client.Pipeline().Collection(coll.ID).Where(Equal(FieldOf("title"), "The Great Gatsby")).Delete().Execute(ctx).Results() + defer deleteIter.Stop() + _, err := deleteIter.GetAll() + if err != nil { + t.Fatalf("Failed to execute delete: %v", err) + } + + verifyIter := client.Pipeline().Collection(coll.ID).Where(Equal(FieldOf("title"), "The Great Gatsby")).Execute(ctx).Results() + defer verifyIter.Stop() + results, err := verifyIter.GetAll() + if err != nil { + t.Fatalf("Failed to execute verify: %v", err) + } + if len(results) != 0 { + t.Errorf("got %d documents, want 0 after delete", len(results)) + } + }) } func TestIntegration_PipelineFunctions(t *testing.T) { diff --git a/firestore/pipeline_stage.go b/firestore/pipeline_stage.go index cc9f48995bd4..f88288024b7f 100644 --- a/firestore/pipeline_stage.go +++ b/firestore/pipeline_stage.go @@ -42,6 +42,7 @@ const ( stageNameCollection = "collection" stageNameCollectionGroup = "collection_group" stageNameDatabase = "database" + stageNameDelete = "delete" stageNameDistinct = "distinct" stageNameDocuments = "documents" stageNameFindNearest = "find_nearest" @@ -52,6 +53,7 @@ const ( stageNameSelect = "select" stageNameUnion = "union" stageNameUnnest = "unnest" + stageNameUpdate = "update" stageNameWhere = "where" ) @@ -598,3 +600,46 @@ func (s *rawStage) toProto() (*pb.Pipeline_Stage, error) { Options: optionsPb, }, nil } + +type updateStage struct { + fields []Selectable +} + +func newUpdateStage(fields []Selectable) (*updateStage, error) { + return &updateStage{fields: fields}, nil +} + +func (s *updateStage) name() string { return stageNameUpdate } + +func (s *updateStage) toProto() (*pb.Pipeline_Stage, error) { + var mapVal *pb.Value + if len(s.fields) > 0 { + var err error + mapVal, err = projectionsToMapValue(s.fields) + if err != nil { + return nil, err + } + } else { + mapVal = &pb.Value{ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{}}} + } + + return &pb.Pipeline_Stage{ + Name: s.name(), + Args: []*pb.Value{mapVal}, + }, nil +} + +type deleteStage struct{} + +func newDeleteStage() *deleteStage { + return &deleteStage{} +} + +func (s *deleteStage) name() string { return stageNameDelete } + +func (s *deleteStage) toProto() (*pb.Pipeline_Stage, error) { + return &pb.Pipeline_Stage{ + Name: s.name(), + Args: []*pb.Value{}, + }, nil +} diff --git a/firestore/pipeline_test.go b/firestore/pipeline_test.go index e889545806cd..413adf02c7fe 100644 --- a/firestore/pipeline_test.go +++ b/firestore/pipeline_test.go @@ -434,3 +434,85 @@ func TestPipeline_CreateFromQuery(t *testing.T) { t.Errorf("toExecutePipelineRequest() mismatch for collection stage (-want +got):\n%s", diff) } } + +func TestPipeline_Update(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Collection("users").Update(WithUpdateTransformations(ConstantOf("Active").As("status"))) + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("p.toExecutePipelineRequest() failed: %v", err) + } + + stages := req.GetStructuredPipeline().GetPipeline().GetStages() + if len(stages) != 2 { + t.Fatalf("Expected 2 stages in proto, got %d", len(stages)) + } + + wantUpdateStage := &pb.Pipeline_Stage{ + Name: "update", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "status": {ValueType: &pb.Value_StringValue{StringValue: "Active"}}, + }, + }, + }}, + }, + } + if diff := cmp.Diff(wantUpdateStage, stages[1], protocmp.Transform()); diff != "" { + t.Errorf("toExecutePipelineRequest() mismatch for update stage (-want +got):\n%s", diff) + } +} + +func TestPipeline_Update_Empty(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Collection("users").Update() + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("p.toExecutePipelineRequest() failed: %v", err) + } + + stages := req.GetStructuredPipeline().GetPipeline().GetStages() + if len(stages) != 2 { + t.Fatalf("Expected 2 stages in proto, got %d", len(stages)) + } + + wantUpdateStage := &pb.Pipeline_Stage{ + Name: "update", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{}}}, + }, + } + if diff := cmp.Diff(wantUpdateStage, stages[1], protocmp.Transform()); diff != "" { + t.Errorf("toExecutePipelineRequest() mismatch for update stage (empty args) (-want +got):\n%s", diff) + } +} + +func TestPipeline_Delete(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Collection("users").Delete() + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("p.toExecutePipelineRequest() failed: %v", err) + } + + stages := req.GetStructuredPipeline().GetPipeline().GetStages() + if len(stages) != 2 { + t.Fatalf("Expected 2 stages in proto, got %d", len(stages)) + } + + wantDeleteStage := &pb.Pipeline_Stage{ + Name: "delete", + Args: []*pb.Value{}, + } + if diff := cmp.Diff(wantDeleteStage, stages[1], protocmp.Transform()); diff != "" { + t.Errorf("toExecutePipelineRequest() mismatch for delete stage (-want +got):\n%s", diff) + } +}