From 65ecbbfe1c8034a15b4b38a9af3040ca5f17bdfc Mon Sep 17 00:00:00 2001 From: spekary Date: Fri, 8 Nov 2024 16:54:03 -0800 Subject: [PATCH 1/2] Adding Go 1.23 iterators and other features --- .gitignore | 3 + equaler.go | 35 ++++++ go.mod | 4 +- go.sum | 2 + map.go | 104 +++++++++++++--- map_test.go | 15 +++ mapi.go | 27 ++++ mapi_test.go | 144 +++++++++++++++++++++ safe_map.go | 98 ++++++++++++++- safe_map_test.go | 9 ++ safe_slice_map.go | 275 ++++++++++++++++++----------------------- safe_slice_map_test.go | 47 +++++++ set.go | 63 +++++++++- set_test.go | 15 +++ seti.go | 6 + seti_test.go | 38 ++++++ slice_map.go | 136 ++++++++++++++------ slice_map_test.go | 49 +++++++- std_map.go | 69 +++++++++-- std_map_test.go | 160 ++++++++++++++++++++++++ 20 files changed, 1070 insertions(+), 229 deletions(-) create mode 100644 equaler.go diff --git a/.gitignore b/.gitignore index 66fd13c..b1806d7 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,6 @@ # Dependency directories (remove the comment below to include it) # vendor/ + +# OS specific files +.DS_Store \ No newline at end of file diff --git a/equaler.go b/equaler.go new file mode 100644 index 0000000..af1f6d9 --- /dev/null +++ b/equaler.go @@ -0,0 +1,35 @@ +package maps + +// Equaler is the interface that implements an Equal function and that provides a way for the +// various MapI like objects to determine if they are equal. +// +// In particular, if your Map has +// non-comparible values, like a slice, but you would still like to call Equal() on that +// map, define an Equal function on the values to do the comparison. For example: +// +// type mySlice []int +// +// func (s mySlice) Equal(b any) bool { +// if s2, ok := b.(mySlice); ok { +// if len(s) == len(s2) { +// for i, v := range s2 { +// if s[i] != v { +// return false +// } +// } +// return true +// } +// } +// return false +// } +type Equaler interface { + Equal(a any) bool +} + +func equalValues(a, b any) bool { + if e, ok := a.(Equaler); ok { + return e.Equal(b) + } + + return a == b +} diff --git a/go.mod b/go.mod index e8f9b46..243a1f6 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module github.com/goradd/maps -go 1.18 +go 1.23 -require github.com/stretchr/testify v1.8.4 +require github.com/stretchr/testify v1.9.0 require ( github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/go.sum b/go.sum index fa4b6e6..93786ae 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/map.go b/map.go index 5cf2e4a..e186f29 100644 --- a/map.go +++ b/map.go @@ -1,25 +1,40 @@ package maps +import ( + "iter" +) + // Map is a go map that uses a standard set of functions shared with other Map-like types. // // The recommended way to create a Map is to first declare a concrete type alias, and then call // new on it, like this: -// type MyMap = Map[string,int] // -// m := new(MyMap) +// type MyMap = Map[string,int] +// +// m := new(MyMap) // // This will allow you to swap in a different kind of Map just by changing the type. type Map[K comparable, V any] struct { items StdMap[K, V] } +// NewMap creates a new map that maps values of type K to values of type V. +// Pass in zero or more standard maps and the contents of those maps will be copied to the new Map. +func NewMap[K comparable, V any](sources ...map[K]V) *Map[K, V] { + m := new(Map[K, V]) + for _, i := range sources { + m.Copy(Cast(i)) + } + return m +} + // Clear resets the map to an empty map func (m *Map[K, V]) Clear() { m.items = nil } // Len returns the number of items in the map -func (m Map[K, V]) Len() int { +func (m *Map[K, V]) Len() int { return m.items.Len() } @@ -27,38 +42,38 @@ func (m Map[K, V]) Len() int { // This is the same interface as sync.Map.Range(). // While its safe to call methods of the map from within the Range function, its discouraged. // If you ever switch to one of the SafeMap maps, it will cause a deadlock. -func (m Map[K, V]) Range(f func(k K, v V) bool) { +func (m *Map[K, V]) Range(f func(k K, v V) bool) { m.items.Range(f) } // Load returns the value based on its key, and a boolean indicating whether it exists in the map. // This is the same interface as sync.Map.Load() -func (m Map[K, V]) Load(k K) (V, bool) { +func (m *Map[K, V]) Load(k K) (V, bool) { return m.items.Load(k) } // Get returns the value for the given key. If the key does not exist, the zero value will be returned. -func (m Map[K, V]) Get(k K) V { +func (m *Map[K, V]) Get(k K) V { return m.items.Get(k) } // Has returns true if the key exists. -func (m Map[K, V]) Has(k K) bool { +func (m *Map[K, V]) Has(k K) bool { return m.items.Has(k) } // Delete removes the key from the map. If the key does not exist, nothing happens. -func (m Map[K, V]) Delete(k K) { +func (m *Map[K, V]) Delete(k K) { m.items.Delete(k) } // Keys returns a new slice containing the keys of the map. -func (m Map[K, V]) Keys() []K { +func (m *Map[K, V]) Keys() []K { return m.items.Keys() } // Values returns a new slice containing the values of the map. -func (m Map[K, V]) Values() []V { +func (m *Map[K, V]) Values() []V { return m.items.Values() } @@ -72,38 +87,45 @@ func (m *Map[K, V]) Set(k K, v V) { } // Merge copies the items from in to the map, overwriting any conflicting keys. +// Deprecated: Call Copy instead. func (m *Map[K, V]) Merge(in MapI[K, V]) { + m.Copy(in) +} + +// Copy copies the items from in to the map, overwriting any conflicting keys. +func (m *Map[K, V]) Copy(in MapI[K, V]) { if m.items == nil { m.items = make(map[K]V, in.Len()) } - m.items.Merge(in) + m.items.Copy(in) } // Equal returns true if all the keys and values are equal. // // If the values are not comparable, you should implement the Equaler interface on the values. // Otherwise, you will get a runtime panic. -func (m Map[K, V]) Equal(m2 MapI[K, V]) bool { +func (m *Map[K, V]) Equal(m2 MapI[K, V]) bool { return m.items.Equal(m2) } // MarshalBinary implements the BinaryMarshaler interface to convert the map to a byte stream. -func (m Map[K, V]) MarshalBinary() ([]byte, error) { +func (m *Map[K, V]) MarshalBinary() ([]byte, error) { return m.items.MarshalBinary() } // UnmarshalBinary implements the BinaryUnmarshaler interface to convert a byte stream to a Map. // // Note that you may need to register the map at init time with gob like this: -// func init() { -// gob.Register(new(Map[keytype,valuetype])) -// } +// +// func init() { +// gob.Register(new(Map[keytype,valuetype])) +// } func (m *Map[K, V]) UnmarshalBinary(data []byte) (err error) { return m.items.UnmarshalBinary(data) } // MarshalJSON implements the json.Marshaler interface to convert the map into a JSON object. -func (m Map[K, V]) MarshalJSON() (out []byte, err error) { +func (m *Map[K, V]) MarshalJSON() (out []byte, err error) { return m.items.MarshalJSON() } @@ -114,6 +136,52 @@ func (m *Map[K, V]) UnmarshalJSON(in []byte) (err error) { } // String returns the map as a string. -func (m Map[K, V]) String() string { +func (m *Map[K, V]) String() string { return m.items.String() } + +// All returns an iterator over all the items in the map. +func (m *Map[K, V]) All() iter.Seq2[K, V] { + return m.items.All() +} + +// KeysIter returns an iterator over all the keys in the map. +func (m *Map[K, V]) KeysIter() iter.Seq[K] { + return m.items.KeysIter() +} + +// ValuesIter returns an iterator over all the values in the map. +func (m *Map[K, V]) ValuesIter() iter.Seq[V] { + return m.items.ValuesIter() +} + +// Insert adds the values from seq to the map. +// Duplicate keys are overridden. +func (m *Map[K, V]) Insert(seq iter.Seq2[K, V]) { + if m.items == nil { + m.items = map[K]V{} + } + + m.items.Insert(seq) +} + +// CollectMap collects key-value pairs from seq into a new Map +// and returns it. +func CollectMap[K comparable, V any](seq iter.Seq2[K, V]) *Map[K, V] { + m := new(Map[K, V]) + m.Insert(seq) + return m +} + +// Clone returns a copy of the Map. This is a shallow clone: +// the new keys and values are set using ordinary assignment. +func (m *Map[K, V]) Clone() *Map[K, V] { + m1 := new(Map[K, V]) + m1.items = m.items.Clone() + return m1 +} + +// DeleteFunc deletes any key/value pairs for which del returns true. +func (m *Map[K, V]) DeleteFunc(del func(K, V) bool) { + m.items.DeleteFunc(del) +} diff --git a/map_test.go b/map_test.go index d80fc78..1b78fb6 100644 --- a/map_test.go +++ b/map_test.go @@ -3,6 +3,7 @@ package maps import ( "encoding/gob" "fmt" + "github.com/stretchr/testify/assert" "testing" ) @@ -21,3 +22,17 @@ func ExampleMap_String() { fmt.Print(m) // Output: {"a":1, "b":2} } + +func ExampleCollectMap() { + m1 := StdMap[string, int]{"a": 1, "b": 2, "c": 3} + m2 := CollectMap(m1.All()) + fmt.Println(m2.String()) + // Output: {"a":1, "b":2, "c":3} +} + +func TestMap_Clone(t *testing.T) { + m1 := StdMap[string, int]{"a": 1, "b": 2, "c": 3} + m2 := CollectMap(m1.All()) + m3 := m2.Clone() + assert.True(t, m1.Equal(m3)) +} diff --git a/mapi.go b/mapi.go index 0c564d4..6b77575 100644 --- a/mapi.go +++ b/mapi.go @@ -1,5 +1,7 @@ package maps +import "iter" + // MapI is the interface used by all the Map types. type MapI[K comparable, V any] interface { Setter[K, V] @@ -14,6 +16,12 @@ type MapI[K comparable, V any] interface { Merge(MapI[K, V]) Equal(MapI[K, V]) bool Delete(k K) + All() iter.Seq2[K, V] + KeysIter() iter.Seq[K] + ValuesIter() iter.Seq[V] + Insert(seq iter.Seq2[K, V]) + DeleteFunc(del func(K, V) bool) + String() string } // Setter sets a value in a map. @@ -30,3 +38,22 @@ type Getter[K comparable, V any] interface { type Loader[K comparable, V any] interface { Load(k K) (v V, ok bool) } + +// EqualFunc returns true if all the keys and values of the m1 and m2 are equal. +// +// The function eq is called on the values to determine equality. Keys are compared using ==. +// If one of the maps is a "safe" map, its more efficient to pass that map as m2. +func EqualFunc[K comparable, V1, V2 any](m1 MapI[K, V1], m2 MapI[K, V2], eq func(V1, V2) bool) bool { + if m1.Len() != m2.Len() { + return false + } + ret := true + m2.Range(func(k K, v V2) bool { + if !m1.Has(k) || !eq(m1.Get(k), v) { + ret = false + return false + } + return true + }) + return ret +} diff --git a/mapi_test.go b/mapi_test.go index 75b81c9..de7f075 100644 --- a/mapi_test.go +++ b/mapi_test.go @@ -5,6 +5,9 @@ import ( "encoding/gob" "encoding/json" "github.com/stretchr/testify/assert" + "iter" + "slices" + "strconv" "testing" ) @@ -34,6 +37,11 @@ func runMapiTests[M any](t *testing.T, f makeF) { testMarshalJSON(t, f) testUnmarshalJSON[M](t, f) testDelete(t, f) + testAll(t, f) + testKeysIter(t, f) + testValuesIter(t, f) + testInsert(t, f) + testDeleteFunc(t, f) } func testClear(t *testing.T, f makeF) { @@ -254,3 +262,139 @@ func testDelete(t *testing.T, f makeF) { m.Delete("b") // make sure deleting from an empty map is a no-op }) } + +func testAll(t *testing.T, f makeF) { + t.Run("All", func(t *testing.T) { + m := f(mapT{"a": 1, "b": 2, "c": 3}) + + var actualKeys []string + var actualValues []int + + for k, v := range m.All() { + actualKeys = append(actualKeys, k) + actualValues = append(actualValues, v) + } + slices.Sort(actualKeys) + slices.Sort(actualValues) + + assert.Equal(t, []string{"a", "b", "c"}, actualKeys) + assert.Equal(t, []int{1, 2, 3}, actualValues) + }) +} + +// An iterator that prematurely stops at 2 items. +func limit2[V any](s iter.Seq[V]) iter.Seq[V] { + return func(yield func(V) bool) { + count := 0 + s(func(item V) bool { + count++ + if !yield(item) { + return false + } + if count == 2 { + return false + } + return true + }) + } +} + +func testKeysIter(t *testing.T, f makeF) { + tests := []struct { + name string + m mapTI + s []string + }{ + {"nil", f(), nil}, + {"0", f(mapT{}), nil}, + {"1", f(mapT{"a": 1}), []string{"a"}}, + {"2", f(mapT{"a": 1, "b": 2}), []string{"a", "b"}}, + {"3", f(mapT{"a": 1, "b": 2, "c": 3}), []string{"a", "b", "c"}}, + } + for _, tt := range tests { + t.Run("KeysIter "+tt.name, func(t *testing.T) { + s := slices.Collect(tt.m.KeysIter()) + slices.Sort(s) + assert.Equal(t, tt.s, s) + }) + } + + m := f(mapT{"a": 1, "b": 2, "c": 3}) + s := slices.Collect(limit2(m.KeysIter())) + assert.Len(t, s, 2) +} + +func testValuesIter(t *testing.T, f makeF) { + tests := []struct { + name string + m mapTI + s []int + }{ + {"nil", f(), nil}, + {"0", f(mapT{}), nil}, + {"1", f(mapT{"a": 1}), []int{1}}, + {"2", f(mapT{"a": 1, "b": 2}), []int{1, 2}}, + {"3", f(mapT{"a": 1, "b": 2, "c": 3}), []int{1, 2, 3}}, + } + for _, tt := range tests { + t.Run("ValuesIter "+tt.name, func(t *testing.T) { + s := slices.Collect(tt.m.ValuesIter()) + slices.Sort(s) + assert.Equal(t, tt.s, s) + }) + } + + m := f(mapT{"a": 1, "b": 2, "c": 3}) + s := slices.Collect(limit2(m.ValuesIter())) + assert.Len(t, s, 2) +} + +func testInsert(t *testing.T, f makeF) { + t.Run("Insert", func(t *testing.T) { + m1 := mapT{"a": 1, "b": 2, "c": 3} + m2 := f(mapT{"a": 1}) + m2.Insert(m1.All()) + assert.True(t, m1.Equal(m2)) + }) +} + +func testDeleteFunc(t *testing.T, f makeF) { + t.Run("DeleteFunc", func(t *testing.T) { + m1 := f(mapT{"a": 1, "b": 2, "c": 3}) + m1.DeleteFunc(func(k string, v int) bool { + return v != 2 + }) + assert.Equal(t, 1, m1.Len()) + }) +} + +func TestEqualFunc(t *testing.T) { + type testCase[K comparable, V1 any, V2 any] struct { + name string + m1 MapI[K, V1] + m2 MapI[K, V2] + want bool + } + tests := []testCase[string, int, string]{ + {"Equal Maps", NewMap(StdMap[string, int]{"a": 1}), NewMap(StdMap[string, string]{"a": "1"}), true}, + {"Unequal Keys", NewMap(StdMap[string, int]{"a": 1}), NewMap(StdMap[string, string]{"b": "1"}), false}, + {"Unequal Values", NewMap(StdMap[string, int]{"a": 1}), NewMap(StdMap[string, string]{"a": "2"}), false}, + {"Equal SafeMap", NewSafeMap(StdMap[string, int]{"a": 1}), NewMap(StdMap[string, string]{"a": "1"}), true}, + {"Equal SliceMap", NewSliceMap(StdMap[string, int]{"a": 1}), NewMap(StdMap[string, string]{"a": "1"}), true}, + {"Equal SafeSliceMap", NewSafeSliceMap(StdMap[string, int]{"a": 1}), NewMap(StdMap[string, string]{"a": "1"}), true}, + {"Equal SafeSliceMap 2", NewMap(StdMap[string, int]{"a": 1}), NewSafeSliceMap(StdMap[string, string]{"a": "1"}), true}, + {"Equal Empty Map", NewMap(StdMap[string, int]{}), NewMap(StdMap[string, string]{}), true}, + {"Equal Empty SafeSliceMap", NewSafeSliceMap(StdMap[string, int]{}), NewMap(StdMap[string, string]{}), true}, + {"Unequal Empty Map", NewSafeSliceMap(StdMap[string, int]{}), NewMap(StdMap[string, string]{"a": "1"}), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, EqualFunc(tt.m1, tt.m2, isEqual), "EqualFunc(%v, %v)", tt.m1.String(), tt.m2.String()) + }) + } +} + +func isEqual(i int, s string) bool { + i2, _ := strconv.Atoi(s) + return i == i2 +} diff --git a/safe_map.go b/safe_map.go index a330970..0edc4e5 100644 --- a/safe_map.go +++ b/safe_map.go @@ -1,6 +1,7 @@ package maps import ( + "iter" "sync" ) @@ -15,11 +16,23 @@ import ( // m := new(MyMap) // // This will allow you to swap in a different kind of Map just by changing the type. +// +// Do not make a copy of a SafeMap using the equality operator (=). Use Clone instead. type SafeMap[K comparable, V any] struct { sync.RWMutex items StdMap[K, V] } +// NewSafeMap creates a new SafeMap. +// Pass in zero or more standard maps and the contents of those maps will be copied to the new SafeMap. +func NewSafeMap[K comparable, V any](sources ...map[K]V) *SafeMap[K, V] { + m := new(SafeMap[K, V]) + for _, i := range sources { + m.Copy(Cast(i)) + } + return m +} + // Clear resets the map to an empty map. func (m *SafeMap[K, V]) Clear() { if m.items == nil { @@ -123,13 +136,19 @@ func (m *SafeMap[K, V]) Range(f func(k K, v V) bool) { } // Merge merges the given map with the current one. The given one takes precedent on collisions. +// Deprecated: Use Copy instead. func (m *SafeMap[K, V]) Merge(in MapI[K, V]) { + m.Copy(in) +} + +// Copy copies the keys and values of in into this map, overwriting any duplicates. +func (m *SafeMap[K, V]) Copy(in MapI[K, V]) { if m.items == nil { m.items = make(map[K]V, in.Len()) } m.Lock() defer m.Unlock() - m.items.Merge(in) + m.items.Copy(in) } // Equal returns true if all the keys in the given map exist in this map, and the values are the same @@ -175,3 +194,80 @@ func (m *SafeMap[K, V]) String() string { defer m.RUnlock() return m.items.String() } + +// All returns an iterator over all the items in the map. +func (m *SafeMap[K, V]) All() iter.Seq2[K, V] { + return func(yield func(K, V) bool) { + m.Range(yield) + } +} + +// KeysIter returns an iterator over all the keys in the map. +func (m *SafeMap[K, V]) KeysIter() iter.Seq[K] { + return func(yield func(K) bool) { + if m.items == nil { + return + } + m.RLock() + defer m.RUnlock() + for k, _ := range m.items { + if !yield(k) { + break + } + } + } +} + +// ValuesIter returns an iterator over all the values in the map. +func (m *SafeMap[K, V]) ValuesIter() iter.Seq[V] { + return func(yield func(V) bool) { + if m.items == nil { + return + } + m.RLock() + defer m.RUnlock() + for _, v := range m.items { + if !yield(v) { + break + } + } + } +} + +// Insert adds the values from seq to the map. +// Duplicate keys are overridden. +func (m *SafeMap[K, V]) Insert(seq iter.Seq2[K, V]) { + m.Lock() + defer m.Unlock() + for k, v := range seq { + m.items[k] = v + } +} + +// CollectSafeMap collects key-value pairs from seq into a new SafeMap +// and returns it. +func CollectSafeMap[K comparable, V any](seq iter.Seq2[K, V]) *SafeMap[K, V] { + m := new(SafeMap[K, V]) + m.items = StdMap[K, V]{} + for k, v := range seq { + m.items[k] = v + } + return m +} + +// Clone returns a copy of the SafeMap. This is a shallow clone: +// the new keys and values are set using ordinary assignment. +func (m *SafeMap[K, V]) Clone() *SafeMap[K, V] { + m1 := new(SafeMap[K, V]) + m.RLock() + defer m.RUnlock() + m1.items = m.items.Clone() + return m1 +} + +// DeleteFunc deletes any key/value pairs for which del returns true. +func (m *SafeMap[K, V]) DeleteFunc(del func(K, V) bool) { + m.Lock() + defer m.Unlock() + m.items.DeleteFunc(del) +} diff --git a/safe_map_test.go b/safe_map_test.go index 7140fae..31fd1da 100644 --- a/safe_map_test.go +++ b/safe_map_test.go @@ -38,3 +38,12 @@ func ExampleSafeMap_String() { fmt.Print(m) // Output: {"a":1, "b":2} } + +func TestCollectSafeMap(t *testing.T) { + m := StdMap[string, int]{"a": 1, "b": 2} + m2 := CollectSafeMap(m.All()) + assert.True(t, m.Equal(m2)) + + m3 := m2.Clone() + assert.True(t, m.Equal(m3)) +} diff --git a/safe_slice_map.go b/safe_slice_map.go index a7359ae..4703309 100644 --- a/safe_slice_map.go +++ b/safe_slice_map.go @@ -1,11 +1,9 @@ package maps import ( - "bytes" - "encoding/gob" - "encoding/json" "fmt" - "sort" + "iter" + "slices" "strings" "sync" ) @@ -27,11 +25,21 @@ import ( // This will allow you to swap in a different kind of Map just by changing the type. // // Call SetSortFunc to give the map a function that will keep the keys sorted in a particular order. +// +// Do not make a copy of a SafeSliceMap using the equality operator. Use Clone() instead. type SafeSliceMap[K comparable, V any] struct { sync.RWMutex - items StdMap[K, V] - order []K - lessF func(key1, key2 K, val1, val2 V) bool + sm SliceMap[K, V] +} + +// NewSafeSliceMap creates a new SafeSliceMap. +// Pass in zero or more standard maps and the contents of those maps will be copied to the new SafeSliceMap. +func NewSafeSliceMap[K comparable, V any](sources ...map[K]V) *SafeSliceMap[K, V] { + m := new(SafeSliceMap[K, V]) + for _, i := range sources { + m.Copy(Cast(i)) + } + return m } // SetSortFunc sets the sort function which will determine the order of the items in the map @@ -41,13 +49,7 @@ type SafeSliceMap[K comparable, V any] struct { func (m *SafeSliceMap[K, V]) SetSortFunc(f func(key1, key2 K, val1, val2 V) bool) { m.Lock() defer m.Unlock() - - m.lessF = f - if f != nil && len(m.order) > 0 { - sort.Slice(m.order, func(i, j int) bool { - return f(m.order[i], m.order[j], m.items[m.order[i]], m.items[m.order[j]]) - }) - } + m.sm.SetSortFunc(f) } // Set sets the given key to the given value. @@ -55,104 +57,32 @@ func (m *SafeSliceMap[K, V]) SetSortFunc(f func(key1, key2 K, val1, val2 V) bool // If the key already exists, the range order will not change. If you want the order // to change, call Delete first, and then Set. func (m *SafeSliceMap[K, V]) Set(key K, val V) { - var ok bool - var oldVal V - m.Lock() - - if m.items == nil { - m.items = make(map[K]V) - } - - _, ok = m.items[key] - if m.lessF != nil { - if ok { - // delete old key location - loc := sort.Search(len(m.items), func(n int) bool { - return !m.lessF(m.order[n], key, m.items[m.order[n]], oldVal) - }) - m.order = append(m.order[:loc], m.order[loc+1:]...) - } - - loc := sort.Search(len(m.order), func(n int) bool { - return m.lessF(key, m.order[n], val, m.items[m.order[n]]) - }) - // insert - m.order = append(m.order, key) - copy(m.order[loc+1:], m.order[loc:]) - m.order[loc] = key - } else { - if !ok { - m.order = append(m.order, key) - } - } - m.items[key] = val - m.Unlock() + defer m.Unlock() + m.sm.Set(key, val) } // SetAt sets the given key to the given value, but also inserts it at the index specified. // If the index is bigger than // the length, it puts it at the end. Negative indexes are backwards from the end. func (m *SafeSliceMap[K, V]) SetAt(index int, key K, val V) { - if m.lessF != nil { - panic("cannot use SetAt if you are also using a sort function") - } - - if index >= len(m.order) { - m.Set(key, val) - return - } - - var emptyKey K - - // Be careful here, since both Has and Delete need to acquire locks - if m.Has(key) { - m.Delete(key) - } m.Lock() - if index <= -len(m.items) { - index = 0 - } - if index < 0 { - index = len(m.items) + index - } - - m.order = append(m.order, emptyKey) - copy(m.order[index+1:], m.order[index:]) - m.order[index] = key - - m.items[key] = val - m.Unlock() + defer m.Unlock() + m.sm.SetAt(index, key, val) } // Delete removes the item with the given key. func (m *SafeSliceMap[K, V]) Delete(key K) { m.Lock() - if _, ok := m.items[key]; ok { - if m.lessF != nil { - oldVal := m.items[key] - loc := sort.Search(len(m.items), func(n int) bool { - return !m.lessF(m.order[n], key, m.items[m.order[n]], oldVal) - }) - m.order = append(m.order[:loc], m.order[loc+1:]...) - } else { - for i, v := range m.order { - if v == key { - m.order = append(m.order[:i], m.order[i+1:]...) - break - } - } - } - delete(m.items, key) - } - m.Unlock() + defer m.Unlock() + m.sm.Delete(key) } // Get returns the value based on its key. If the key does not exist, an empty value is returned. func (m *SafeSliceMap[K, V]) Get(key K) (val V) { m.RLock() defer m.RUnlock() - return m.items.Get(key) + return m.sm.Get(key) } // Load returns the value based on its key, and a boolean indicating whether it exists in the map. @@ -160,55 +90,49 @@ func (m *SafeSliceMap[K, V]) Get(key K) (val V) { func (m *SafeSliceMap[K, V]) Load(key K) (val V, ok bool) { m.RLock() defer m.RUnlock() - return m.items.Load(key) + return m.sm.Load(key) } // Has returns true if the given key exists in the map. func (m *SafeSliceMap[K, V]) Has(key K) (ok bool) { m.RLock() defer m.RUnlock() - return m.items.Has(key) + return m.sm.Has(key) } // GetAt returns the value based on its position. If the position is out of bounds, an empty value is returned. func (m *SafeSliceMap[K, V]) GetAt(position int) (val V) { m.RLock() defer m.RUnlock() - if position < len(m.order) && position >= 0 { - val, _ = m.items[m.order[position]] - } - return + return m.sm.GetAt(position) } // GetKeyAt returns the key based on its position. If the position is out of bounds, an empty value is returned. func (m *SafeSliceMap[K, V]) GetKeyAt(position int) (key K) { m.RLock() defer m.RUnlock() - if position < len(m.order) && position >= 0 { - key = m.order[position] - } - return + return m.sm.GetKeyAt(position) } // Values returns a slice of the values in the order they were added or sorted. func (m *SafeSliceMap[K, V]) Values() (vals []V) { m.RLock() defer m.RUnlock() - return m.items.Values() + return m.sm.Values() } // Keys returns the keys of the map, in the order they were added or sorted. func (m *SafeSliceMap[K, V]) Keys() (keys []K) { m.RLock() defer m.RUnlock() - return m.items.Keys() + return m.sm.Keys() } // Len returns the number of items in the map. func (m *SafeSliceMap[K, V]) Len() int { m.RLock() defer m.RUnlock() - return m.items.Len() + return m.sm.Len() } // MarshalBinary implements the BinaryMarshaler interface to convert the map to a byte stream. @@ -217,38 +141,15 @@ func (m *SafeSliceMap[K, V]) Len() int { func (m *SafeSliceMap[K, V]) MarshalBinary() (data []byte, err error) { m.RLock() defer m.RUnlock() - - buf := new(bytes.Buffer) - encoder := gob.NewEncoder(buf) - - err = encoder.Encode(map[K]V(m.items)) - if err == nil { - err = encoder.Encode(m.order) - } - data = buf.Bytes() - return + return m.sm.MarshalBinary() } // UnmarshalBinary implements the BinaryUnmarshaler interface to convert a byte stream to a // SafeSliceMap. func (m *SafeSliceMap[K, V]) UnmarshalBinary(data []byte) (err error) { - var items map[K]V - var order []K - m.Lock() defer m.Unlock() - - buf := bytes.NewBuffer(data) - dec := gob.NewDecoder(buf) - if err = dec.Decode(&items); err == nil { - err = dec.Decode(&order) - } - - if err == nil { - m.items = items - m.order = order - } - return err + return m.sm.UnmarshalBinary(data) } // MarshalJSON implements the json.Marshaler interface to convert the map into a JSON object. @@ -257,34 +158,27 @@ func (m *SafeSliceMap[K, V]) MarshalJSON() (data []byte, err error) { defer m.RUnlock() // Json objects are unordered - return m.items.MarshalJSON() + return m.sm.MarshalJSON() } // UnmarshalJSON implements the json.Unmarshaler interface to convert a json object to a Map. // The JSON must start with an object. func (m *SafeSliceMap[K, V]) UnmarshalJSON(data []byte) (err error) { - var items map[K]V - m.Lock() defer m.Unlock() - - if err = json.Unmarshal(data, &items); err == nil { - m.items = items - // Create a default order, since these are inherently unordered - m.order = make([]K, len(m.items)) - i := 0 - for k := range m.items { - m.order[i] = k - i++ - } - } - return + return m.sm.UnmarshalJSON(data) } // Merge the given map into the current one. +// Deprecated: Use copy instead. func (m *SafeSliceMap[K, V]) Merge(in MapI[K, V]) { + m.Copy(in) +} + +// Copy will copy the given map into the current one. +func (m *SafeSliceMap[K, V]) Copy(in MapI[K, V]) { in.Range(func(k K, v V) bool { - m.Set(k, v) // This will lock and unlock + m.Set(k, v) // This will lock and unlock, making sure that a long operation does not deadlock another go routine. return true }) } @@ -293,16 +187,12 @@ func (m *SafeSliceMap[K, V]) Merge(in MapI[K, V]) { // they were placed in the map, or in if you sorted the map, in your custom order. // If f returns false, it stops the iteration. This pattern is taken from sync.Map. func (m *SafeSliceMap[K, V]) Range(f func(key K, value V) bool) { - if m == nil || m.items == nil { + if m == nil || m.sm.items == nil { // prevent unnecessary lock return } m.RLock() defer m.RUnlock() - for _, k := range m.order { - if !f(k, m.items[k]) { - break - } - } + m.sm.Range(f) } // Equal returns true if all the keys and values are equal, regardless of the order. @@ -312,14 +202,13 @@ func (m *SafeSliceMap[K, V]) Range(f func(key K, value V) bool) { func (m *SafeSliceMap[K, V]) Equal(m2 MapI[K, V]) bool { m.RLock() defer m.RUnlock() - return m.items.Equal(m2) + return m.sm.Equal(m2) } // Clear removes all the items in the map. func (m *SafeSliceMap[K, V]) Clear() { m.Lock() - m.items = nil - m.order = nil + m.sm.Clear() m.Unlock() } @@ -338,3 +227,77 @@ func (m *SafeSliceMap[K, V]) String() string { s += "}" return s } + +// All returns an iterator over all the items in the map in the order they were entered or sorted. +func (m *SafeSliceMap[K, V]) All() iter.Seq2[K, V] { + return func(yield func(K, V) bool) { + m.Range(yield) + } +} + +// KeysIter returns an iterator over all the keys in the map. +func (m *SafeSliceMap[K, V]) KeysIter() iter.Seq[K] { + return func(yield func(K) bool) { + if m == nil || m.sm.items == nil { + return + } + m.RLock() + defer m.RUnlock() + m.sm.KeysIter()(yield) + } +} + +// ValuesIter returns an iterator over all the values in the map. +func (m *SafeSliceMap[K, V]) ValuesIter() iter.Seq[V] { + return func(yield func(V) bool) { + if m == nil || m.sm.items == nil { + return + } + m.RLock() + defer m.RUnlock() + m.sm.ValuesIter()(yield) + } +} + +// Insert adds the values from seq to the end of the map. +// Duplicate keys are overridden but not moved. +// Will lock and unlock for each item in seq to give time to other go routines. +func (m *SafeSliceMap[K, V]) Insert(seq iter.Seq2[K, V]) { + for k, v := range seq { + m.Set(k, v) + } +} + +// CollectSafeSliceMap collects key-value pairs from seq into a new SafeSliceMap +// and returns it. +func CollectSafeSliceMap[K comparable, V any](seq iter.Seq2[K, V]) *SafeSliceMap[K, V] { + m := new(SafeSliceMap[K, V]) + + // no need to lock here since this is a private variable + for k, v := range seq { + m.sm.Set(k, v) + } + return m +} + +// Clone returns a copy of the SafeSliceMap. This is a shallow clone of the keys and values: +// the new keys and values are set using ordinary assignment. The order is preserved. +func (m *SafeSliceMap[K, V]) Clone() *SafeSliceMap[K, V] { + m1 := new(SafeSliceMap[K, V]) + m.RLock() + defer m.RUnlock() + m1.sm.items = m.sm.items.Clone() + m1.sm.order = slices.Clone(m.sm.order) + m1.sm.lessF = m.sm.lessF + return m1 +} + +// DeleteFunc deletes any key/value pairs for which del returns true. +// Items are ranged in order. +// This function locks the entire slice structure for the entirety of the call, +// so be careful to avoid deadlocks when calling this on a very big structure. +func (m *SafeSliceMap[K, V]) DeleteFunc(del func(K, V) bool) { + m.Lock() + defer m.Unlock() + m.sm.DeleteFunc(del) +} diff --git a/safe_slice_map_test.go b/safe_slice_map_test.go index a1d811c..4666229 100644 --- a/safe_slice_map_test.go +++ b/safe_slice_map_test.go @@ -121,3 +121,50 @@ func TestSafeSliceMap_GetAt(t *testing.T) { assert.Equal(t, 0, m.GetAt(0)) assert.Equal(t, "", m.GetKeyAt(0)) } + +func TestSafeSliceMap_Clone(t *testing.T) { + // Create a new SafeSliceMap and populate it + originalMap := NewSafeSliceMap[string, int]() + originalMap.Set("b", 2) + originalMap.Set("a", 1) + originalMap.Set("c", 3) + + // Clone the original map + clonedMap := originalMap.Clone() + + assert.True(t, clonedMap.Equal(originalMap)) + + // Verify that modifying the cloned map does not affect the original map + clonedMap.Set("a", 100) + if originalMap.Get("a") == 100 { + t.Error("Modification in cloned map affected the original map") + } + + // Verify that the original map remains unchanged + if originalMap.Get("a") != 1 { + t.Errorf("Expected value for key 'a' in original map to be %d, got %d", 1, originalMap.Get("a")) + } + + // Verify order is the same + values := clonedMap.Values() + expectedValues := []int{2, 100, 3} + assert.Equal(t, expectedValues, values) +} + +func TestCollectSafeSliceMap(t *testing.T) { + // Create a sequence of key-value pairs + s := NewSafeSliceMap[string, int]() + s.Set("b", 2) + s.Set("a", 1) + s.Set("c", 3) + seq := s.All() + // Use CollectSafeSliceMap to create a new SafeSliceMap + collectedMap := CollectSliceMap(seq) + + assert.True(t, s.Equal(collectedMap)) + + // Ensure the order of keys follows the insertion order + keys := collectedMap.Keys() + expectedKeys := []string{"b", "a", "c"} + assert.Equal(t, keys, expectedKeys) +} diff --git a/set.go b/set.go index f7a6267..8e9ba59 100644 --- a/set.go +++ b/set.go @@ -5,9 +5,10 @@ import ( "encoding/gob" "encoding/json" "fmt" + "iter" ) -// Set is a collection the keeps track of membership. +// Set is a collection that keeps track of membership. // // The recommended way to create a Set is to first declare a concrete type alias, and then call // new on it, like this: @@ -21,6 +22,14 @@ type Set[K comparable] struct { items StdMap[K, struct{}] } +func NewSet[K comparable](values ...K) *Set[K] { + s := new(Set[K]) + for _, k := range values { + s.Add(k) + } + return s +} + // Clear resets the set to an empty set func (m *Set[K]) Clear() { m.items = nil @@ -74,11 +83,14 @@ func (m *Set[K]) Add(k ...K) SetI[K] { } // Merge adds the values from the given set to the set. +// Deprecated: Call Copy instead. func (m *Set[K]) Merge(in SetI[K]) { - if m == nil { - panic("cannot merge into a nil set") - } - if in == nil { + m.Copy(in) +} + +// Copy adds the values from in to the set. +func (m *Set[K]) Copy(in SetI[K]) { + if in == nil || in.Len() == 0 { return } if m.items == nil { @@ -162,3 +174,44 @@ func (m *Set[K]) String() string { ret += "}" return ret } + +// All returns an iterator over all the items in the set. Order is not determinate. +func (m *Set[K]) All() iter.Seq[K] { + return m.items.KeysIter() +} + +// Insert adds the values from seq to the map. +// Duplicates are overridden. +func (m *Set[K]) Insert(seq iter.Seq[K]) { + if m.items == nil { + m.items = NewStdMap[K, struct{}]() + } + + for k := range seq { + m.Add(k) + } +} + +// CollectSet collects values from seq into a new Set +// and returns it. +func CollectSet[K comparable](seq iter.Seq[K]) *Set[K] { + m := NewSet[K]() + m.Insert(seq) + return m +} + +// Clone returns a copy of the Set. This is a shallow clone: +// the new keys and values are set using ordinary assignment. +func (m *Set[K]) Clone() *Set[K] { + m1 := NewSet[K]() + m1.items = m.items.Clone() + return m1 +} + +// DeleteFunc deletes any values for which del returns true. +func (m *Set[K]) DeleteFunc(del func(K) bool) { + del2 := func(k K, s struct{}) bool { + return del(k) + } + m.items.DeleteFunc(del2) +} diff --git a/set_test.go b/set_test.go index d867cdd..d2fa083 100644 --- a/set_test.go +++ b/set_test.go @@ -3,6 +3,7 @@ package maps import ( "encoding/gob" "fmt" + "github.com/stretchr/testify/assert" "sort" "testing" ) @@ -28,3 +29,17 @@ func ExampleSet_String() { fmt.Print(v) // Output: [a b] } + +func ExampleCollectSet() { + m1 := NewSet("a", "b", "c") + m2 := CollectSet(m1.All()) + fmt.Println(m2.String()) + // Output: {"a","b","c"} +} + +func TestSet_Clone(t *testing.T) { + m1 := NewSet("a", "b", "c") + m2 := CollectSet(m1.All()) + m3 := m2.Clone() + assert.True(t, m1.Equal(m3)) +} diff --git a/seti.go b/seti.go index 74d1a02..f5db475 100644 --- a/seti.go +++ b/seti.go @@ -1,5 +1,7 @@ package maps +import "iter" + // SetI is the interface used by all the Set types. type SetI[K comparable] interface { Add(k ...K) SetI[K] @@ -11,4 +13,8 @@ type SetI[K comparable] interface { Merge(SetI[K]) Equal(SetI[K]) bool Delete(k K) + All() iter.Seq[K] + Insert(seq iter.Seq[K]) + Clone() *Set[K] + DeleteFunc(del func(K) bool) } diff --git a/seti_test.go b/seti_test.go index a209daf..9395144 100644 --- a/seti_test.go +++ b/seti_test.go @@ -5,6 +5,7 @@ import ( "encoding/gob" "encoding/json" "github.com/stretchr/testify/assert" + "slices" "testing" ) @@ -33,6 +34,9 @@ func runSetITests[M any](t *testing.T, f makeSetF) { testSetMarshalJSON(t, f) testSetUnmarshalJSON[M](t, f) testSetDelete(t, f) + testSetAll(t, f) + testSetInsert(t, f) + testSetDeleteFunc(t, f) } func testSetClear(t *testing.T, f makeSetF) { @@ -224,3 +228,37 @@ func testSetDelete(t *testing.T, f makeSetF) { m.Delete("b") // make sure deleting from an empty map is a no-op }) } + +func testSetAll(t *testing.T, f makeSetF) { + t.Run("All", func(t *testing.T) { + m := f("a", "b", "c") + + var actualValues []string + + for k := range m.All() { + actualValues = append(actualValues, k) + } + slices.Sort(actualValues) + + assert.Equal(t, []string{"a", "b", "c"}, actualValues) + }) +} + +func testSetInsert(t *testing.T, f makeSetF) { + t.Run("Insert", func(t *testing.T) { + m1 := f("a", "b", "c") + m2 := f("a") + m2.Insert(m1.All()) + assert.True(t, m1.Equal(m2)) + }) +} + +func testSetDeleteFunc(t *testing.T, f makeSetF) { + t.Run("DeleteFunc", func(t *testing.T) { + m1 := f("a", "b", "c") + m1.DeleteFunc(func(k string) bool { + return k != "b" + }) + assert.Equal(t, 1, m1.Len()) + }) +} diff --git a/slice_map.go b/slice_map.go index b9e1b7b..ebe4c47 100644 --- a/slice_map.go +++ b/slice_map.go @@ -5,6 +5,8 @@ import ( "encoding/gob" "encoding/json" "fmt" + "iter" + "slices" "sort" "strings" ) @@ -30,6 +32,16 @@ type SliceMap[K comparable, V any] struct { lessF func(key1, key2 K, val1, val2 V) bool } +// NewSliceMap creates a new SliceMap. +// Pass in zero or more standard maps and the contents of those maps will be copied to the new SafeMap. +func NewSliceMap[K comparable, V any](sources ...map[K]V) *SliceMap[K, V] { + m := new(SliceMap[K, V]) + for _, i := range sources { + m.Copy(Cast(i)) + } + return m +} + // SetSortFunc sets the sort function which will determine the order of the items in the map // on an ongoing basis. Normally, items will iterate in the order they were added. // @@ -94,10 +106,6 @@ func (m *SliceMap[K, V]) Set(key K, val V) { // If the index is bigger than // the length, it puts it at the end. Negative indexes are backwards from the end. func (m *SliceMap[K, V]) SetAt(index int, key K, val V) { - if m == nil { - panic("cannot set a value on a nil SliceMap") - } - if m.lessF != nil { panic("cannot use SetAt if you are also using a sort function") } @@ -140,11 +148,11 @@ func (m *SliceMap[K, V]) Delete(key K) { loc := sort.Search(len(m.items), func(n int) bool { return !m.lessF(m.order[n], key, m.items[m.order[n]], oldVal) }) - m.order = append(m.order[:loc], m.order[loc+1:]...) + m.order = slices.Delete(m.order, loc, loc+1) } else { for i, v := range m.order { if v == key { - m.order = append(m.order[:i], m.order[i+1:]...) + m.order = slices.Delete(m.order, i, i+1) break } } @@ -200,20 +208,24 @@ func (m *SliceMap[K, V]) GetKeyAt(position int) (key K) { return } -// Values returns a slice of the values in the order they were added or sorted. +// Values returns a new slice of the values in the order they were added or sorted. func (m *SliceMap[K, V]) Values() (vals []V) { if m == nil { return } - return m.items.Values() + values := make([]V, 0, len(m.order)) + for _, k := range m.order { + values = append(values, m.items[k]) + } + return values } -// Keys returns the keys of the map, in the order they were added or sorted +// Keys returns a new slice of the keys of the map, in the order they were added or sorted func (m *SliceMap[K, V]) Keys() (keys []K) { if m == nil { return } - return m.items.Keys() + return slices.Clone(m.order) } // Len returns the number of items in the map @@ -296,6 +308,7 @@ func (m *SliceMap[K, V]) UnmarshalJSON(data []byte) (err error) { } // Merge the given map into the current one. +// Deprecated: use Copy instead. func (m *SliceMap[K, V]) Merge(in MapI[K, V]) { in.Range(func(k K, v V) bool { m.Set(k, v) @@ -303,6 +316,15 @@ func (m *SliceMap[K, V]) Merge(in MapI[K, V]) { }) } +// Copy copies the keys and values of in into the current one. +// Duplicate keys will have the values replaced, but not the order. +func (m *SliceMap[K, V]) Copy(in MapI[K, V]) { + in.Range(func(k K, v V) bool { + m.Set(k, v) + return true + }) +} + // Range will call the given function with every key and value in the order // they were placed in the map, or in if you sorted the map, in your custom order. // If f returns false, it stops the iteration. This pattern is taken from sync.Map. @@ -354,36 +376,74 @@ func (m *SliceMap[K, V]) String() string { return s } -// Equaler is the interface that implements an Equal function and that provides a way for the -// various MapI like objects to determine if they are equal. -// -// In particular, if your Map has -// non-comparible values, like a slice, but you would still like to call Equal() on that -// map, define an Equal function on the values to do the comparison. For example: -// -// type mySlice []int -// -// func (s mySlice) Equal(b any) bool { -// if s2, ok := b.(mySlice); ok { -// if len(s) == len(s2) { -// for i, v := range s2 { -// if s[i] != v { -// return false -// } -// } -// return true -// } -// } -// return false -// } -type Equaler interface { - Equal(a any) bool +// All returns an iterator over all the items in the map in the order they were entered or sorted. +func (m *SliceMap[K, V]) All() iter.Seq2[K, V] { + return func(yield func(K, V) bool) { + m.Range(yield) + } } -func equalValues(a, b any) bool { - if e, ok := a.(Equaler); ok { - return e.Equal(b) +// KeysIter returns an iterator over all the keys in the map. +func (m *SliceMap[K, V]) KeysIter() iter.Seq[K] { + return func(yield func(K) bool) { + if m == nil || m.items == nil { + return + } + for _, k := range m.order { + if !yield(k) { + break + } + } } +} + +// ValuesIter returns an iterator over all the values in the map. +func (m *SliceMap[K, V]) ValuesIter() iter.Seq[V] { + return func(yield func(V) bool) { + if m == nil || m.items == nil { + return + } + for _, k := range m.order { + if !yield(m.items[k]) { + break + } + } + } +} - return a == b +// Insert adds the values from seq to the end of the map. +// Duplicate keys are overridden but not moved. +func (m *SliceMap[K, V]) Insert(seq iter.Seq2[K, V]) { + for k, v := range seq { + m.Set(k, v) + } +} + +// CollectSliceMap collects key-value pairs from seq into a new SliceMap +// and returns it. +func CollectSliceMap[K comparable, V any](seq iter.Seq2[K, V]) *SliceMap[K, V] { + m := new(SliceMap[K, V]) + m.Insert(seq) + return m +} + +// Clone returns a copy of the SliceMap. This is a shallow clone of the keys and values: +// the new keys and values are set using ordinary assignment. The order is preserved. +func (m *SliceMap[K, V]) Clone() *SliceMap[K, V] { + m1 := new(SliceMap[K, V]) + m1.items = m.items.Clone() + m1.order = slices.Clone(m.order) + m1.lessF = m.lessF + return m1 +} + +// DeleteFunc deletes any key/value pairs for which del returns true. +// Items are ranged in order. +func (m *SliceMap[K, V]) DeleteFunc(del func(K, V) bool) { + for i, k := range slices.Backward(m.order) { + if del(k, m.items[k]) { + m.items.Delete(k) + m.order = slices.Delete(m.order, i, i+1) + } + } } diff --git a/slice_map_test.go b/slice_map_test.go index 624b571..2b5f172 100644 --- a/slice_map_test.go +++ b/slice_map_test.go @@ -100,7 +100,7 @@ func TestSliceMap_SetAt(t *testing.T) { // delete and set will put to end m.Delete("e") m.Set("e", 6) - assert.Equal(t, 6, m.GetAt(4)) + assert.Equal(t, 6, m.GetAt(m.Len()-1)) // Or force it to new location m.SetAt(3, "e", 6) @@ -169,3 +169,50 @@ func TestSliceMap_NilMap(t *testing.T) { }) assert.Equal(t, "", m.String()) } + +func TestSliceMap_Clone(t *testing.T) { + // Create a new SafeSliceMap and populate it + originalMap := NewSliceMap[string, int]() + originalMap.Set("b", 2) + originalMap.Set("a", 1) + originalMap.Set("c", 3) + + // Clone the original map + clonedMap := originalMap.Clone() + + assert.True(t, clonedMap.Equal(originalMap)) + + // Verify that modifying the cloned map does not affect the original map + clonedMap.Set("a", 100) + if originalMap.Get("a") == 100 { + t.Error("Modification in cloned map affected the original map") + } + + // Verify that the original map remains unchanged + if originalMap.Get("a") != 1 { + t.Errorf("Expected value for key 'a' in original map to be %d, got %d", 1, originalMap.Get("a")) + } + + // Verify order is the same + values := clonedMap.Values() + expectedValues := []int{2, 100, 3} + assert.Equal(t, expectedValues, values) +} + +func TestCollectSliceMap(t *testing.T) { + // Create a sequence of key-value pairs + s := NewSliceMap[string, int]() + s.Set("b", 2) + s.Set("a", 1) + s.Set("c", 3) + seq := s.All() + // Use CollectSafeSliceMap to create a new SafeSliceMap + collectedMap := CollectSafeSliceMap(seq) + + assert.True(t, s.Equal(collectedMap)) + + // Ensure the order of keys follows the insertion order + keys := collectedMap.Keys() + expectedKeys := []string{"b", "a", "c"} + assert.Equal(t, keys, expectedKeys) +} diff --git a/std_map.go b/std_map.go index 5e88a2a..047756b 100644 --- a/std_map.go +++ b/std_map.go @@ -5,6 +5,8 @@ import ( "encoding/gob" "encoding/json" "fmt" + "iter" + "maps" "strings" ) @@ -12,7 +14,8 @@ import ( // // The zero value is NOT settable. Use NewStdMap to create a new StdMap object, or use standard // map instantiation syntax like this: -// m := StdMap[string, int]{"a":1} +// +// m := StdMap[string, int]{"a":1} // // StdMap is mostly a convenience type for making a standard Go map into a MapI interface. // Generally, you should use Map instead, as it presents a consistent interface that allows you @@ -22,11 +25,12 @@ type StdMap[K comparable, V any] map[K]V // NewStdMap creates a new map that maps values of type K to values of type V. // Pass in zero or more standard maps and the contents of those maps will be copied to the new StdMap. // You can also create a new StdMap like this: -// m := StdMap[string, int]{"a":1} +// +// m := StdMap[string, int]{"a":1} func NewStdMap[K comparable, V any](sources ...map[K]V) StdMap[K, V] { m := StdMap[K, V]{} for _, i := range sources { - m.Merge(Cast(i)) + m.Copy(Cast(i)) } return m } @@ -51,9 +55,15 @@ func (m StdMap[K, V]) Len() int { } // Merge copies the items from in to the map, overwriting any conflicting keys. +// Deprecated: use Copy instead func (m StdMap[K, V]) Merge(in MapI[K, V]) { + m.Copy(in) +} + +// Copy copies the items from in to the map, overwriting any conflicting keys. +func (m StdMap[K, V]) Copy(in MapI[K, V]) { if m == nil { - panic("cannot merge into a nil map") + panic("cannot copy into a nil map") } in.Range(func(k K, v V) bool { m[k] = v @@ -65,6 +75,8 @@ func (m StdMap[K, V]) Merge(in MapI[K, V]) { // This is the same interface as sync.Map.Range(). // While its safe to call methods of the map from within the Range function, its discouraged. // If you ever switch to one of the SafeMap maps, it will cause a deadlock. +// +// You can also range over a map using All(). func (m StdMap[K, V]) Range(f func(k K, v V) bool) { for k, v := range m { if !f(k, v) { @@ -141,7 +153,7 @@ func (m StdMap[K, V]) Values() (values []V) { // Equal returns true if all the keys and values are equal. // // If the values are not comparable, you should implement the Equaler interface on the values. -// Otherwise you will get a runtime panic. +// Otherwise, you will get a runtime panic. func (m StdMap[K, V]) Equal(m2 MapI[K, V]) bool { if m.Len() != m2.Len() { return false @@ -176,9 +188,10 @@ func (m StdMap[K, V]) MarshalBinary() ([]byte, error) { // UnmarshalBinary implements the BinaryUnmarshaler interface to convert a byte stream to a Map. // // Note that you will likely need to register the unmarshaller at init time with gob like this: -// func init() { -// gob.Register(new(Map[K,V])) -// } +// +// func init() { +// gob.Register(new(Map[K,V])) +// } func (m *StdMap[K, V]) UnmarshalBinary(data []byte) (err error) { b := bytes.NewBuffer(data) dec := gob.NewDecoder(b) @@ -203,3 +216,43 @@ func (m *StdMap[K, V]) UnmarshalJSON(in []byte) (err error) { *m = v return } + +// All returns an iterator over all the items in the map. +func (m StdMap[K, V]) All() iter.Seq2[K, V] { + return maps.All(m) +} + +// KeysIter returns an iterator over all the keys in the map. +func (m StdMap[K, V]) KeysIter() iter.Seq[K] { + return maps.Keys(m) +} + +// ValuesIter returns an iterator over all the values in the map. +func (m StdMap[K, V]) ValuesIter() iter.Seq[V] { + return maps.Values(m) +} + +// Insert adds the values from seq to the map. +// Duplicate keys are overridden. +func (m StdMap[K, V]) Insert(seq iter.Seq2[K, V]) { + maps.Insert(m, seq) +} + +// CollectStdMap collects key-value pairs from seq into a new StdMap +// and returns it. +func CollectStdMap[K comparable, V any](seq iter.Seq2[K, V]) StdMap[K, V] { + m := StdMap[K, V]{} + m.Insert(seq) + return m +} + +// Clone returns a copy of the StdMap. This is a shallow clone: +// the new keys and values are set using ordinary assignment. +func (m StdMap[K, V]) Clone() StdMap[K, V] { + return maps.Clone(m) +} + +// DeleteFunc deletes any key/value pairs for which del returns true. +func (m StdMap[K, V]) DeleteFunc(del func(K, V) bool) { + maps.DeleteFunc(m, del) +} diff --git a/std_map_test.go b/std_map_test.go index ae7a654..dde7118 100644 --- a/std_map_test.go +++ b/std_map_test.go @@ -2,8 +2,10 @@ package maps import ( "encoding/gob" + "encoding/json" "fmt" "github.com/stretchr/testify/assert" + "slices" "testing" ) @@ -236,3 +238,161 @@ func TestEqualValues(t *testing.T) { f := []float32{1, 2} assert.Panics(t, func() { equalValues(e, f) }) } + +func TestMarshalBinary(t *testing.T) { + m := StdMap[string, int]{"a": 1, "b": 2} + + // Marshal the map + data, err := m.MarshalBinary() + if err != nil { + t.Fatalf("Error marshalling: %v", err) + } + + // Unmarshal the data + var m2 StdMap[string, int] + err = m2.UnmarshalBinary(data) + if err != nil { + t.Fatalf("Error unmarshalling: %v", err) + } + + // Compare the original and unmarshalled maps + assert.Equal(t, m, m2) +} + +func TestMarshalJSON(t *testing.T) { + m := StdMap[string, int]{"a": 1, "b": 2} + + // Marshal the map to JSON + jsonData, err := json.Marshal(m) + if err != nil { + t.Fatalf("Error marshalling to JSON: %v", err) + } + + // Assert the JSON output + expectedJSON := `{"a":1,"b":2}` + assert.Equal(t, expectedJSON, string(jsonData)) +} + +func TestUnmarshalJSON(t *testing.T) { + jsonData := []byte(`{"a":1,"b":2}`) + + // Unmarshal the JSON into a StdMap + var m StdMap[string, int] + err := json.Unmarshal(jsonData, &m) + if err != nil { + t.Fatalf("Error unmarshalling from JSON: %v", err) + } + + // Assert the unmarshalled map + assert.Equal(t, 1, m["a"]) + assert.Equal(t, 2, m["b"]) +} + +func TestUnmarshalJSONInvalidInput(t *testing.T) { + invalidJSON := []byte(`invalid json`) + + // Unmarshal the invalid JSON + var m StdMap[string, int] + err := json.Unmarshal(invalidJSON, &m) + assert.Error(t, err) +} + +func TestDelete(t *testing.T) { + m := StdMap[string, int]{"a": 1, "b": 2} + + // Delete an existing key + m.Delete("a") + _, ok := m["a"] + assert.False(t, ok) + + // Delete a non-existent key + m.Delete("c") + // No error should occur, and the map should remain unchanged + assert.Equal(t, 2, m["b"]) +} + +func TestString(t *testing.T) { + m := StdMap[string, int]{"a": 1, "b": 2} + + // Get the string representation + str := m.String() + + // Check the string representation + expected := `{"a":1, "b":2}` + assert.Equal(t, expected, str) +} + +func ExampleStdMap_All() { + m := StdMap[string, int]{"a": 1, "b": 2, "c": 3} + + var actualKeys []string + var actualValues []int + + for k, v := range m.All() { + actualKeys = append(actualKeys, k) + actualValues = append(actualValues, v) + } + slices.Sort(actualKeys) + slices.Sort(actualValues) + fmt.Println(actualKeys) + fmt.Println(actualValues) + + // Output: [a b c] + // [1 2 3] +} + +func ExampleStdMap_KeysIter() { + m := StdMap[string, int]{"a": 1, "b": 2, "c": 3} + + var actualKeys []string + + for k := range m.KeysIter() { + actualKeys = append(actualKeys, k) + } + slices.Sort(actualKeys) + fmt.Println(actualKeys) + + // Output: [a b c] +} + +func ExampleStdMap_ValuesIter() { + m := StdMap[string, int]{"a": 1, "b": 2, "c": 3} + + var actualValues []int + + for v := range m.ValuesIter() { + actualValues = append(actualValues, v) + } + slices.Sort(actualValues) + fmt.Println(actualValues) + + // Output: [1 2 3] +} + +func TestStdMap_Insert(t *testing.T) { + m1 := StdMap[string, int]{"a": 1, "b": 2, "c": 3} + m2 := StdMap[string, int]{"a": 1} + m2.Insert(m1.All()) + assert.True(t, m1.Equal(m2)) +} + +func TestStdMap_Collect(t *testing.T) { + m1 := StdMap[string, int]{"a": 1, "b": 2, "c": 3} + m2 := CollectStdMap(m1.All()) + assert.True(t, m1.Equal(m2)) +} + +func TestStdMap_Clone(t *testing.T) { + m1 := StdMap[string, int]{"a": 1, "b": 2, "c": 3} + m2 := m1.Clone() + assert.True(t, m1.Equal(m2)) +} + +func ExampleStdMap_DeleteFunc() { + m1 := StdMap[string, int]{"a": 1, "b": 2, "c": 3} + m1.DeleteFunc(func(k string, v int) bool { + return v != 2 + }) + fmt.Println(m1.String()) + // Output: {"b":2} +} From 3c59731dfa45e10525523a4f1f512c177a3c53d3 Mon Sep 17 00:00:00 2001 From: spekary Date: Fri, 8 Nov 2024 17:40:50 -0800 Subject: [PATCH 2/2] Fixing testing environment --- .github/workflows/go.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index f4c6e9a..359f314 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -18,7 +18,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.18 + go-version: 1.23 - name: Build run: go build -v ./...