From 932c84328201ee51d89ab3fe7f7ec4f98969482c Mon Sep 17 00:00:00 2001 From: bloeys Date: Sun, 21 Jul 2024 23:01:03 +0400 Subject: [PATCH] Keep count of set bits to offer .Len+optimizations --- nset.go | 119 +++++++++++++++++++++++++++++++++++++-------------- nset_test.go | 49 +++++++++++++++------ 2 files changed, 122 insertions(+), 46 deletions(-) diff --git a/nset.go b/nset.go index 00cb457..86ed5b4 100644 --- a/nset.go +++ b/nset.go @@ -1,10 +1,12 @@ package nset import ( + "bytes" "fmt" "math/bits" "reflect" "strings" + "unsafe" ) var _ fmt.Stringer = &NSet[uint8]{} @@ -13,14 +15,18 @@ type BucketType uint8 type StorageType uint64 const ( - BucketCount = 128 + BucketCount = 128 + // StorageTypeBits is the number of bits used per storage unit in each bucket. + // + // NOTE: this must be a power of 2, otherwise FastModPower2 will break and must be replaced by a normal x%y + // NOTE: GetStorageUnitIndex must be adjusted if this value is changed StorageTypeBits = 64 BucketIndexingBits = 7 ) -//IntsIf is limited to uint32 because we can store ALL 4 Billion uint32 numbers -//in 512MB with NSet (instead of the normal 16GB for an array of all uint32s). -//But if we allow uint64 (or int, since int can be 64-bit) users can easily put a big 64-bit number and use more RAM than maybe Google and crash. +// IntsIf is limited to uint32 because we can store ALL 4 Billion uint32 numbers +// in 512MB with NSet (instead of the normal 16GB for an array of all uint32s). +// But if we allow uint64 (or int, since int can be 64-bit) users can easily put a big 64-bit number and use more RAM than maybe Google and crash. type IntsIf interface { uint8 | uint16 | uint32 } @@ -35,6 +41,7 @@ type NSet[T IntsIf] struct { //StorageUnitCount the number of uint64 integers that are used to indicate presence of numbers in the set StorageUnitCount uint32 shiftAmount T + SetBits uint64 } func (n *NSet[T]) Add(x T) { @@ -51,7 +58,11 @@ func (n *NSet[T]) Add(x T) { bucket.StorageUnitCount += storageUnitsToAdd } - bucket.Data[unitIndex] |= n.GetBitMask(x) + oldStorage := bucket.Data[unitIndex] + newStorage := oldStorage | n.GetBitMask(x) + + bucket.Data[unitIndex] = newStorage + n.SetBits += uint64(bits.OnesCount64(uint64(^oldStorage) & uint64(newStorage))) } func (n *NSet[T]) AddMany(values ...T) { @@ -71,9 +82,12 @@ func (n *NSet[T]) AddMany(values ...T) { bucket.StorageUnitCount += storageUnitsToAdd } - bucket.Data[unitIndex] |= n.GetBitMask(x) - } + oldStorage := bucket.Data[unitIndex] + newStorage := oldStorage | n.GetBitMask(x) + bucket.Data[unitIndex] = newStorage + n.SetBits += uint64(bits.OnesCount64(uint64(^oldStorage) & uint64(newStorage))) + } } func (n *NSet[T]) Remove(x T) { @@ -84,7 +98,11 @@ func (n *NSet[T]) Remove(x T) { return } - b.Data[unitIndex] ^= n.GetBitMask(x) + oldStorage := b.Data[unitIndex] + newStorage := oldStorage &^ n.GetBitMask(x) + + b.Data[unitIndex] = newStorage + n.SetBits -= uint64(bits.OnesCount64(uint64(oldStorage) & uint64(^newStorage))) } func (n *NSet[T]) Contains(x T) bool { @@ -129,16 +147,22 @@ func (n *NSet[T]) GetBucketIndex(x T) BucketType { } func (n *NSet[T]) GetStorageUnitIndex(x T) uint32 { + //The top 'n' bits are used to select the bucket so we need to remove them before finding storage //unit and bit mask. This is done by shifting left by 4 which removes the top 'n' bits, //then shifting right by 4 which puts the bits back to their original place, but now //the top 'n' bits are zeros. - return uint32(((x << BucketIndexingBits) >> BucketIndexingBits) / StorageTypeBits) + + // Since StorageTypeBits is known and is a power of 2, we can replace the division + // with a right shift. + // + // The below return is equal to: return uint32(((x << BucketIndexingBits) >> BucketIndexingBits) / StorageTypeBits) + return uint32(((x << BucketIndexingBits) >> BucketIndexingBits) >> 6) } func (n *NSet[T]) GetBitMask(x T) StorageType { //Removes top 'n' bits - return 1 << (((x << BucketIndexingBits) >> BucketIndexingBits) % StorageTypeBits) + return 1 << FastModPower2(((x<>BucketIndexingBits), StorageTypeBits) } func (n *NSet[T]) Union(otherSet *NSet[T]) { @@ -158,7 +182,12 @@ func (n *NSet[T]) Union(otherSet *NSet[T]) { } for j := 0; j < len(b1.Data) && j < len(b2.Data); j++ { - b1.Data[j] |= b2.Data[j] + + oldStorage := b1.Data[j] + newStorage := oldStorage | b2.Data[j] + + b1.Data[j] = newStorage + n.SetBits += uint64(bits.OnesCount64(uint64(^oldStorage) & uint64(newStorage))) } } } @@ -187,19 +216,26 @@ func (n *NSet[T]) GetIntersection(otherSet *NSet[T]) *NSet[T] { outSet.StorageUnitCount += storageUnitsToAdd } - newB.Data[j] = b1.Data[j] & b2.Data[j] + newStorage := b1.Data[j] & b2.Data[j] + newB.Data[j] = newStorage + outSet.SetBits += uint64(bits.OnesCount64(uint64(newStorage))) } } return outSet } -//GetAllElements returns all the added numbers added to NSet. -//NOTE: Be careful with this if you have a lot of elements in NSet because NSet is compressed while the returned array is not. -//In the worst case (all uint32s stored) the returned array will be ~4.2 billion elements and will use 16+ GBs of RAM. +// GetAllElements returns all the added numbers added to NSet. +// +// NOTE: Be careful with this if you have a lot of elements in NSet because NSet is compressed while the returned array is not. +// In the worst case (all uint32s stored) the returned array will be ~4.2 billion elements and will use 16+ GBs of RAM. func (n *NSet[T]) GetAllElements() []T { - elements := make([]T, 0) + elements := make([]T, 0, n.SetBits) + + if n.SetBits == 0 { + return elements + } for i := 0; i < BucketCount; i++ { @@ -211,11 +247,11 @@ func (n *NSet[T]) GetAllElements() []T { for j := 0; j < len(b1.Data); j++ { storageUnit := b1.Data[j] - onesCount := bits.OnesCount64(uint64(storageUnit)) - if onesCount == 0 { + if storageUnit == 0 { continue } - elementsToAdd := make([]T, 0, onesCount) + + onesCount := bits.OnesCount64(uint64(storageUnit)) mask := StorageType(1 << 0) //This will be used to check set bits. Numbers will be reconstructed only for set bits firstStorageUnitValue := T(j*StorageTypeBits) | bucketIndexBits //StorageUnitIndex = noBucketBitsX / StorageTypeBits. So: noBucketBitsX = StorageUnitIndex * StorageTypeBits; Then: x = noBucketBitsX | bucketIndexBits @@ -223,14 +259,12 @@ func (n *NSet[T]) GetAllElements() []T { for k := T(0); onesCount > 0 && k < StorageTypeBits; k++ { if storageUnit&mask > 0 { - elementsToAdd = append(elementsToAdd, firstStorageUnitValue+k) + elements = append(elements, firstStorageUnitValue+k) onesCount-- } mask <<= 1 } - - elements = append(elements, elementsToAdd...) } } @@ -239,7 +273,7 @@ func (n *NSet[T]) GetAllElements() []T { func (n *NSet[T]) IsEq(otherSet *NSet[T]) bool { - if n.StorageUnitCount != otherSet.StorageUnitCount { + if n.SetBits != otherSet.SetBits { return false } @@ -255,11 +289,13 @@ func (n *NSet[T]) IsEq(otherSet *NSet[T]) bool { b1 := &n.Buckets[i] b2 := &otherSet.Buckets[i] - for j := 0; j < len(b1.Data); j++ { + bucketsEqual := (b1.StorageUnitCount == 0 && b2.StorageUnitCount == 0) || bytes.Equal( + unsafe.Slice((*byte)(unsafe.Pointer(&b1.Data[0])), len(b1.Data)*int(unsafe.Sizeof(b1.Data[0]))), + unsafe.Slice((*byte)(unsafe.Pointer(&b2.Data[0])), len(b2.Data)*int(unsafe.Sizeof(b2.Data[0]))), + ) - if b1.Data[j] != b2.Data[j] { - return false - } + if !bucketsEqual { + return false } } @@ -284,7 +320,7 @@ func (n *NSet[T]) HasIntersection(otherSet *NSet[T]) bool { return false } -//String returns a string of the storage as bytes separated by spaces. A comma is between each storage unit +// String returns a string of the storage as bytes separated by spaces. A comma is between each storage unit func (n *NSet[T]) String() string { b := strings.Builder{} @@ -334,9 +370,20 @@ func (n *NSet[T]) Copy() *NSet[T] { } +// Len returns the number of values stored (i.e. bits set to 1). +// It is the same as NSet.SetBits. +func (n *NSet[T]) Len() uint64 { + return n.SetBits +} + func UnionSets[T IntsIf](set1, set2 *NSet[T]) *NSet[T] { newSet := NewNSet[T]() + + // This is an optimization that makes it so that we only need to count bits + // when doing union with set2 + newSet.SetBits = set1.SetBits + for i := 0; i < BucketCount; i++ { b1 := &set1.Buckets[i] @@ -355,18 +402,26 @@ func UnionSets[T IntsIf](set1, set2 *NSet[T]) *NSet[T] { newSet.StorageUnitCount += bucketSize //Union fields of both sets on the new set - for j := 0; j < len(b1.Data); j++ { - newB.Data[j] |= b1.Data[j] - } + copy(newB.Data, b1.Data) for j := 0; j < len(b2.Data); j++ { - newB.Data[j] |= b2.Data[j] + + oldStorage := newB.Data[j] + newStorage := oldStorage | b2.Data[j] + + newB.Data[j] = newStorage + newSet.SetBits += uint64(bits.OnesCount64(uint64(^oldStorage) & uint64(newStorage))) } } return newSet } +// FastModPower2 is a fast version of x%y that only works when y is a power of 2 +func FastModPower2[T uint8 | uint16 | uint32 | uint64](x, y T) T { + return x & (y - 1) +} + func NewNSet[T IntsIf]() *NSet[T] { n := &NSet[T]{ diff --git a/nset_test.go b/nset_test.go index 791bcda..bbe9295 100755 --- a/nset_test.go +++ b/nset_test.go @@ -22,10 +22,23 @@ var ( func TestNSet(t *testing.T) { n1 := nset.NewNSet[uint32]() + + // Double add/remove of the same value is not only important to test SetBits, but also + // to test for bugs where double adding/removing incorrectly flips bits (checked using Contains()) n1.Add(0) + AllTrue(t, n1.Len() == 1) + + n1.Add(0) + AllTrue(t, n1.Len() == 1) + n1.Add(1) + AllTrue(t, n1.Len() == 2) + n1.Add(63) + AllTrue(t, n1.Len() == 3) + n1.Add(math.MaxUint32) + AllTrue(t, n1.Len() == 4) AllTrue(t, n1.Contains(0), n1.Contains(1), n1.Contains(63), n1.Contains(math.MaxUint32), !n1.Contains(10), !n1.Contains(599)) AllTrue(t, n1.ContainsAll(0, 1, 63), !n1.ContainsAll(9, 0, 1), !n1.ContainsAll(0, 1, 63, 99)) @@ -35,7 +48,12 @@ func TestNSet(t *testing.T) { IsEq(t, math.MaxUint32/64/nset.BucketCount, n1.GetStorageUnitIndex(math.MaxUint32)) nCopy := n1.Copy() + n1.Remove(1) + AllTrue(t, n1.Len() == 3) + + n1.Remove(1) + AllTrue(t, n1.Len() == 3) AllTrue(t, n1.Contains(0), n1.Contains(63), !n1.Contains(1), nCopy.ContainsAll(0, 1, 63, math.MaxUint32)) @@ -59,25 +77,25 @@ func TestNSet(t *testing.T) { n4n5Twin := nset.NewNSet[uint32]() n4n5Twin.AddMany(0, 1, 64, math.MaxUint32) - AllTrue(t, n4n5.ContainsAll(0, 1, 64, math.MaxUint32), !n4n5.Contains(63), n4n5Twin.IsEq(n4n5)) + AllTrue(t, n4n5.Len() == 4, n4n5.Len() == n4n5Twin.Len(), n4n5.ContainsAll(0, 1, 64, math.MaxUint32), !n4n5.Contains(63), n4n5Twin.IsEq(n4n5)) //Union n6 := nset.NewNSet[uint32]() - n6.AddMany(4, 7, 100, 1000) + n6.AddMany(1, 4, 7, 100, 1000) n7 := nset.NewNSet[uint32]() - n7.AddMany(math.MaxUint32) + n7.AddMany(1, math.MaxUint32) n7OldStorageUnitCount := n7.StorageUnitCount n7.Union(n6) - AllTrue(t, n6.ContainsAll(4, 7, 100, 1000), !n6.Contains(math.MaxUint32), n7.ContainsAll(4, 7, 100, 1000, math.MaxUint32), n7.StorageUnitCount == n7OldStorageUnitCount+n6.StorageUnitCount) + AllTrue(t, n6.Len() == 5, n7.Len() == 6, n6.ContainsAll(1, 4, 7, 100, 1000), !n6.Contains(math.MaxUint32), n7.ContainsAll(1, 4, 7, 100, 1000, math.MaxUint32), n7.StorageUnitCount == n7OldStorageUnitCount+n6.StorageUnitCount-1) //UnionSets n7 = nset.NewNSet[uint32]() - n7.AddMany(math.MaxUint32) + n7.AddMany(4, math.MaxUint32) unionedSet := nset.UnionSets(n6, n7) - AllTrue(t, !n6.Contains(math.MaxUint32), !n7.ContainsAny(4, 7, 100, 1000), unionedSet.ContainsAll(4, 7, 100, 1000, math.MaxUint32), unionedSet.StorageUnitCount == n6.StorageUnitCount+n7OldStorageUnitCount) + AllTrue(t, unionedSet.Len() == 6, !n6.Contains(math.MaxUint32), !n7.ContainsAny(7, 100, 1000), unionedSet.ContainsAll(4, 7, 100, 1000, math.MaxUint32), unionedSet.StorageUnitCount == n6.StorageUnitCount+n7OldStorageUnitCount-1) //Equality AllTrue(t, !n6.IsEq(n7)) @@ -130,15 +148,18 @@ func TestNSetFullRange(t *testing.T) { } -func AllTrue(t *testing.T, values ...bool) bool { +func AllTrue(t *testing.T, values ...bool) (success bool) { + + success = true for i := 0; i < len(values); i++ { if !values[i] { - t.Errorf("Expected 'true' but got 'false'\n") + t.Fatalf("Expected 'true' but got 'false'\n") + success = false } } - return true + return success } func IsEq[T comparable](t *testing.T, expected, val T) bool { @@ -147,7 +168,7 @@ func IsEq[T comparable](t *testing.T, expected, val T) bool { return true } - t.Errorf("Expected '%v' but got '%v'\n", expected, val) + t.Fatalf("Expected '%v' but got '%v'\n", expected, val) return false } @@ -814,12 +835,12 @@ func BenchmarkNSetUnionRand(b *testing.B) { b.StopTimer() - rand.Seed(RandSeed) + randGen := rand.New(rand.NewSource(RandSeed)) s1 := nset.NewNSet[uint32]() s2 := nset.NewNSet[uint32]() for i := uint32(0); i < maxBenchSize; i++ { - r := rand.Uint32() + r := randGen.Uint32() s1.Add(r) s2.Add(r) } @@ -837,12 +858,12 @@ func BenchmarkMapUnionRand(b *testing.B) { b.StopTimer() - rand.Seed(RandSeed) + randGen := rand.New(rand.NewSource(RandSeed)) m1 := map[uint32]struct{}{} m2 := map[uint32]struct{}{} for i := uint32(0); i < maxBenchSize; i++ { - r := rand.Uint32() + r := randGen.Uint32() m1[r] = struct{}{} m2[r] = struct{}{} }