diff --git a/bits.go b/bits.go index f61013a..f435b6d 100644 --- a/bits.go +++ b/bits.go @@ -53,3 +53,6 @@ func hasZeroByte(x uint64) bitset { func castUint64(m *metadata) uint64 { return *(*uint64)((unsafe.Pointer)(m)) } + +//go:linkname fastrand runtime.fastrand +func fastrand() uint32 diff --git a/bits_amd64.go b/bits_amd64.go index 120ebfd..8b91f57 100644 --- a/bits_amd64.go +++ b/bits_amd64.go @@ -18,6 +18,7 @@ package swiss import ( "math/bits" + _ "unsafe" "github.com/dolthub/swiss/simd" ) @@ -44,3 +45,6 @@ func nextMatch(b *bitset) (s uint32) { *b &= ^(1 << s) // clear bit |s| return } + +//go:linkname fastrand runtime.fastrand +func fastrand() uint32 diff --git a/go.mod b/go.mod index 5477652..87f1c8f 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/dolthub/swiss go 1.18 require ( - github.com/dolthub/maphash v0.0.0-20221220182448-74e1e1ea1577 + github.com/dolthub/maphash v0.1.0 github.com/stretchr/testify v1.8.1 ) diff --git a/go.sum b/go.sum index ceaf83d..0e9147a 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dolthub/maphash v0.0.0-20221220182448-74e1e1ea1577 h1:SegEguMxToBn045KRHLIUlF2/jR7Y2qD6fF+3tdOfvI= -github.com/dolthub/maphash v0.0.0-20221220182448-74e1e1ea1577/go.mod h1:gkg4Ch4CdCDu5h6PMriVLawB7koZ+5ijb9puGMV50a4= +github.com/dolthub/maphash v0.1.0 h1:bsQ7JsF4FkkWyrP3oCnFJgrCUAFbFf3kOl4L/QxPDyQ= +github.com/dolthub/maphash v0.1.0/go.mod h1:gkg4Ch4CdCDu5h6PMriVLawB7koZ+5ijb9puGMV50a4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/map.go b/map.go index 6cc6dfe..07f01af 100644 --- a/map.go +++ b/map.go @@ -15,8 +15,6 @@ package swiss import ( - "math/rand" - "github.com/dolthub/maphash" ) @@ -214,7 +212,7 @@ func (m *Map[K, V]) Iter(cb func(k K, v V) (stop bool)) { // we rehash during iteration ctrl, groups := m.ctrl, m.groups // pick a random starting group - g := rand.Intn(len(groups)) + g := randIntN(len(groups)) for n := 0; n < len(groups); n++ { for s, c := range ctrl[g] { if c == empty || c == tombstone { @@ -226,17 +224,33 @@ func (m *Map[K, V]) Iter(cb func(k K, v V) (stop bool)) { } } g++ - if g >= len(groups) { + if g >= uint32(len(groups)) { g = 0 } } } +// Clear removes all elements from the Map. +func (m *Map[K, V]) Clear() { + for i, c := range m.ctrl { + for j := range c { + m.ctrl[i][j] = empty + } + } + m.resident, m.dead = 0, 0 +} + // Count returns the number of elements in the Map. func (m *Map[K, V]) Count() int { return int(m.resident - m.dead) } +// Capacity returns the number of additional elements +// the can be added to the Map before resizing. +func (m *Map[K, V]) Capacity() int { + return int(m.limit - m.resident) +} + // find returns the location of |key| if present, or its insertion location if absent. // for performance, find is manually inlined into public methods. func (m *Map[K, V]) find(key K, hi h1, lo h2) (g, s uint32, ok bool) { @@ -278,7 +292,7 @@ func (m *Map[K, V]) rehash(n uint32) { for i := range m.ctrl { m.ctrl[i] = newEmptyMetadata() } - m.hash = maphash.NewHasher[K]() + m.hash = maphash.NewSeed(m.hash) m.limit = n * maxAvgGroupLoad m.resident, m.dead = 0, 0 for g := range ctrl { @@ -325,3 +339,8 @@ func probeStart(hi h1, groups int) uint32 { func fastModN(x, n uint32) uint32 { return uint32((uint64(x) * uint64(n)) >> 32) } + +// randIntN returns a random number in the interval [0, n). +func randIntN(n int) uint32 { + return fastModN(fastrand(), uint32(n)) +} diff --git a/map_test.go b/map_test.go index 85f4e12..846edb0 100644 --- a/map_test.go +++ b/map_test.go @@ -50,6 +50,14 @@ func TestSwissMap(t *testing.T) { t.Run("uint32=100_000", func(t *testing.T) { testSwissMap(t, genUint32Data(100_000)) }) + t.Run("string capacity", func(t *testing.T) { + testSwissMapCapacity(t, func(n int) []string { + return genStringData(16, n) + }) + }) + t.Run("uint32 capacity", func(t *testing.T) { + testSwissMapCapacity(t, genUint32Data) + }) } func testSwissMap[K comparable](t *testing.T, keys []K) { @@ -67,6 +75,9 @@ func testSwissMap[K comparable](t *testing.T, keys []K) { t.Run("delete", func(t *testing.T) { testMapDelete(t, keys) }) + t.Run("clear", func(t *testing.T) { + testMapClear(t, keys) + }) t.Run("iter", func(t *testing.T) { testMapIter(t, keys) }) @@ -173,6 +184,29 @@ func testMapDelete[K comparable](t *testing.T, keys []K) { assert.Equal(t, 0, m.Count()) } +func testMapClear[K comparable](t *testing.T, keys []K) { + m := NewMap[K, int](0) + assert.Equal(t, 0, m.Count()) + for i, key := range keys { + m.Put(key, i) + } + assert.Equal(t, len(keys), m.Count()) + m.Clear() + assert.Equal(t, 0, m.Count()) + for _, key := range keys { + ok := m.Has(key) + assert.False(t, ok) + _, ok = m.Get(key) + assert.False(t, ok) + } + var calls int + m.Iter(func(k K, v int) (stop bool) { + calls++ + return + }) + assert.Equal(t, 0, calls) +} + func testMapIter[K comparable](t *testing.T, keys []K) { m := NewMap[K, int](uint32(len(keys))) for i, key := range keys { @@ -214,6 +248,32 @@ func testMapGrow[K comparable](t *testing.T, keys []K) { } } +func testSwissMapCapacity[K comparable](t *testing.T, gen func(n int) []K) { + // Capacity() behavior depends on |groupSize| + // which varies by processor architecture. + caps := []uint32{ + 1 * maxAvgGroupLoad, + 2 * maxAvgGroupLoad, + 3 * maxAvgGroupLoad, + 4 * maxAvgGroupLoad, + 5 * maxAvgGroupLoad, + 10 * maxAvgGroupLoad, + 25 * maxAvgGroupLoad, + 50 * maxAvgGroupLoad, + 100 * maxAvgGroupLoad, + } + for _, c := range caps { + m := NewMap[K, K](c) + assert.Equal(t, int(c), m.Capacity()) + keys := gen(rand.Intn(int(c))) + for _, k := range keys { + m.Put(k, k) + } + assert.Equal(t, int(c)-len(keys), m.Capacity()) + assert.Equal(t, int(c), m.Count()+m.Capacity()) + } +} + func testProbeStats[K comparable](t *testing.T, keys []K) { runTest := func(load float32) { n := uint32(len(keys))