diff --git a/db_test.go b/db_test.go index a6dbc8e..88b06ef 100644 --- a/db_test.go +++ b/db_test.go @@ -755,6 +755,37 @@ func TestDB_Modify(t *testing.T) { require.EqualValues(t, 1, objs[0].ID) } +func TestDB_ModifyUpdatesSecondaryIndex(t *testing.T) { + t.Parallel() + + db, table, _ := newTestDB(t, tagsIndex) + + txn := db.WriteTxn(table) + _, _, err := table.Insert(txn, &testObject{ID: uint64(1), Tags: part.NewSet("foo")}) + require.NoError(t, err, "Insert failed") + + _, hadOld, err := table.Modify(txn, &testObject{ID: uint64(1)}, func(old, new *testObject) *testObject { + return &testObject{ + ID: 1, + Tags: old.Tags.Set("bar"), + } + }) + require.NoError(t, err, "Modify failed") + require.True(t, hadOld, "expected hadOld to be true") + + rtxn := txn.Commit() + + obj, _, found := table.Get(rtxn, tagsIndex.Query("bar")) + require.True(t, found, "expected tags index to include merged tag") + require.EqualValues(t, 1, obj.ID) + require.True(t, obj.Tags.Has("foo")) + require.True(t, obj.Tags.Has("bar")) + + obj, _, found = table.Get(rtxn, tagsIndex.Query("foo")) + require.True(t, found, "expected tags index to retain existing tag") + require.EqualValues(t, 1, obj.ID) +} + func TestDB_Revision(t *testing.T) { t.Parallel() diff --git a/lpm_index.go b/lpm_index.go index 36e3921..2b28a4e 100644 --- a/lpm_index.go +++ b/lpm_index.go @@ -312,7 +312,7 @@ func (l *lpmIndexTxn) insert(key index.Key, obj object) (old object, hadOld bool } // modify implements tableIndexTxn. -func (l *lpmIndexTxn) modify(key index.Key, obj object, mod func(old, new object) object) (old object, hadOld bool, watch <-chan struct{}) { +func (l *lpmIndexTxn) modify(key index.Key, obj object, mod func(old, new object) object) (old object, newObj object, hadOld bool, watch <-chan struct{}) { panic("LPM index cannot be the primary index") } diff --git a/part/part_test.go b/part/part_test.go index 06f1c32..cd569ae 100644 --- a/part/part_test.go +++ b/part/part_test.go @@ -672,7 +672,7 @@ func Test_modify(t *testing.T) { txn := tree.Txn() for i := range 1000 { - old, hadOld := txn.Modify(key, 123, func(x, _ int) int { return x + 1 }) + old, _, hadOld := txn.Modify(key, 123, func(x, _ int) int { return x + 1 }) require.True(t, hadOld) require.Equal(t, i+1, old) } diff --git a/part/tree.go b/part/tree.go index 24ea5e6..3b18897 100644 --- a/part/tree.go +++ b/part/tree.go @@ -127,7 +127,7 @@ func (t *Tree[T]) Insert(key []byte, value T) (old T, hadOld bool, tree Tree[T]) // Returns the old value if it exists. func (t *Tree[T]) Modify(key []byte, value T, mod func(T, T) T) (old T, hadOld bool, tree Tree[T]) { txn := t.Txn() - old, hadOld = txn.Modify(key, value, mod) + old, _, hadOld = txn.Modify(key, value, mod) tree = txn.CommitAndNotify() return } diff --git a/part/txn.go b/part/txn.go index 4d4936d..d7fa3ed 100644 --- a/part/txn.go +++ b/part/txn.go @@ -80,7 +80,7 @@ func (txn *Txn[T]) Insert(key []byte, value T) (old T, hadOld bool) { // Returns the old value if it exists and a watch channel that closes when the // key changes again. func (txn *Txn[T]) InsertWatch(key []byte, value T) (old T, hadOld bool, watch <-chan struct{}) { - old, hadOld, watch, txn.root = txn.insert(txn.root, key, value) + old, _, hadOld, watch, txn.root = txn.insert(txn.root, key, value) validateTree(txn.root, nil, txn.watches, txn.txnID) if !hadOld { txn.size++ @@ -93,19 +93,19 @@ func (txn *Txn[T]) InsertWatch(key []byte, value T) (old T, hadOld bool, watch < // Modify a value in the tree. It is up to the // caller to not mutate the value in-place and to return a clone. -// Returns the old value if it exists. -func (txn *Txn[T]) Modify(key []byte, value T, mod func(T, T) T) (old T, hadOld bool) { - old, hadOld, _ = txn.ModifyWatch(key, value, mod) +// Returns the old value (if it exists) and the new possibly merged value. +func (txn *Txn[T]) Modify(key []byte, value T, mod func(T, T) T) (old T, newValue T, hadOld bool) { + old, newValue, hadOld, _ = txn.ModifyWatch(key, value, mod) return } // Modify a value in the tree. If the key does not exist the modify // function is called with the zero value for T. It is up to the // caller to not mutate the value in-place and to return a clone. -// Returns the old value if it exists and a watch channel that closes -// when the key changes again. -func (txn *Txn[T]) ModifyWatch(key []byte, value T, mod func(T, T) T) (old T, hadOld bool, watch <-chan struct{}) { - old, hadOld, watch, txn.root = txn.modify(txn.root, key, value, mod) +// Returns the old value (if it exists) and the new possibly merged value, +// and a watch channel that closes when the key changes again. +func (txn *Txn[T]) ModifyWatch(key []byte, value T, mod func(T, T) T) (old T, newValue T, hadOld bool, watch <-chan struct{}) { + old, newValue, hadOld, watch, txn.root = txn.modify(txn.root, key, value, mod) validateTree(txn.root, nil, txn.watches, txn.txnID) if !hadOld { txn.size++ @@ -243,17 +243,18 @@ func (txn *Txn[T]) cloneNode(n *header[T]) *header[T] { return n } -func (txn *Txn[T]) insert(root *header[T], key []byte, value T) (oldValue T, hadOld bool, watch <-chan struct{}, newRoot *header[T]) { +func (txn *Txn[T]) insert(root *header[T], key []byte, value T) (oldValue T, newValue T, hadOld bool, watch <-chan struct{}, newRoot *header[T]) { return txn.modify(root, key, value, nil) } -func (txn *Txn[T]) modify(root *header[T], key []byte, newValue T, mod func(T, T) T) (oldValue T, hadOld bool, watch <-chan struct{}, newRoot *header[T]) { +func (txn *Txn[T]) modify(root *header[T], key []byte, newValue T, mod func(T, T) T) (oldValue T, newValueOut T, hadOld bool, watch <-chan struct{}, newRoot *header[T]) { txn.dirty = true fullKey := key + newValueOut = newValue if root == nil { leaf := newLeaf(txn.opts, key, fullKey, newValue) - return oldValue, false, leaf.watch, leaf.self() + return oldValue, newValueOut, false, leaf.watch, leaf.self() } // Start recursing from the root to find the insertion point. @@ -323,7 +324,8 @@ func (txn *Txn[T]) modify(root *header[T], key []byte, newValue T, mod func(T, T } watch = leaf.watch if mod != nil { - leaf.value = mod(oldValue, newValue) + newValueOut = mod(oldValue, newValue) + leaf.value = newValueOut } else { leaf.value = newValue } diff --git a/part_index.go b/part_index.go index 7bde0da..c227a98 100644 --- a/part_index.go +++ b/part_index.go @@ -321,7 +321,7 @@ func (r *partIndexTxn) len() int { } // modify implements tableIndexTxn. -func (r *partIndexTxn) modify(key index.Key, obj object, mod func(old, new object) object) (old object, hadOld bool, watch <-chan struct{}) { +func (r *partIndexTxn) modify(key index.Key, obj object, mod func(old, new object) object) (old object, newObj object, hadOld bool, watch <-chan struct{}) { return r.tx.ModifyWatch(key, obj, mod) } diff --git a/quick_test.go b/quick_test.go index 79d58c2..0b26c85 100644 --- a/quick_test.go +++ b/quick_test.go @@ -120,7 +120,7 @@ func TestDB_Quick(t *testing.T) { numInserted, numRemoved := 0, 0 - check := func(a, b string, remove bool) bool { + check := func(a, b string, remove bool, useModify bool) bool { txn := db.WriteTxn(table) if remove { key := pickRandom() @@ -152,13 +152,31 @@ func TestDB_Quick(t *testing.T) { getObj, _, getWatch, getFound := table.GetWatch(txn, aIndex.Query(a)) - old, hadOld, err := table.Insert(txn, quickObj{a, b}) - require.NoError(t, err, "Insert") + expected, found := values[a] + expectedB := b + + var ( + old quickObj + hadOld bool + err error + ) + if useModify { + old, hadOld, err = table.Modify(txn, quickObj{a, b}, func(old, new quickObj) quickObj { + new.B = old.B + new.B + return new + }) + require.NoError(t, err, "Modify") + if found { + expectedB = expected + b + } + } else { + old, hadOld, err = table.Insert(txn, quickObj{a, b}) + require.NoError(t, err, "Insert") + } numInserted++ require.Equal(t, getFound, hadOld) - expected, found := values[a] if found { if !hadOld { t.Logf("object was updated but old value not returned") @@ -174,7 +192,7 @@ func TestDB_Quick(t *testing.T) { return false } } - values[a] = b + values[a] = expectedB if len(values) != table.NumObjects(txn) { t.Logf("wrong object count") @@ -270,6 +288,7 @@ func TestDB_Quick(t *testing.T) { // // Check against the secondary (non-unique index) // + queryB := expectedB // Non-unique indexes return the same number of objects as we've inserted. if len(values) != seqLen(table.Prefix(rtxn, bIndex.Query(""))) { @@ -284,15 +303,15 @@ func TestDB_Quick(t *testing.T) { // Get returns the first match, but since the index is non-unique, this might // not be the one that we just inserted. - obj, _, found = table.Get(rtxn, bIndex.Query(b)) - if !found || obj.B != b { - t.Logf("Get(%q) via bIndex not found (%v) or wrong B (%q vs %q)", b, found, obj.B, b) + obj, _, found = table.Get(rtxn, bIndex.Query(queryB)) + if !found || obj.B != queryB { + t.Logf("Get(%q) via bIndex not found (%v) or wrong B (%q vs %q)", queryB, found, obj.B, queryB) return false } found = false - for obj := range table.List(rtxn, bIndex.Query(b)) { - if obj.B != b { + for obj := range table.List(rtxn, bIndex.Query(queryB)) { + if obj.B != queryB { t.Logf("List() via bIndex wrong B") return false } @@ -306,8 +325,8 @@ func TestDB_Quick(t *testing.T) { } visited := map[string]struct{}{} - for obj := range table.Prefix(rtxn, bIndex.Query(b)) { - if !strings.HasPrefix(obj.B, b) { + for obj := range table.Prefix(rtxn, bIndex.Query(queryB)) { + if !strings.HasPrefix(obj.B, queryB) { t.Logf("Prefix() via bIndex has wrong prefix") return false } @@ -318,20 +337,20 @@ func TestDB_Quick(t *testing.T) { visited[obj.A] = struct{}{} } - anyObjs, err = anyTable.Prefix(rtxn, "b", b) + anyObjs, err = anyTable.Prefix(rtxn, "b", queryB) require.NoError(t, err, "AnyTable.Prefix") for anyObj := range anyObjs { obj := anyObj.(quickObj) - if !strings.HasPrefix(obj.B, b) { - t.Logf("AnyTable.Prefix() via bIndex has wrong prefix: %q vs %q", obj.B, b) + if !strings.HasPrefix(obj.B, queryB) { + t.Logf("AnyTable.Prefix() via bIndex has wrong prefix: %q vs %q", obj.B, queryB) return false } } visited = map[string]struct{}{} - for obj := range table.LowerBound(rtxn, bIndex.Query(b)) { - if cmp.Compare(obj.B, b) < 0 { - t.Logf("LowerBound() via bIndex has wrong objects, expected %v >= %v", []byte(obj.B), []byte(b)) + for obj := range table.LowerBound(rtxn, bIndex.Query(queryB)) { + if cmp.Compare(obj.B, queryB) < 0 { + t.Logf("LowerBound() via bIndex has wrong objects, expected %v >= %v", []byte(obj.B), []byte(queryB)) return false } if _, found := visited[obj.A]; found { @@ -341,12 +360,12 @@ func TestDB_Quick(t *testing.T) { visited[obj.A] = struct{}{} } - anyObjs, err = anyTable.LowerBound(rtxn, "b", b) + anyObjs, err = anyTable.LowerBound(rtxn, "b", queryB) require.NoError(t, err, "AnyTable.LowerBound") for anyObj := range anyObjs { obj := anyObj.(quickObj) - if cmp.Compare(obj.B, b) < 0 { - t.Logf("AnyTable.LowerBound() via bIndex has wrong objects, expected %v >= %v", []byte(obj.B), []byte(b)) + if cmp.Compare(obj.B, queryB) < 0 { + t.Logf("AnyTable.LowerBound() via bIndex has wrong objects, expected %v >= %v", []byte(obj.B), []byte(queryB)) return false } } @@ -371,7 +390,7 @@ func TestDB_Quick(t *testing.T) { // than the default quick value generation to hit the more interesting cases // often. Values: func(args []reflect.Value, rand *rand.Rand) { - if len(args) != 3 { + if len(args) != 4 { panic("unexpected args count") } for i := range args[:2] { @@ -382,6 +401,8 @@ func TestDB_Quick(t *testing.T) { } // Remove 33% of the time args[2] = reflect.ValueOf(rand.Intn(3) == 1) + // Use Modify 50% of the time + args[3] = reflect.ValueOf(rand.Intn(2) == 0) }, })) diff --git a/table.go b/table.go index 3ce8a5d..c3b6dcd 100644 --- a/table.go +++ b/table.go @@ -486,7 +486,7 @@ func (t *genTable[Obj]) InsertWatch(txn WriteTxn, obj Obj) (oldObj Obj, hadOld b func (t *genTable[Obj]) Modify(txn WriteTxn, obj Obj, merge func(old, new Obj) Obj) (oldObj Obj, hadOld bool, err error) { mergeObjects := func(old object, new object) object { - new.data = merge(old.data.(Obj), obj) + new.data = merge(old.data.(Obj), new.data.(Obj)) return new } var old object diff --git a/types.go b/types.go index e9fd365..8237aae 100644 --- a/types.go +++ b/types.go @@ -414,7 +414,7 @@ type tableIndexTxn interface { tableIndex insert(key index.Key, obj object) (old object, hadOld bool, watch <-chan struct{}) - modify(key index.Key, obj object, mod func(old, new object) object) (old object, hadOld bool, watch <-chan struct{}) + modify(key index.Key, obj object, mod func(old, new object) object) (old object, new object, hadOld bool, watch <-chan struct{}) delete(key index.Key) (old object, hadOld bool) reindex(primaryKey index.Key, old object, new object) } diff --git a/write_txn.go b/write_txn.go index 6aae11d..091ac0d 100644 --- a/write_txn.go +++ b/write_txn.go @@ -139,7 +139,9 @@ func (txn *writeTxnState) modify(meta TableMeta, guardRevision Revision, newData if merge == nil { oldObj, oldExists, watch = idIndexTxn.insert(idKey, obj) } else { - oldObj, oldExists, watch = idIndexTxn.modify(idKey, obj, merge) + // Insert the object into the primary index. This returns the merged new + // object which we'll then insert into the secondary indexes. + oldObj, obj, oldExists, watch = idIndexTxn.modify(idKey, obj, merge) } // Sanity check: is the same object being inserted back and thus the