Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,37 @@ func TestDB_Modify(t *testing.T) {
require.EqualValues(t, 1, objs[0].ID)
}

func TestDB_ModifyUpdatesSecondaryIndex(t *testing.T) {
t.Parallel()

db, table, _ := newTestDB(t, tagsIndex)

txn := db.WriteTxn(table)
_, _, err := table.Insert(txn, &testObject{ID: uint64(1), Tags: part.NewSet("foo")})
require.NoError(t, err, "Insert failed")

_, hadOld, err := table.Modify(txn, &testObject{ID: uint64(1)}, func(old, new *testObject) *testObject {
return &testObject{
ID: 1,
Tags: old.Tags.Set("bar"),
}
})
require.NoError(t, err, "Modify failed")
require.True(t, hadOld, "expected hadOld to be true")

rtxn := txn.Commit()

obj, _, found := table.Get(rtxn, tagsIndex.Query("bar"))
require.True(t, found, "expected tags index to include merged tag")
require.EqualValues(t, 1, obj.ID)
require.True(t, obj.Tags.Has("foo"))
require.True(t, obj.Tags.Has("bar"))

obj, _, found = table.Get(rtxn, tagsIndex.Query("foo"))
require.True(t, found, "expected tags index to retain existing tag")
require.EqualValues(t, 1, obj.ID)
}

func TestDB_Revision(t *testing.T) {
t.Parallel()

Expand Down
2 changes: 1 addition & 1 deletion lpm_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ func (l *lpmIndexTxn) insert(key index.Key, obj object) (old object, hadOld bool
}

// modify implements tableIndexTxn.
func (l *lpmIndexTxn) modify(key index.Key, obj object, mod func(old, new object) object) (old object, hadOld bool, watch <-chan struct{}) {
func (l *lpmIndexTxn) modify(key index.Key, obj object, mod func(old, new object) object) (old object, newObj object, hadOld bool, watch <-chan struct{}) {
panic("LPM index cannot be the primary index")
}

Expand Down
2 changes: 1 addition & 1 deletion part/part_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ func Test_modify(t *testing.T) {

txn := tree.Txn()
for i := range 1000 {
old, hadOld := txn.Modify(key, 123, func(x, _ int) int { return x + 1 })
old, _, hadOld := txn.Modify(key, 123, func(x, _ int) int { return x + 1 })
require.True(t, hadOld)
require.Equal(t, i+1, old)
}
Expand Down
2 changes: 1 addition & 1 deletion part/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (t *Tree[T]) Insert(key []byte, value T) (old T, hadOld bool, tree Tree[T])
// Returns the old value if it exists.
func (t *Tree[T]) Modify(key []byte, value T, mod func(T, T) T) (old T, hadOld bool, tree Tree[T]) {
txn := t.Txn()
old, hadOld = txn.Modify(key, value, mod)
old, _, hadOld = txn.Modify(key, value, mod)
tree = txn.CommitAndNotify()
return
}
Expand Down
26 changes: 14 additions & 12 deletions part/txn.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (txn *Txn[T]) Insert(key []byte, value T) (old T, hadOld bool) {
// Returns the old value if it exists and a watch channel that closes when the
// key changes again.
func (txn *Txn[T]) InsertWatch(key []byte, value T) (old T, hadOld bool, watch <-chan struct{}) {
old, hadOld, watch, txn.root = txn.insert(txn.root, key, value)
old, _, hadOld, watch, txn.root = txn.insert(txn.root, key, value)
validateTree(txn.root, nil, txn.watches, txn.txnID)
if !hadOld {
txn.size++
Expand All @@ -93,19 +93,19 @@ func (txn *Txn[T]) InsertWatch(key []byte, value T) (old T, hadOld bool, watch <

// Modify a value in the tree. It is up to the
// caller to not mutate the value in-place and to return a clone.
// Returns the old value if it exists.
func (txn *Txn[T]) Modify(key []byte, value T, mod func(T, T) T) (old T, hadOld bool) {
old, hadOld, _ = txn.ModifyWatch(key, value, mod)
// Returns the old value (if it exists) and the new possibly merged value.
func (txn *Txn[T]) Modify(key []byte, value T, mod func(T, T) T) (old T, newValue T, hadOld bool) {
old, newValue, hadOld, _ = txn.ModifyWatch(key, value, mod)
return
}

// Modify a value in the tree. If the key does not exist the modify
// function is called with the zero value for T. It is up to the
// caller to not mutate the value in-place and to return a clone.
// Returns the old value if it exists and a watch channel that closes
// when the key changes again.
func (txn *Txn[T]) ModifyWatch(key []byte, value T, mod func(T, T) T) (old T, hadOld bool, watch <-chan struct{}) {
old, hadOld, watch, txn.root = txn.modify(txn.root, key, value, mod)
// Returns the old value (if it exists) and the new possibly merged value,
// and a watch channel that closes when the key changes again.
func (txn *Txn[T]) ModifyWatch(key []byte, value T, mod func(T, T) T) (old T, newValue T, hadOld bool, watch <-chan struct{}) {
old, newValue, hadOld, watch, txn.root = txn.modify(txn.root, key, value, mod)
validateTree(txn.root, nil, txn.watches, txn.txnID)
if !hadOld {
txn.size++
Expand Down Expand Up @@ -243,17 +243,18 @@ func (txn *Txn[T]) cloneNode(n *header[T]) *header[T] {
return n
}

func (txn *Txn[T]) insert(root *header[T], key []byte, value T) (oldValue T, hadOld bool, watch <-chan struct{}, newRoot *header[T]) {
func (txn *Txn[T]) insert(root *header[T], key []byte, value T) (oldValue T, newValue T, hadOld bool, watch <-chan struct{}, newRoot *header[T]) {
return txn.modify(root, key, value, nil)
}

func (txn *Txn[T]) modify(root *header[T], key []byte, newValue T, mod func(T, T) T) (oldValue T, hadOld bool, watch <-chan struct{}, newRoot *header[T]) {
func (txn *Txn[T]) modify(root *header[T], key []byte, newValue T, mod func(T, T) T) (oldValue T, newValueOut T, hadOld bool, watch <-chan struct{}, newRoot *header[T]) {
txn.dirty = true
fullKey := key
newValueOut = newValue

if root == nil {
leaf := newLeaf(txn.opts, key, fullKey, newValue)
return oldValue, false, leaf.watch, leaf.self()
return oldValue, newValueOut, false, leaf.watch, leaf.self()
}

// Start recursing from the root to find the insertion point.
Expand Down Expand Up @@ -323,7 +324,8 @@ func (txn *Txn[T]) modify(root *header[T], key []byte, newValue T, mod func(T, T
}
watch = leaf.watch
if mod != nil {
leaf.value = mod(oldValue, newValue)
newValueOut = mod(oldValue, newValue)
leaf.value = newValueOut
} else {
leaf.value = newValue
}
Expand Down
2 changes: 1 addition & 1 deletion part_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func (r *partIndexTxn) len() int {
}

// modify implements tableIndexTxn.
func (r *partIndexTxn) modify(key index.Key, obj object, mod func(old, new object) object) (old object, hadOld bool, watch <-chan struct{}) {
func (r *partIndexTxn) modify(key index.Key, obj object, mod func(old, new object) object) (old object, newObj object, hadOld bool, watch <-chan struct{}) {
return r.tx.ModifyWatch(key, obj, mod)
}

Expand Down
65 changes: 43 additions & 22 deletions quick_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func TestDB_Quick(t *testing.T) {

numInserted, numRemoved := 0, 0

check := func(a, b string, remove bool) bool {
check := func(a, b string, remove bool, useModify bool) bool {
txn := db.WriteTxn(table)
if remove {
key := pickRandom()
Expand Down Expand Up @@ -152,13 +152,31 @@ func TestDB_Quick(t *testing.T) {

getObj, _, getWatch, getFound := table.GetWatch(txn, aIndex.Query(a))

old, hadOld, err := table.Insert(txn, quickObj{a, b})
require.NoError(t, err, "Insert")
expected, found := values[a]
expectedB := b

var (
old quickObj
hadOld bool
err error
)
if useModify {
old, hadOld, err = table.Modify(txn, quickObj{a, b}, func(old, new quickObj) quickObj {
new.B = old.B + new.B
return new
})
require.NoError(t, err, "Modify")
if found {
expectedB = expected + b
}
} else {
old, hadOld, err = table.Insert(txn, quickObj{a, b})
require.NoError(t, err, "Insert")
}
numInserted++

require.Equal(t, getFound, hadOld)

expected, found := values[a]
if found {
if !hadOld {
t.Logf("object was updated but old value not returned")
Expand All @@ -174,7 +192,7 @@ func TestDB_Quick(t *testing.T) {
return false
}
}
values[a] = b
values[a] = expectedB

if len(values) != table.NumObjects(txn) {
t.Logf("wrong object count")
Expand Down Expand Up @@ -270,6 +288,7 @@ func TestDB_Quick(t *testing.T) {
//
// Check against the secondary (non-unique index)
//
queryB := expectedB

// Non-unique indexes return the same number of objects as we've inserted.
if len(values) != seqLen(table.Prefix(rtxn, bIndex.Query(""))) {
Expand All @@ -284,15 +303,15 @@ func TestDB_Quick(t *testing.T) {

// Get returns the first match, but since the index is non-unique, this might
// not be the one that we just inserted.
obj, _, found = table.Get(rtxn, bIndex.Query(b))
if !found || obj.B != b {
t.Logf("Get(%q) via bIndex not found (%v) or wrong B (%q vs %q)", b, found, obj.B, b)
obj, _, found = table.Get(rtxn, bIndex.Query(queryB))
if !found || obj.B != queryB {
t.Logf("Get(%q) via bIndex not found (%v) or wrong B (%q vs %q)", queryB, found, obj.B, queryB)
return false
}

found = false
for obj := range table.List(rtxn, bIndex.Query(b)) {
if obj.B != b {
for obj := range table.List(rtxn, bIndex.Query(queryB)) {
if obj.B != queryB {
t.Logf("List() via bIndex wrong B")
return false
}
Expand All @@ -306,8 +325,8 @@ func TestDB_Quick(t *testing.T) {
}

visited := map[string]struct{}{}
for obj := range table.Prefix(rtxn, bIndex.Query(b)) {
if !strings.HasPrefix(obj.B, b) {
for obj := range table.Prefix(rtxn, bIndex.Query(queryB)) {
if !strings.HasPrefix(obj.B, queryB) {
t.Logf("Prefix() via bIndex has wrong prefix")
return false
}
Expand All @@ -318,20 +337,20 @@ func TestDB_Quick(t *testing.T) {
visited[obj.A] = struct{}{}
}

anyObjs, err = anyTable.Prefix(rtxn, "b", b)
anyObjs, err = anyTable.Prefix(rtxn, "b", queryB)
require.NoError(t, err, "AnyTable.Prefix")
for anyObj := range anyObjs {
obj := anyObj.(quickObj)
if !strings.HasPrefix(obj.B, b) {
t.Logf("AnyTable.Prefix() via bIndex has wrong prefix: %q vs %q", obj.B, b)
if !strings.HasPrefix(obj.B, queryB) {
t.Logf("AnyTable.Prefix() via bIndex has wrong prefix: %q vs %q", obj.B, queryB)
return false
}
}

visited = map[string]struct{}{}
for obj := range table.LowerBound(rtxn, bIndex.Query(b)) {
if cmp.Compare(obj.B, b) < 0 {
t.Logf("LowerBound() via bIndex has wrong objects, expected %v >= %v", []byte(obj.B), []byte(b))
for obj := range table.LowerBound(rtxn, bIndex.Query(queryB)) {
if cmp.Compare(obj.B, queryB) < 0 {
t.Logf("LowerBound() via bIndex has wrong objects, expected %v >= %v", []byte(obj.B), []byte(queryB))
return false
}
if _, found := visited[obj.A]; found {
Expand All @@ -341,12 +360,12 @@ func TestDB_Quick(t *testing.T) {
visited[obj.A] = struct{}{}
}

anyObjs, err = anyTable.LowerBound(rtxn, "b", b)
anyObjs, err = anyTable.LowerBound(rtxn, "b", queryB)
require.NoError(t, err, "AnyTable.LowerBound")
for anyObj := range anyObjs {
obj := anyObj.(quickObj)
if cmp.Compare(obj.B, b) < 0 {
t.Logf("AnyTable.LowerBound() via bIndex has wrong objects, expected %v >= %v", []byte(obj.B), []byte(b))
if cmp.Compare(obj.B, queryB) < 0 {
t.Logf("AnyTable.LowerBound() via bIndex has wrong objects, expected %v >= %v", []byte(obj.B), []byte(queryB))
return false
}
}
Expand All @@ -371,7 +390,7 @@ func TestDB_Quick(t *testing.T) {
// than the default quick value generation to hit the more interesting cases
// often.
Values: func(args []reflect.Value, rand *rand.Rand) {
if len(args) != 3 {
if len(args) != 4 {
panic("unexpected args count")
}
for i := range args[:2] {
Expand All @@ -382,6 +401,8 @@ func TestDB_Quick(t *testing.T) {
}
// Remove 33% of the time
args[2] = reflect.ValueOf(rand.Intn(3) == 1)
// Use Modify 50% of the time
args[3] = reflect.ValueOf(rand.Intn(2) == 0)
},
}))

Expand Down
2 changes: 1 addition & 1 deletion table.go
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ func (t *genTable[Obj]) InsertWatch(txn WriteTxn, obj Obj) (oldObj Obj, hadOld b

func (t *genTable[Obj]) Modify(txn WriteTxn, obj Obj, merge func(old, new Obj) Obj) (oldObj Obj, hadOld bool, err error) {
mergeObjects := func(old object, new object) object {
new.data = merge(old.data.(Obj), obj)
new.data = merge(old.data.(Obj), new.data.(Obj))
return new
}
var old object
Expand Down
2 changes: 1 addition & 1 deletion types.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ type tableIndexTxn interface {
tableIndex

insert(key index.Key, obj object) (old object, hadOld bool, watch <-chan struct{})
modify(key index.Key, obj object, mod func(old, new object) object) (old object, hadOld bool, watch <-chan struct{})
modify(key index.Key, obj object, mod func(old, new object) object) (old object, new object, hadOld bool, watch <-chan struct{})
delete(key index.Key) (old object, hadOld bool)
reindex(primaryKey index.Key, old object, new object)
}
Expand Down
4 changes: 3 additions & 1 deletion write_txn.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ func (txn *writeTxnState) modify(meta TableMeta, guardRevision Revision, newData
if merge == nil {
oldObj, oldExists, watch = idIndexTxn.insert(idKey, obj)
} else {
oldObj, oldExists, watch = idIndexTxn.modify(idKey, obj, merge)
// Insert the object into the primary index. This returns the merged new
// object which we'll then insert into the secondary indexes.
oldObj, obj, oldExists, watch = idIndexTxn.modify(idKey, obj, merge)
}

// Sanity check: is the same object being inserted back and thus the
Expand Down