diff --git a/firestore/bson_types.go b/firestore/bson_types.go new file mode 100644 index 000000000000..84d30407c4cc --- /dev/null +++ b/firestore/bson_types.go @@ -0,0 +1,53 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +// BSONObjectID represents a BSON ObjectID as a 24-character lowercase hex string. +type BSONObjectID string + +// String returns the string representation of the BSONObjectID. +func (id BSONObjectID) String() string { + return string(id) +} + +// BSONRegex represents a BSON Regular Expression. +type BSONRegex struct { + Pattern string + Options string +} + +// BSONTimestamp represents a BSON Timestamp. +type BSONTimestamp struct { + Seconds uint32 + Increment uint32 +} + +// BSONDecimal128 represents a BSON Decimal128. +type BSONDecimal128 string + +// BSONMinKey represents BSON MinKey. +type BSONMinKey struct{} + +// BSONMaxKey represents BSON MaxKey. +type BSONMaxKey struct{} + +// BSONBinary represents BSON Binary data with subtype != 0. +type BSONBinary struct { + Subtype byte + Data []byte +} + +// BSONInt32 represents a BSON 32-bit integer. +type BSONInt32 int32 diff --git a/firestore/from_value.go b/firestore/from_value.go index 10b434fc866e..7d4fad889a9b 100644 --- a/firestore/from_value.go +++ b/firestore/from_value.go @@ -113,6 +113,62 @@ func setReflectFromProtoValue(vDest reflect.Value, vprotoSrc *pb.Value, c *Clien } vDest.Set(reflect.ValueOf(val)) return nil + case typeOfBSONObjectID: + val, err := bsonObjectIDFromProtoValue(vprotoSrc) + if err != nil { + return err + } + vDest.Set(reflect.ValueOf(val)) + return nil + case typeOfBSONRegex: + val, err := bsonRegexFromProtoValue(vprotoSrc) + if err != nil { + return err + } + vDest.Set(reflect.ValueOf(val)) + return nil + case typeOfBSONTimestamp: + val, err := bsonTimestampFromProtoValue(vprotoSrc) + if err != nil { + return err + } + vDest.Set(reflect.ValueOf(val)) + return nil + case typeOfBSONDecimal128: + val, err := bsonDecimal128FromProtoValue(vprotoSrc) + if err != nil { + return err + } + vDest.Set(reflect.ValueOf(val)) + return nil + case typeOfBSONMinKey: + val, err := bsonMinKeyFromProtoValue(vprotoSrc) + if err != nil { + return err + } + vDest.Set(reflect.ValueOf(val)) + return nil + case typeOfBSONMaxKey: + val, err := bsonMaxKeyFromProtoValue(vprotoSrc) + if err != nil { + return err + } + vDest.Set(reflect.ValueOf(val)) + return nil + case typeOfBSONBinary: + val, err := bsonBinaryFromProtoValue(vprotoSrc) + if err != nil { + return err + } + vDest.Set(reflect.ValueOf(val)) + return nil + case typeOfBSONInt32: + val, err := bsonInt32FromProtoValue(vprotoSrc) + if err != nil { + return err + } + vDest.Set(reflect.ValueOf(val)) + return nil } switch vDest.Kind() { @@ -130,7 +186,7 @@ func setReflectFromProtoValue(vDest reflect.Value, vprotoSrc *pb.Value, c *Clien } vDest.SetString(x.StringValue) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int64: var i int64 switch x := valTypeSrc.(type) { case *pb.Value_IntegerValue: @@ -149,6 +205,31 @@ func setReflectFromProtoValue(vDest reflect.Value, vprotoSrc *pb.Value, c *Clien } vDest.SetInt(i) + case reflect.Int32: + var i int64 + switch x := valTypeSrc.(type) { + case *pb.Value_IntegerValue: + i = x.IntegerValue + case *pb.Value_DoubleValue: + f := x.DoubleValue + i = int64(f) + if float64(i) != f { + return fmt.Errorf("firestore: float %f does not fit into %s", f, vDest.Type()) + } + case *pb.Value_MapValue: + val, err := bsonInt32FromProtoValue(vprotoSrc) + if err != nil { + return err + } + i = int64(val) + default: + return typeErr() + } + if vDest.OverflowInt(i) { + return overflowErr(vDest, i) + } + vDest.SetInt(i) + case reflect.Uint8, reflect.Uint16, reflect.Uint32: var u uint64 switch x := valTypeSrc.(type) { @@ -415,13 +496,17 @@ func createFromProtoValue(vproto *pb.Value, c *Client) (interface{}, error) { } typeVal, ok := ret[typeKey] - if !ok || typeVal != typeValVector { - // Map is not a vector. Return the map - return ret, nil + if ok && typeVal == typeValVector { + return vector64FromProtoValue(vproto) } - // Special handling for vector - return vector64FromProtoValue(vproto) + if bsonVal, ok, err := tryConvertMapToBSONType(ret); err != nil { + return nil, err + } else if ok { + return bsonVal, nil + } + + return ret, nil default: return nil, fmt.Errorf("firestore: unknown value type %T", v) } @@ -483,3 +568,255 @@ func typeString(vproto *pb.Value) string { func overflowErr(v reflect.Value, x interface{}) error { return fmt.Errorf("firestore: value %v overflows type %s", x, v.Type()) } + +func bsonObjectIDFromProtoValue(v *pb.Value) (BSONObjectID, error) { + var id BSONObjectID + m, err := assertMapWithValueKey(v, "__oid__") + if err != nil { + return id, err + } + s, err := stringFromProtoValue(m["__oid__"]) + if err != nil { + return id, err + } + return BSONObjectID(s), nil +} + +func bsonRegexFromProtoValue(v *pb.Value) (BSONRegex, error) { + var r BSONRegex + m, err := assertMapWithValueKey(v, "__regex__") + if err != nil { + return r, err + } + regexMapVal := m["__regex__"] + regexMap, ok := regexMapVal.ValueType.(*pb.Value_MapValue) + if !ok { + return r, fmt.Errorf("firestore: failed to convert regex value %v to map", regexMapVal.ValueType) + } + rm := regexMap.MapValue.Fields + pattern, err := stringFromProtoValue(rm["pattern"]) + if err != nil { + return r, err + } + options, err := stringFromProtoValue(rm["options"]) + if err != nil { + return r, err + } + r = BSONRegex{Pattern: pattern, Options: options} + return r, nil +} + +func bsonTimestampFromProtoValue(v *pb.Value) (BSONTimestamp, error) { + var t BSONTimestamp + m, err := assertMapWithValueKey(v, "__request_timestamp__") + if err != nil { + return t, err + } + tsMapVal := m["__request_timestamp__"] + tsMap, ok := tsMapVal.ValueType.(*pb.Value_MapValue) + if !ok { + return t, fmt.Errorf("firestore: failed to convert timestamp value %v to map", tsMapVal.ValueType) + } + tm := tsMap.MapValue.Fields + + secondsVal, ok := tm["seconds"] + if !ok { + return t, fmt.Errorf("firestore: seconds missing in timestamp") + } + sv, ok := secondsVal.ValueType.(*pb.Value_IntegerValue) + if !ok { + return t, fmt.Errorf("firestore: seconds is not integer: %v", secondsVal.ValueType) + } + + incrementVal, ok := tm["increment"] + if !ok { + return t, fmt.Errorf("firestore: increment missing in timestamp") + } + iv, ok := incrementVal.ValueType.(*pb.Value_IntegerValue) + if !ok { + return t, fmt.Errorf("firestore: increment is not integer: %v", incrementVal.ValueType) + } + + if sv.IntegerValue < 0 || sv.IntegerValue > 0xffffffff { + return t, fmt.Errorf("firestore: BSON timestamp seconds out of range: %d", sv.IntegerValue) + } + if iv.IntegerValue < 0 || iv.IntegerValue > 0xffffffff { + return t, fmt.Errorf("firestore: BSON timestamp increment out of range: %d", iv.IntegerValue) + } + + return BSONTimestamp{Seconds: uint32(sv.IntegerValue), Increment: uint32(iv.IntegerValue)}, nil +} + +func bsonDecimal128FromProtoValue(v *pb.Value) (BSONDecimal128, error) { + var d BSONDecimal128 + m, err := assertMapWithValueKey(v, "__decimal128__") + if err != nil { + return d, err + } + s, err := stringFromProtoValue(m["__decimal128__"]) + if err != nil { + return d, err + } + d = BSONDecimal128(s) + return d, nil +} + +func bsonMinKeyFromProtoValue(v *pb.Value) (BSONMinKey, error) { + var m BSONMinKey + _, err := assertMapWithValueKey(v, "__min__") + if err != nil { + return m, err + } + return m, nil +} + +func bsonMaxKeyFromProtoValue(v *pb.Value) (BSONMaxKey, error) { + var m BSONMaxKey + _, err := assertMapWithValueKey(v, "__max__") + if err != nil { + return m, err + } + return m, nil +} + +func bsonBinaryFromProtoValue(v *pb.Value) (BSONBinary, error) { + var b BSONBinary + m, err := assertMapWithValueKey(v, "__binary__") + if err != nil { + return b, err + } + payloadVal := m["__binary__"] + bv, ok := payloadVal.ValueType.(*pb.Value_BytesValue) + if !ok { + return b, fmt.Errorf("firestore: binary value is not bytes: %v", payloadVal.ValueType) + } + payload := bv.BytesValue + if len(payload) == 0 { + return b, fmt.Errorf("firestore: empty binary payload") + } + return BSONBinary{Subtype: payload[0], Data: payload[1:]}, nil +} + +func bsonInt32FromProtoValue(v *pb.Value) (BSONInt32, error) { + m, err := assertMapWithValueKey(v, "__int__") + if err != nil { + return 0, err + } + intVal := m["__int__"] + iv, ok := intVal.ValueType.(*pb.Value_IntegerValue) + if !ok { + return 0, fmt.Errorf("firestore: int32 value is not integer: %v", intVal.ValueType) + } + if iv.IntegerValue < -2147483648 || iv.IntegerValue > 2147483647 { + return 0, fmt.Errorf("firestore: int32 value out of range: %d", iv.IntegerValue) + } + return BSONInt32(iv.IntegerValue), nil +} + +func assertMapWithValueKey(v *pb.Value, key string) (map[string]*pb.Value, error) { + if v == nil { + return nil, fmt.Errorf("firestore: value is nil") + } + pbMap, ok := v.ValueType.(*pb.Value_MapValue) + if !ok { + return nil, fmt.Errorf("firestore: cannot convert %v to *pb.Value_MapValue", v.ValueType) + } + m := pbMap.MapValue.Fields + if _, ok := m[key]; !ok { + return nil, fmt.Errorf("firestore: missing key %q in map %v", key, m) + } + return m, nil +} + +func tryConvertMapToBSONType(m map[string]interface{}) (interface{}, bool, error) { + if len(m) != 1 { + return nil, false, nil + } + for k, v := range m { + switch k { + case "__oid__": + s, ok := v.(string) + if !ok { + return nil, false, fmt.Errorf("firestore: __oid__ value is not string: %T", v) + } + return BSONObjectID(s), true, nil + + case "__regex__": + subMap, ok := v.(map[string]interface{}) + if !ok { + return nil, false, fmt.Errorf("firestore: __regex__ value is not map: %T", v) + } + pattern, ok := subMap["pattern"].(string) + if !ok { + return nil, false, fmt.Errorf("firestore: regex pattern is not string") + } + options, ok := subMap["options"].(string) + if !ok { + return nil, false, fmt.Errorf("firestore: regex options is not string") + } + r := BSONRegex{Pattern: pattern, Options: options} + return r, true, nil + + case "__int__": + i, ok := v.(int64) + if !ok { + return nil, false, fmt.Errorf("firestore: __int__ value is not int64: %T", v) + } + if i < -2147483648 || i > 2147483647 { + return nil, false, fmt.Errorf("firestore: BSON int32 value out of range: %d", i) + } + return BSONInt32(i), true, nil + + case "__request_timestamp__": + subMap, ok := v.(map[string]interface{}) + if !ok { + return nil, false, fmt.Errorf("firestore: __request_timestamp__ value is not map: %T", v) + } + seconds, ok := subMap["seconds"].(int64) + if !ok { + return nil, false, fmt.Errorf("firestore: timestamp seconds is not int64") + } + increment, ok := subMap["increment"].(int64) + if !ok { + return nil, false, fmt.Errorf("firestore: timestamp increment is not int64") + } + if seconds < 0 || seconds > 0xffffffff { + return nil, false, fmt.Errorf("firestore: BSON timestamp seconds out of range: %d", seconds) + } + if increment < 0 || increment > 0xffffffff { + return nil, false, fmt.Errorf("firestore: BSON timestamp increment out of range: %d", increment) + } + return BSONTimestamp{Seconds: uint32(seconds), Increment: uint32(increment)}, true, nil + + case "__decimal128__": + s, ok := v.(string) + if !ok { + return nil, false, fmt.Errorf("firestore: __decimal128__ value is not string: %T", v) + } + return BSONDecimal128(s), true, nil + + case "__min__": + if v != nil { + return nil, false, fmt.Errorf("firestore: __min__ value must be null") + } + return BSONMinKey{}, true, nil + + case "__max__": + if v != nil { + return nil, false, fmt.Errorf("firestore: __max__ value must be null") + } + return BSONMaxKey{}, true, nil + + case "__binary__": + b, ok := v.([]byte) + if !ok { + return nil, false, fmt.Errorf("firestore: __binary__ value is not bytes: %T", v) + } + if len(b) == 0 { + return nil, false, fmt.Errorf("firestore: empty binary payload") + } + return BSONBinary{Subtype: b[0], Data: b[1:]}, true, nil + } + } + return nil, false, nil +} diff --git a/firestore/from_value_test.go b/firestore/from_value_test.go index 4325a3be8f96..d276eae17da9 100644 --- a/firestore/from_value_test.go +++ b/firestore/from_value_test.go @@ -700,3 +700,109 @@ func TestPopulateMap(t *testing.T) { }) } } + +func TestBSONTypes_RoundTrip(t *testing.T) { + oid := BSONObjectID("0123456789abcdef01234567") + + tests := []struct { + desc string + val interface{} + pb *pb.Value + }{ + { + desc: "BSONObjectID", + val: oid, + pb: mapval(map[string]*pb.Value{ + "__oid__": strval("0123456789abcdef01234567"), + }), + }, + { + desc: "BSONRegex", + val: BSONRegex{Pattern: "foo", Options: "im"}, + pb: mapval(map[string]*pb.Value{ + "__regex__": mapval(map[string]*pb.Value{ + "pattern": strval("foo"), + "options": strval("im"), + }), + }), + }, + { + desc: "BSONTimestamp", + val: BSONTimestamp{Seconds: 123, Increment: 456}, + pb: mapval(map[string]*pb.Value{ + "__request_timestamp__": mapval(map[string]*pb.Value{ + "seconds": int64val(123), + "increment": int64val(456), + }), + }), + }, + { + desc: "BSONDecimal128", + val: BSONDecimal128("123.45"), + pb: mapval(map[string]*pb.Value{ + "__decimal128__": strval("123.45"), + }), + }, + { + desc: "BSONMinKey", + val: BSONMinKey{}, + pb: mapval(map[string]*pb.Value{ + "__min__": nullValue, + }), + }, + { + desc: "BSONMaxKey", + val: BSONMaxKey{}, + pb: mapval(map[string]*pb.Value{ + "__max__": nullValue, + }), + }, + { + desc: "BSONBinary", + val: BSONBinary{Subtype: 0x02, Data: []byte{1, 2, 3}}, + pb: mapval(map[string]*pb.Value{ + "__binary__": bytesval([]byte{0x02, 1, 2, 3}), + }), + }, + { + desc: "BSONInt32", + val: BSONInt32(42), + pb: mapval(map[string]*pb.Value{ + "__int__": int64val(42), + }), + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + // Test serialization + gotPB, _, err := toProtoValue(reflect.ValueOf(test.val)) + if err != nil { + t.Fatalf("toProtoValue failed: %v", err) + } + if !testEqual(gotPB, test.pb) { + t.Fatalf("toProtoValue got:\n%v\nwant:\n%v", gotPB, test.pb) + } + + // Test deserialization (specific type) + dest := reflect.New(reflect.TypeOf(test.val)).Interface() + err = setFromProtoValue(dest, gotPB, nil) + if err != nil { + t.Fatalf("setFromProtoValue failed: %v", err) + } + gotVal := reflect.ValueOf(dest).Elem().Interface() + if !testEqual(gotVal, test.val) { + t.Fatalf("setFromProtoValue got:\n%v\nwant:\n%v", gotVal, test.val) + } + + // Test deserialization (generic interface{}) + gotInterface, err := createFromProtoValue(gotPB, nil) + if err != nil { + t.Fatalf("createFromProtoValue failed: %v", err) + } + if !testEqual(gotInterface, test.val) { + t.Fatalf("createFromProtoValue got:\n%v (%T)\nwant:\n%v (%T)", gotInterface, gotInterface, test.val, test.val) + } + }) + } +} diff --git a/firestore/integration_test.go b/firestore/integration_test.go index 2018291d44c8..5809f4714b86 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -3587,6 +3587,78 @@ func TestIntegration_FindNearest(t *testing.T) { } } +func TestIntegration_BSONTypes(t *testing.T) { + skipIfEdition(t, "BSON types", editionStandard) + t.Skip("Temporarily skipping BSON integration test. Not yet released to prod.") + ctx := context.Background() + coll := integrationColl(t) + doc := coll.NewDoc() + t.Cleanup(func() { + deleteDocuments([]*DocumentRef{doc}) + }) + + oid := BSONObjectID("0123456789abcdef01234567") + + data := map[string]interface{}{ + "oid": oid, + "regex": BSONRegex{Pattern: "foo", Options: "im"}, + "timestamp": BSONTimestamp{Seconds: 123, Increment: 456}, + "decimal128": BSONDecimal128("123.45"), + "minkey": BSONMinKey{}, + "maxkey": BSONMaxKey{}, + "binary": BSONBinary{Subtype: 0x02, Data: []byte{1, 2, 3}}, + "bson_int": BSONInt32(42), + } + + _, err := doc.Create(ctx, data) + if err != nil { + t.Fatalf("failed to create doc with BSON types: %v", err) + } + + // If write succeeded, we try to read back and verify. + ds, err := doc.Get(ctx) + if err != nil { + t.Fatalf("failed to get doc: %v", err) + } + + got := ds.Data() + if !testEqual(got, data) { + t.Errorf("got vs want diff:\n%s", testDiff(got, data)) + } + + // Also test decoding into a struct. + type bsonStruct struct { + Oid BSONObjectID `firestore:"oid"` + Regex BSONRegex `firestore:"regex"` + Timestamp BSONTimestamp `firestore:"timestamp"` + Decimal128 BSONDecimal128 `firestore:"decimal128"` + MinKey BSONMinKey `firestore:"minkey"` + MaxKey BSONMaxKey `firestore:"maxkey"` + Binary BSONBinary `firestore:"binary"` + BsonInt BSONInt32 `firestore:"bson_int"` + } + + var gotStruct bsonStruct + if err := ds.DataTo(&gotStruct); err != nil { + t.Fatalf("DataTo failed: %v", err) + } + + wantStruct := bsonStruct{ + Oid: oid, + Regex: BSONRegex{Pattern: "foo", Options: "im"}, + Timestamp: BSONTimestamp{Seconds: 123, Increment: 456}, + Decimal128: BSONDecimal128("123.45"), + MinKey: BSONMinKey{}, + MaxKey: BSONMaxKey{}, + Binary: BSONBinary{Subtype: 0x02, Data: []byte{1, 2, 3}}, + BsonInt: BSONInt32(42), + } + + if !testEqual(gotStruct, wantStruct) { + t.Errorf("got struct vs want struct diff:\n%s", testDiff(gotStruct, wantStruct)) + } +} + func TestIntegration_TransactionReadTime(t *testing.T) { ctx := context.Background() c := integrationClient(t) diff --git a/firestore/order.go b/firestore/order.go index 489936901ef0..601538f8380b 100644 --- a/firestore/order.go +++ b/firestore/order.go @@ -19,12 +19,33 @@ import ( "fmt" "math" "sort" + "strconv" "strings" pb "cloud.google.com/go/firestore/apiv1/firestorepb" tspb "google.golang.org/protobuf/types/known/timestamppb" ) +const ( + typeOrderNull = iota + typeOrderMinKey + typeOrderBoolean + typeOrderNumber + typeOrderTimestamp + typeOrderBSONTimestamp + typeOrderString + typeOrderBlob + typeOrderBSONBinary + typeOrderRef + typeOrderBSONObjectID + typeOrderGeoPoint + typeOrderRegex + typeOrderArray + typeOrderVector + typeOrderObject + typeOrderMaxKey +) + // Returns a negative number, zero, or a positive number depending on whether a is // less than, equal to, or greater than b according to Firestore's ordering of // values. @@ -34,6 +55,15 @@ func compareValues(a, b *pb.Value) int { if ta != tb { return compareInt64s(int64(ta), int64(tb)) } + if ta == typeOrderNumber { + return compareNumbers(extractNumber(a), extractNumber(b)) + } + if ta == typeOrderBSONTimestamp { + return compareBSONTimestamps(a, b) + } + if ta == typeOrderRegex { + return compareBSONRegexes(a, b) + } switch a := a.ValueType.(type) { case *pb.Value_NullValue: return 0 // nulls are equal @@ -189,27 +219,49 @@ func compareInt64s(a, b int64) int { func typeOrder(v *pb.Value) int { switch v.ValueType.(type) { case *pb.Value_NullValue: - return 0 + return typeOrderNull case *pb.Value_BooleanValue: - return 1 - case *pb.Value_IntegerValue: - return 2 - case *pb.Value_DoubleValue: - return 2 + return typeOrderBoolean + case *pb.Value_IntegerValue, *pb.Value_DoubleValue: + return typeOrderNumber case *pb.Value_TimestampValue: - return 3 + return typeOrderTimestamp case *pb.Value_StringValue: - return 4 + return typeOrderString case *pb.Value_BytesValue: - return 5 + return typeOrderBlob case *pb.Value_ReferenceValue: - return 6 + return typeOrderRef case *pb.Value_GeoPointValue: - return 7 + return typeOrderGeoPoint case *pb.Value_ArrayValue: - return 8 + return typeOrderArray case *pb.Value_MapValue: - return 9 + if isBSONMinKey(v) { + return typeOrderMinKey + } + if isBSONMaxKey(v) { + return typeOrderMaxKey + } + if isBSONInt32(v) || isBSONDecimal128(v) { + return typeOrderNumber + } + if isBSONTimestamp(v) { + return typeOrderBSONTimestamp + } + if isBSONBinary(v) { + return typeOrderBSONBinary + } + if isBSONObjectID(v) { + return typeOrderBSONObjectID + } + if isBSONRegex(v) { + return typeOrderRegex + } + if isVector(v) { + return typeOrderVector + } + return typeOrderObject default: panic(fmt.Sprintf("bad value type: %v", v)) } @@ -221,3 +273,175 @@ type byFirestoreValue []*pb.Value func (a byFirestoreValue) Len() int { return len(a) } func (a byFirestoreValue) Swap(i, j int) { a[i], a[j] = a[j], a[i] } func (a byFirestoreValue) Less(i, j int) bool { return compareValues(a[i], a[j]) < 0 } + +func isMapWithSingleKey(v *pb.Value, key string) (*pb.Value, bool) { + mv, ok := v.ValueType.(*pb.Value_MapValue) + if !ok { + return nil, false + } + fields := mv.MapValue.Fields + if len(fields) != 1 { + return nil, false + } + val, ok := fields[key] + return val, ok +} + +func isBSONMinKey(v *pb.Value) bool { + val, ok := isMapWithSingleKey(v, "__min__") + if !ok { + return false + } + _, ok = val.ValueType.(*pb.Value_NullValue) + return ok +} + +func isBSONMaxKey(v *pb.Value) bool { + val, ok := isMapWithSingleKey(v, "__max__") + if !ok { + return false + } + _, ok = val.ValueType.(*pb.Value_NullValue) + return ok +} + +func isBSONInt32(v *pb.Value) bool { + val, ok := isMapWithSingleKey(v, "__int__") + if !ok { + return false + } + _, ok = val.ValueType.(*pb.Value_IntegerValue) + return ok +} + +func isBSONDecimal128(v *pb.Value) bool { + val, ok := isMapWithSingleKey(v, "__decimal128__") + if !ok { + return false + } + _, ok = val.ValueType.(*pb.Value_StringValue) + return ok +} + +func isBSONObjectID(v *pb.Value) bool { + val, ok := isMapWithSingleKey(v, "__oid__") + if !ok { + return false + } + _, ok = val.ValueType.(*pb.Value_StringValue) + return ok +} + +func isBSONBinary(v *pb.Value) bool { + val, ok := isMapWithSingleKey(v, "__binary__") + if !ok { + return false + } + _, ok = val.ValueType.(*pb.Value_BytesValue) + return ok +} + +func isBSONTimestamp(v *pb.Value) bool { + val, ok := isMapWithSingleKey(v, "__request_timestamp__") + if !ok { + return false + } + tsMap, ok := val.ValueType.(*pb.Value_MapValue) + if !ok { + return false + } + tf := tsMap.MapValue.Fields + if len(tf) != 2 { + return false + } + _, ok1 := tf["seconds"].ValueType.(*pb.Value_IntegerValue) + _, ok2 := tf["increment"].ValueType.(*pb.Value_IntegerValue) + return ok1 && ok2 +} + +func isBSONRegex(v *pb.Value) bool { + val, ok := isMapWithSingleKey(v, "__regex__") + if !ok { + return false + } + regexMap, ok := val.ValueType.(*pb.Value_MapValue) + if !ok { + return false + } + rf := regexMap.MapValue.Fields + if len(rf) != 2 { + return false + } + _, ok1 := rf["pattern"].ValueType.(*pb.Value_StringValue) + _, ok2 := rf["options"].ValueType.(*pb.Value_StringValue) + return ok1 && ok2 +} + +func isVector(v *pb.Value) bool { + mv, ok := v.ValueType.(*pb.Value_MapValue) + if !ok { + return false + } + fields := mv.MapValue.Fields + typeVal, ok := fields["__type__"] + if !ok { + return false + } + sv, ok := typeVal.ValueType.(*pb.Value_StringValue) + return ok && sv.StringValue == "__vector__" +} + +func extractNumber(v *pb.Value) float64 { + switch x := v.ValueType.(type) { + case *pb.Value_IntegerValue: + return float64(x.IntegerValue) + case *pb.Value_DoubleValue: + return x.DoubleValue + case *pb.Value_MapValue: + if val, ok := isMapWithSingleKey(v, "__int__"); ok { + return float64(val.GetIntegerValue()) + } + if val, ok := isMapWithSingleKey(v, "__decimal128__"); ok { + f, err := strconv.ParseFloat(val.GetStringValue(), 64) + if err != nil { + return math.NaN() + } + return f + } + } + return 0 +} + +func compareBSONTimestamps(a, b *pb.Value) int { + valA, _ := isMapWithSingleKey(a, "__request_timestamp__") + valB, _ := isMapWithSingleKey(b, "__request_timestamp__") + + mapA := valA.GetMapValue().Fields + mapB := valB.GetMapValue().Fields + + secA := mapA["seconds"].GetIntegerValue() + secB := mapB["seconds"].GetIntegerValue() + if c := compareInt64s(secA, secB); c != 0 { + return c + } + incA := mapA["increment"].GetIntegerValue() + incB := mapB["increment"].GetIntegerValue() + return compareInt64s(incA, incB) +} + +func compareBSONRegexes(a, b *pb.Value) int { + valA, _ := isMapWithSingleKey(a, "__regex__") + valB, _ := isMapWithSingleKey(b, "__regex__") + + mapA := valA.GetMapValue().Fields + mapB := valB.GetMapValue().Fields + + patA := mapA["pattern"].GetStringValue() + patB := mapB["pattern"].GetStringValue() + if c := strings.Compare(patA, patB); c != 0 { + return c + } + optA := mapA["options"].GetStringValue() + optB := mapB["options"].GetStringValue() + return strings.Compare(optA, optB) +} diff --git a/firestore/order_test.go b/firestore/order_test.go index 696af19b3daa..3cb5337f2563 100644 --- a/firestore/order_test.go +++ b/firestore/order_test.go @@ -27,16 +27,19 @@ func TestCompareValues(t *testing.T) { // Ordered list of values. vals := []*pb.Value{ nullValue, + bsonMinKey(), // BSON MinKey boolval(false), boolval(true), floatval(math.NaN()), floatval(math.Inf(-1)), floatval(-math.MaxFloat64), int64val(math.MinInt64), + bsonInt32val(-2), // BSON Int32 floatval(-1.1), intval(-1), intval(0), floatval(math.SmallestNonzeroFloat64), + bsonDecimal128val("0.5"), // BSON Decimal128 intval(1), floatval(1.1), intval(2), @@ -45,6 +48,7 @@ func TestCompareValues(t *testing.T) { floatval(math.Inf(1)), tsval(time.Date(2016, 5, 20, 10, 20, 0, 0, time.UTC)), tsval(time.Date(2016, 10, 21, 15, 32, 0, 0, time.UTC)), + bsonTimestampval(1477063920, 0), // BSON Timestamp (2016-10-21 15:32:00 UTC) strval(""), strval("\u0000\ud7ff\ue000\uffff"), strval("(╯°□°)╯︵ ┻━┻"), @@ -58,6 +62,7 @@ func TestCompareValues(t *testing.T) { bytesval([]byte{0, 1, 2, 3, 4}), bytesval([]byte{0, 1, 2, 4, 3}), bytesval([]byte{255}), + bsonBinaryval(2, []byte{1}), // BSON Binary (subtype 2) refval("projects/p1/databases/d1/documents/c1/doc1"), refval("projects/p1/databases/d1/documents/c1/doc2"), refval("projects/p1/databases/d1/documents/c1/doc2/c2/doc1"), @@ -66,6 +71,7 @@ func TestCompareValues(t *testing.T) { refval("projects/p1/databases/dkkkkklkjnjkkk1/documents/c2/doc1"), refval("projects/p2/databases/d2/documents/c1/doc1"), refval("projects/p2/databases/d2/documents/c1-/doc1"), + bsonObjectIDval("0123456789abcdef01234567"), // BSON ObjectID geopoint(-90, -180), geopoint(-90, 0), geopoint(-90, 180), @@ -78,6 +84,7 @@ func TestCompareValues(t *testing.T) { geopoint(90, -180), geopoint(90, 0), geopoint(90, 180), + bsonRegexval("foo", "im"), // BSON Regex arrayval(), arrayval(strval("bar")), arrayval(strval("foo")), @@ -89,6 +96,7 @@ func TestCompareValues(t *testing.T) { mapval(map[string]*pb.Value{"foo": intval(1)}), mapval(map[string]*pb.Value{"foo": intval(2)}), mapval(map[string]*pb.Value{"foo": strval("0")}), + bsonMaxKey(), // BSON MaxKey } for i, v1 := range vals { @@ -100,7 +108,7 @@ func TestCompareValues(t *testing.T) { t.Errorf("compare(%v, %v) == %d, want -1", v1, v2, got) } if got := compareValues(v2, v1); got != 1 { - t.Errorf("compare(%v, %v) == %d, want 1", v1, v2, got) + t.Errorf("compare(%v, %v) == %d, want 1", v2, v1, got) } } } @@ -111,8 +119,85 @@ func TestCompareValues(t *testing.T) { if got := compareValues(n1, n2); got != 0 { t.Errorf("compare(%v, %v) == %d, want 0", n1, n2, got) } + + // BSON Int32 and Decimal128 order same as other numbers. + bn1 := bsonInt32val(17) + bn2 := bsonDecimal128val("17.0") + if got := compareValues(bn1, n1); got != 0 { + t.Errorf("compare(%v, %v) == %d, want 0", bn1, n1, got) + } + if got := compareValues(bn2, n2); got != 0 { + t.Errorf("compare(%v, %v) == %d, want 0", bn2, n2, got) + } + if got := compareValues(bn1, bn2); got != 0 { + t.Errorf("compare(%v, %v) == %d, want 0", bn1, bn2, got) + } + + // Decimal128 NaN orders same as float NaN + nan1 := floatval(math.NaN()) + nan2 := bsonDecimal128val("NaN") + if got := compareValues(nan1, nan2); got != 0 { + t.Errorf("compare(%v, %v) == %d, want 0", nan1, nan2, got) + } } func geopoint(lat, lng float64) *pb.Value { return geoval(&latlng.LatLng{Latitude: lat, Longitude: lng}) } + +func bsonMinKey() *pb.Value { + return &pb.Value{ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "__min__": nullValue, + }}}} +} + +func bsonMaxKey() *pb.Value { + return &pb.Value{ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "__max__": nullValue, + }}}} +} + +func bsonInt32val(i int32) *pb.Value { + return &pb.Value{ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "__int__": int64val(int64(i)), + }}}} +} + +func bsonDecimal128val(s string) *pb.Value { + return &pb.Value{ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "__decimal128__": strval(s), + }}}} +} + +func bsonTimestampval(seconds, increment int64) *pb.Value { + return &pb.Value{ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "__request_timestamp__": mapval(map[string]*pb.Value{ + "seconds": int64val(seconds), + "increment": int64val(increment), + }), + }}}} +} + +func bsonBinaryval(subtype byte, data []byte) *pb.Value { + payload := make([]byte, len(data)+1) + payload[0] = subtype + copy(payload[1:], data) + return &pb.Value{ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "__binary__": bytesval(payload), + }}}} +} + +func bsonObjectIDval(hexStr string) *pb.Value { + return &pb.Value{ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "__oid__": strval(hexStr), + }}}} +} + +func bsonRegexval(pattern, options string) *pb.Value { + return &pb.Value{ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "__regex__": mapval(map[string]*pb.Value{ + "pattern": strval(pattern), + "options": strval(options), + }), + }}}} +} diff --git a/firestore/to_value.go b/firestore/to_value.go index 27e4070786d9..acf979b70db8 100644 --- a/firestore/to_value.go +++ b/firestore/to_value.go @@ -36,6 +36,14 @@ var ( typeOfProtoTimestamp = reflect.TypeOf((*ts.Timestamp)(nil)) typeOfVector64 = reflect.TypeOf(Vector64{}) typeOfVector32 = reflect.TypeOf(Vector32{}) + typeOfBSONObjectID = reflect.TypeOf(BSONObjectID("")) + typeOfBSONRegex = reflect.TypeOf(BSONRegex{}) + typeOfBSONTimestamp = reflect.TypeOf(BSONTimestamp{}) + typeOfBSONDecimal128 = reflect.TypeOf(BSONDecimal128("")) + typeOfBSONMinKey = reflect.TypeOf(BSONMinKey{}) + typeOfBSONMaxKey = reflect.TypeOf(BSONMaxKey{}) + typeOfBSONBinary = reflect.TypeOf(BSONBinary{}) + typeOfBSONInt32 = reflect.TypeOf(BSONInt32(0)) isZeroerType = reflect.TypeOf((*isZeroer)(nil)).Elem() ) @@ -104,6 +112,22 @@ func toProtoValue(v reflect.Value) (pbv *pb.Value, sawTransform bool, err error) return nullValue, false, nil } return &pb.Value{ValueType: &pb.Value_ReferenceValue{ReferenceValue: x.Path}}, false, nil + case BSONObjectID: + return bsonObjectIDToProtoValue(x), false, nil + case BSONRegex: + return bsonRegexToProtoValue(x), false, nil + case BSONTimestamp: + return bsonTimestampToProtoValue(x), false, nil + case BSONDecimal128: + return bsonDecimal128ToProtoValue(x), false, nil + case BSONMinKey: + return bsonMinKeyToProtoValue(), false, nil + case BSONMaxKey: + return bsonMaxKeyToProtoValue(), false, nil + case BSONBinary: + return bsonBinaryToProtoValue(x), false, nil + case BSONInt32: + return bsonInt32ToProtoValue(x), false, nil // Do not add bool, string, int, etc. to this switch; leave them in the // reflect-based switch below. Moving them here would drop support for // types whose underlying types are those primitives. @@ -366,3 +390,120 @@ func isZeroValue(v reflect.Value) bool { } return v.IsZero() } + +func bsonObjectIDToProtoValue(id BSONObjectID) *pb.Value { + return &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "__oid__": stringToProtoValue(id.String()), + }, + }, + }, + } +} + +func bsonRegexToProtoValue(r BSONRegex) *pb.Value { + return &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "__regex__": { + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "pattern": stringToProtoValue(r.Pattern), + "options": stringToProtoValue(r.Options), + }, + }, + }, + }, + }, + }, + }, + } +} + +func bsonTimestampToProtoValue(t BSONTimestamp) *pb.Value { + return &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "__request_timestamp__": { + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "seconds": {ValueType: &pb.Value_IntegerValue{IntegerValue: int64(t.Seconds)}}, + "increment": {ValueType: &pb.Value_IntegerValue{IntegerValue: int64(t.Increment)}}, + }, + }, + }, + }, + }, + }, + }, + } +} + +func bsonDecimal128ToProtoValue(d BSONDecimal128) *pb.Value { + return &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "__decimal128__": stringToProtoValue(string(d)), + }, + }, + }, + } +} + +func bsonMinKeyToProtoValue() *pb.Value { + return &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "__min__": nullValue, + }, + }, + }, + } +} + +func bsonMaxKeyToProtoValue() *pb.Value { + return &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "__max__": nullValue, + }, + }, + }, + } +} + +func bsonBinaryToProtoValue(b BSONBinary) *pb.Value { + payload := make([]byte, len(b.Data)+1) + payload[0] = b.Subtype + copy(payload[1:], b.Data) + return &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "__binary__": {ValueType: &pb.Value_BytesValue{BytesValue: payload}}, + }, + }, + }, + } +} + +func bsonInt32ToProtoValue(i BSONInt32) *pb.Value { + return &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "__int__": {ValueType: &pb.Value_IntegerValue{IntegerValue: int64(i)}}, + }, + }, + }, + } +}