From 93698bce7404a7baea555be1762bbea14ba70475 Mon Sep 17 00:00:00 2001 From: Jussi Maki Date: Wed, 4 Feb 2026 14:00:33 +0100 Subject: [PATCH 1/2] statedb: Add test for Modify that checks secondary indexes The changes to Modify() to not call merge() when the object did not exist are broken. It only called merge() for the primary index and the other indexes got the new object without merging it with the old one. Add a regression test to catch this issue and extend the quick tests to also check for this. Signed-off-by: Jussi Maki --- db_test.go | 31 ++++++++++++++++++++++++ quick_test.go | 65 ++++++++++++++++++++++++++++++++++----------------- 2 files changed, 74 insertions(+), 22 deletions(-) diff --git a/db_test.go b/db_test.go index a6dbc8e..88b06ef 100644 --- a/db_test.go +++ b/db_test.go @@ -755,6 +755,37 @@ func TestDB_Modify(t *testing.T) { require.EqualValues(t, 1, objs[0].ID) } +func TestDB_ModifyUpdatesSecondaryIndex(t *testing.T) { + t.Parallel() + + db, table, _ := newTestDB(t, tagsIndex) + + txn := db.WriteTxn(table) + _, _, err := table.Insert(txn, &testObject{ID: uint64(1), Tags: part.NewSet("foo")}) + require.NoError(t, err, "Insert failed") + + _, hadOld, err := table.Modify(txn, &testObject{ID: uint64(1)}, func(old, new *testObject) *testObject { + return &testObject{ + ID: 1, + Tags: old.Tags.Set("bar"), + } + }) + require.NoError(t, err, "Modify failed") + require.True(t, hadOld, "expected hadOld to be true") + + rtxn := txn.Commit() + + obj, _, found := table.Get(rtxn, tagsIndex.Query("bar")) + require.True(t, found, "expected tags index to include merged tag") + require.EqualValues(t, 1, obj.ID) + require.True(t, obj.Tags.Has("foo")) + require.True(t, obj.Tags.Has("bar")) + + obj, _, found = table.Get(rtxn, tagsIndex.Query("foo")) + require.True(t, found, "expected tags index to retain existing tag") + require.EqualValues(t, 1, obj.ID) +} + func TestDB_Revision(t *testing.T) { t.Parallel() diff --git a/quick_test.go b/quick_test.go index 79d58c2..0b26c85 100644 --- a/quick_test.go +++ b/quick_test.go @@ -120,7 +120,7 @@ func TestDB_Quick(t *testing.T) { numInserted, numRemoved := 0, 0 - check := func(a, b string, remove bool) bool { + check := func(a, b string, remove bool, useModify bool) bool { txn := db.WriteTxn(table) if remove { key := pickRandom() @@ -152,13 +152,31 @@ func TestDB_Quick(t *testing.T) { getObj, _, getWatch, getFound := table.GetWatch(txn, aIndex.Query(a)) - old, hadOld, err := table.Insert(txn, quickObj{a, b}) - require.NoError(t, err, "Insert") + expected, found := values[a] + expectedB := b + + var ( + old quickObj + hadOld bool + err error + ) + if useModify { + old, hadOld, err = table.Modify(txn, quickObj{a, b}, func(old, new quickObj) quickObj { + new.B = old.B + new.B + return new + }) + require.NoError(t, err, "Modify") + if found { + expectedB = expected + b + } + } else { + old, hadOld, err = table.Insert(txn, quickObj{a, b}) + require.NoError(t, err, "Insert") + } numInserted++ require.Equal(t, getFound, hadOld) - expected, found := values[a] if found { if !hadOld { t.Logf("object was updated but old value not returned") @@ -174,7 +192,7 @@ func TestDB_Quick(t *testing.T) { return false } } - values[a] = b + values[a] = expectedB if len(values) != table.NumObjects(txn) { t.Logf("wrong object count") @@ -270,6 +288,7 @@ func TestDB_Quick(t *testing.T) { // // Check against the secondary (non-unique index) // + queryB := expectedB // Non-unique indexes return the same number of objects as we've inserted. if len(values) != seqLen(table.Prefix(rtxn, bIndex.Query(""))) { @@ -284,15 +303,15 @@ func TestDB_Quick(t *testing.T) { // Get returns the first match, but since the index is non-unique, this might // not be the one that we just inserted. - obj, _, found = table.Get(rtxn, bIndex.Query(b)) - if !found || obj.B != b { - t.Logf("Get(%q) via bIndex not found (%v) or wrong B (%q vs %q)", b, found, obj.B, b) + obj, _, found = table.Get(rtxn, bIndex.Query(queryB)) + if !found || obj.B != queryB { + t.Logf("Get(%q) via bIndex not found (%v) or wrong B (%q vs %q)", queryB, found, obj.B, queryB) return false } found = false - for obj := range table.List(rtxn, bIndex.Query(b)) { - if obj.B != b { + for obj := range table.List(rtxn, bIndex.Query(queryB)) { + if obj.B != queryB { t.Logf("List() via bIndex wrong B") return false } @@ -306,8 +325,8 @@ func TestDB_Quick(t *testing.T) { } visited := map[string]struct{}{} - for obj := range table.Prefix(rtxn, bIndex.Query(b)) { - if !strings.HasPrefix(obj.B, b) { + for obj := range table.Prefix(rtxn, bIndex.Query(queryB)) { + if !strings.HasPrefix(obj.B, queryB) { t.Logf("Prefix() via bIndex has wrong prefix") return false } @@ -318,20 +337,20 @@ func TestDB_Quick(t *testing.T) { visited[obj.A] = struct{}{} } - anyObjs, err = anyTable.Prefix(rtxn, "b", b) + anyObjs, err = anyTable.Prefix(rtxn, "b", queryB) require.NoError(t, err, "AnyTable.Prefix") for anyObj := range anyObjs { obj := anyObj.(quickObj) - if !strings.HasPrefix(obj.B, b) { - t.Logf("AnyTable.Prefix() via bIndex has wrong prefix: %q vs %q", obj.B, b) + if !strings.HasPrefix(obj.B, queryB) { + t.Logf("AnyTable.Prefix() via bIndex has wrong prefix: %q vs %q", obj.B, queryB) return false } } visited = map[string]struct{}{} - for obj := range table.LowerBound(rtxn, bIndex.Query(b)) { - if cmp.Compare(obj.B, b) < 0 { - t.Logf("LowerBound() via bIndex has wrong objects, expected %v >= %v", []byte(obj.B), []byte(b)) + for obj := range table.LowerBound(rtxn, bIndex.Query(queryB)) { + if cmp.Compare(obj.B, queryB) < 0 { + t.Logf("LowerBound() via bIndex has wrong objects, expected %v >= %v", []byte(obj.B), []byte(queryB)) return false } if _, found := visited[obj.A]; found { @@ -341,12 +360,12 @@ func TestDB_Quick(t *testing.T) { visited[obj.A] = struct{}{} } - anyObjs, err = anyTable.LowerBound(rtxn, "b", b) + anyObjs, err = anyTable.LowerBound(rtxn, "b", queryB) require.NoError(t, err, "AnyTable.LowerBound") for anyObj := range anyObjs { obj := anyObj.(quickObj) - if cmp.Compare(obj.B, b) < 0 { - t.Logf("AnyTable.LowerBound() via bIndex has wrong objects, expected %v >= %v", []byte(obj.B), []byte(b)) + if cmp.Compare(obj.B, queryB) < 0 { + t.Logf("AnyTable.LowerBound() via bIndex has wrong objects, expected %v >= %v", []byte(obj.B), []byte(queryB)) return false } } @@ -371,7 +390,7 @@ func TestDB_Quick(t *testing.T) { // than the default quick value generation to hit the more interesting cases // often. Values: func(args []reflect.Value, rand *rand.Rand) { - if len(args) != 3 { + if len(args) != 4 { panic("unexpected args count") } for i := range args[:2] { @@ -382,6 +401,8 @@ func TestDB_Quick(t *testing.T) { } // Remove 33% of the time args[2] = reflect.ValueOf(rand.Intn(3) == 1) + // Use Modify 50% of the time + args[3] = reflect.ValueOf(rand.Intn(2) == 0) }, })) From c7a61566232a83c51f6633be0d874f21ec484246 Mon Sep 17 00:00:00 2001 From: Jussi Maki Date: Wed, 4 Feb 2026 14:07:30 +0100 Subject: [PATCH 2/2] statedb: Fix regression in Modify() The change to not call merge() if the object did not exist was broken as merge() wasn't called for the object inserted into secondary indexes. Fix the issue by returning the merged new object from tableIndexTxn.modify and inserting that into the secondary indexes. Signed-off-by: Jussi Maki --- lpm_index.go | 2 +- part/part_test.go | 2 +- part/tree.go | 2 +- part/txn.go | 26 ++++++++++++++------------ part_index.go | 2 +- table.go | 2 +- types.go | 2 +- write_txn.go | 4 +++- 8 files changed, 23 insertions(+), 19 deletions(-) diff --git a/lpm_index.go b/lpm_index.go index 36e3921..2b28a4e 100644 --- a/lpm_index.go +++ b/lpm_index.go @@ -312,7 +312,7 @@ func (l *lpmIndexTxn) insert(key index.Key, obj object) (old object, hadOld bool } // modify implements tableIndexTxn. -func (l *lpmIndexTxn) modify(key index.Key, obj object, mod func(old, new object) object) (old object, hadOld bool, watch <-chan struct{}) { +func (l *lpmIndexTxn) modify(key index.Key, obj object, mod func(old, new object) object) (old object, newObj object, hadOld bool, watch <-chan struct{}) { panic("LPM index cannot be the primary index") } diff --git a/part/part_test.go b/part/part_test.go index 06f1c32..cd569ae 100644 --- a/part/part_test.go +++ b/part/part_test.go @@ -672,7 +672,7 @@ func Test_modify(t *testing.T) { txn := tree.Txn() for i := range 1000 { - old, hadOld := txn.Modify(key, 123, func(x, _ int) int { return x + 1 }) + old, _, hadOld := txn.Modify(key, 123, func(x, _ int) int { return x + 1 }) require.True(t, hadOld) require.Equal(t, i+1, old) } diff --git a/part/tree.go b/part/tree.go index 24ea5e6..3b18897 100644 --- a/part/tree.go +++ b/part/tree.go @@ -127,7 +127,7 @@ func (t *Tree[T]) Insert(key []byte, value T) (old T, hadOld bool, tree Tree[T]) // Returns the old value if it exists. func (t *Tree[T]) Modify(key []byte, value T, mod func(T, T) T) (old T, hadOld bool, tree Tree[T]) { txn := t.Txn() - old, hadOld = txn.Modify(key, value, mod) + old, _, hadOld = txn.Modify(key, value, mod) tree = txn.CommitAndNotify() return } diff --git a/part/txn.go b/part/txn.go index 4d4936d..d7fa3ed 100644 --- a/part/txn.go +++ b/part/txn.go @@ -80,7 +80,7 @@ func (txn *Txn[T]) Insert(key []byte, value T) (old T, hadOld bool) { // Returns the old value if it exists and a watch channel that closes when the // key changes again. func (txn *Txn[T]) InsertWatch(key []byte, value T) (old T, hadOld bool, watch <-chan struct{}) { - old, hadOld, watch, txn.root = txn.insert(txn.root, key, value) + old, _, hadOld, watch, txn.root = txn.insert(txn.root, key, value) validateTree(txn.root, nil, txn.watches, txn.txnID) if !hadOld { txn.size++ @@ -93,19 +93,19 @@ func (txn *Txn[T]) InsertWatch(key []byte, value T) (old T, hadOld bool, watch < // Modify a value in the tree. It is up to the // caller to not mutate the value in-place and to return a clone. -// Returns the old value if it exists. -func (txn *Txn[T]) Modify(key []byte, value T, mod func(T, T) T) (old T, hadOld bool) { - old, hadOld, _ = txn.ModifyWatch(key, value, mod) +// Returns the old value (if it exists) and the new possibly merged value. +func (txn *Txn[T]) Modify(key []byte, value T, mod func(T, T) T) (old T, newValue T, hadOld bool) { + old, newValue, hadOld, _ = txn.ModifyWatch(key, value, mod) return } // Modify a value in the tree. If the key does not exist the modify // function is called with the zero value for T. It is up to the // caller to not mutate the value in-place and to return a clone. -// Returns the old value if it exists and a watch channel that closes -// when the key changes again. -func (txn *Txn[T]) ModifyWatch(key []byte, value T, mod func(T, T) T) (old T, hadOld bool, watch <-chan struct{}) { - old, hadOld, watch, txn.root = txn.modify(txn.root, key, value, mod) +// Returns the old value (if it exists) and the new possibly merged value, +// and a watch channel that closes when the key changes again. +func (txn *Txn[T]) ModifyWatch(key []byte, value T, mod func(T, T) T) (old T, newValue T, hadOld bool, watch <-chan struct{}) { + old, newValue, hadOld, watch, txn.root = txn.modify(txn.root, key, value, mod) validateTree(txn.root, nil, txn.watches, txn.txnID) if !hadOld { txn.size++ @@ -243,17 +243,18 @@ func (txn *Txn[T]) cloneNode(n *header[T]) *header[T] { return n } -func (txn *Txn[T]) insert(root *header[T], key []byte, value T) (oldValue T, hadOld bool, watch <-chan struct{}, newRoot *header[T]) { +func (txn *Txn[T]) insert(root *header[T], key []byte, value T) (oldValue T, newValue T, hadOld bool, watch <-chan struct{}, newRoot *header[T]) { return txn.modify(root, key, value, nil) } -func (txn *Txn[T]) modify(root *header[T], key []byte, newValue T, mod func(T, T) T) (oldValue T, hadOld bool, watch <-chan struct{}, newRoot *header[T]) { +func (txn *Txn[T]) modify(root *header[T], key []byte, newValue T, mod func(T, T) T) (oldValue T, newValueOut T, hadOld bool, watch <-chan struct{}, newRoot *header[T]) { txn.dirty = true fullKey := key + newValueOut = newValue if root == nil { leaf := newLeaf(txn.opts, key, fullKey, newValue) - return oldValue, false, leaf.watch, leaf.self() + return oldValue, newValueOut, false, leaf.watch, leaf.self() } // Start recursing from the root to find the insertion point. @@ -323,7 +324,8 @@ func (txn *Txn[T]) modify(root *header[T], key []byte, newValue T, mod func(T, T } watch = leaf.watch if mod != nil { - leaf.value = mod(oldValue, newValue) + newValueOut = mod(oldValue, newValue) + leaf.value = newValueOut } else { leaf.value = newValue } diff --git a/part_index.go b/part_index.go index 7bde0da..c227a98 100644 --- a/part_index.go +++ b/part_index.go @@ -321,7 +321,7 @@ func (r *partIndexTxn) len() int { } // modify implements tableIndexTxn. -func (r *partIndexTxn) modify(key index.Key, obj object, mod func(old, new object) object) (old object, hadOld bool, watch <-chan struct{}) { +func (r *partIndexTxn) modify(key index.Key, obj object, mod func(old, new object) object) (old object, newObj object, hadOld bool, watch <-chan struct{}) { return r.tx.ModifyWatch(key, obj, mod) } diff --git a/table.go b/table.go index 3ce8a5d..c3b6dcd 100644 --- a/table.go +++ b/table.go @@ -486,7 +486,7 @@ func (t *genTable[Obj]) InsertWatch(txn WriteTxn, obj Obj) (oldObj Obj, hadOld b func (t *genTable[Obj]) Modify(txn WriteTxn, obj Obj, merge func(old, new Obj) Obj) (oldObj Obj, hadOld bool, err error) { mergeObjects := func(old object, new object) object { - new.data = merge(old.data.(Obj), obj) + new.data = merge(old.data.(Obj), new.data.(Obj)) return new } var old object diff --git a/types.go b/types.go index e9fd365..8237aae 100644 --- a/types.go +++ b/types.go @@ -414,7 +414,7 @@ type tableIndexTxn interface { tableIndex insert(key index.Key, obj object) (old object, hadOld bool, watch <-chan struct{}) - modify(key index.Key, obj object, mod func(old, new object) object) (old object, hadOld bool, watch <-chan struct{}) + modify(key index.Key, obj object, mod func(old, new object) object) (old object, new object, hadOld bool, watch <-chan struct{}) delete(key index.Key) (old object, hadOld bool) reindex(primaryKey index.Key, old object, new object) } diff --git a/write_txn.go b/write_txn.go index 6aae11d..091ac0d 100644 --- a/write_txn.go +++ b/write_txn.go @@ -139,7 +139,9 @@ func (txn *writeTxnState) modify(meta TableMeta, guardRevision Revision, newData if merge == nil { oldObj, oldExists, watch = idIndexTxn.insert(idKey, obj) } else { - oldObj, oldExists, watch = idIndexTxn.modify(idKey, obj, merge) + // Insert the object into the primary index. This returns the merged new + // object which we'll then insert into the secondary indexes. + oldObj, obj, oldExists, watch = idIndexTxn.modify(idKey, obj, merge) } // Sanity check: is the same object being inserted back and thus the