From 17b8c38162eaa8b14758ecb6d499b994cffb4c3c Mon Sep 17 00:00:00 2001 From: bloeys Date: Mon, 9 Sep 2024 07:47:36 +0400 Subject: [PATCH] Add GetDifference method --- nset.go | 39 +++++++++++++++++++++++++++++++++++++++ nset_test.go | 31 ++++++++++++++++++++++--------- 2 files changed, 61 insertions(+), 9 deletions(-) diff --git a/nset.go b/nset.go index 430b753..03fa76b 100644 --- a/nset.go +++ b/nset.go @@ -165,6 +165,8 @@ func (n *NSet[T]) GetBitMask(x T) StorageType { return 1 << FastModPower2(((x<>BucketIndexingBits), StorageTypeBits) } +// Union does n1=Union(n1, n2), so the current set will be updated +// such that its a union of its old value and the passed set func (n *NSet[T]) Union(otherSet *NSet[T]) { for i := 0; i < BucketCount; i++ { @@ -192,6 +194,8 @@ func (n *NSet[T]) Union(otherSet *NSet[T]) { } } +// GetIntersection returns a new set that's the intersection between +// this set and the passed set func (n *NSet[T]) GetIntersection(otherSet *NSet[T]) *NSet[T] { outSet := NewNSet[T]() @@ -225,6 +229,41 @@ func (n *NSet[T]) GetIntersection(otherSet *NSet[T]) *NSet[T] { return outSet } +// GetDifference returns a new set that contains the elements in this set +// that are not in the passed set. +// +// For example, if s1=(1,2,3,4,5) and s2=(1,3,4), the output is +// s3=Diff(s1,s2)=(2,5) +func (n *NSet[T]) GetDifference(otherSet *NSet[T]) *NSet[T] { + + outSet := NewNSet[T]() + + for i := 0; i < BucketCount; i++ { + + b1 := &n.Buckets[i] + b2 := &otherSet.Buckets[i] + + outSet.StorageUnitCount += b1.StorageUnitCount + + newB := &outSet.Buckets[i] + newB.StorageUnitCount = b1.StorageUnitCount + newB.Data = make([]StorageType, newB.StorageUnitCount) + + for j := uint32(0); j < b1.StorageUnitCount && j < b2.StorageUnitCount; j++ { + + newStorage := b1.Data[j] & (^b2.Data[j]) + newB.Data[j] = newStorage + outSet.SetBits += uint64(bits.OnesCount64(uint64(newStorage))) + } + + if b1.StorageUnitCount > b2.StorageUnitCount { + copy(newB.Data[b2.StorageUnitCount:], b1.Data[b2.StorageUnitCount:]) + } + } + + return outSet +} + // GetAllElements returns all the added numbers added to NSet. // // NOTE: Be careful with this if you have a lot of elements in NSet because NSet is compressed while the returned array is not. diff --git a/nset_test.go b/nset_test.go index 41a1398..89a3fe2 100755 --- a/nset_test.go +++ b/nset_test.go @@ -57,7 +57,7 @@ func TestNSet(t *testing.T) { AllTrue(t, n1.Contains(0), n1.Contains(63), !n1.Contains(1), nCopy.ContainsAll(0, 1, 63, math.MaxUint32)) - //Intersections + // Intersections n2 := nset.NewNSet[uint32]() n2.AddMany(1000, 63, 5, 10) @@ -79,7 +79,7 @@ func TestNSet(t *testing.T) { AllTrue(t, n4n5.Len() == 4, n4n5.Len() == n4n5Twin.Len(), n4n5.ContainsAll(0, 1, 64, math.MaxUint32), !n4n5.Contains(63), n4n5Twin.IsEq(n4n5)) - //Union + // Union n6 := nset.NewNSet[uint32]() n6.AddMany(1, 4, 7, 100, 1000) @@ -90,14 +90,14 @@ func TestNSet(t *testing.T) { AllTrue(t, n6.Len() == 5, n7.Len() == 6, n6.ContainsAll(1, 4, 7, 100, 1000), !n6.Contains(math.MaxUint32), n7.ContainsAll(1, 4, 7, 100, 1000, math.MaxUint32), n7.StorageUnitCount == n7OldStorageUnitCount+n6.StorageUnitCount-1) - //UnionSets + // UnionSets n7 = nset.NewNSet[uint32]() n7.AddMany(4, math.MaxUint32) unionedSet := nset.UnionSets(n6, n7) AllTrue(t, unionedSet.Len() == 6, !n6.Contains(math.MaxUint32), !n7.ContainsAny(7, 100, 1000), unionedSet.ContainsAll(4, 7, 100, 1000, math.MaxUint32), unionedSet.StorageUnitCount == n6.StorageUnitCount+n7OldStorageUnitCount-1) - //Equality + // Equality AllTrue(t, !n6.IsEq(n7)) n7.Union(n6) @@ -106,12 +106,27 @@ func TestNSet(t *testing.T) { n6.Union(n7) AllTrue(t, n6.IsEq(n7)) - //GetAllElements + // GetAllElements n8 := nset.NewNSet[uint32]() n8.AddMany(0, 1, 55, 1000, 10000) n8Elements := n8.GetAllElements() AllTrue(t, len(n8Elements) == 5, n8Elements[0] == 0, n8Elements[1] == 1, n8Elements[2] == 55, n8Elements[3] == 1000, n8Elements[4] == 10000) + + // GetDifference + nDiff1 := nset.NewNSet[uint32]() + nDiff1.AddMany(1, 2, 3) + + nDiff2 := nset.NewNSet[uint32]() + nDiff2.AddMany(1, 3, 4) + + nDiff3 := nDiff1.GetDifference(nDiff2) + + AllTrue(t, nDiff3.SetBits == 1, nDiff3.StorageUnitCount == 1, nDiff3.Contains(2), !nDiff3.ContainsAny(1, 3, 4)) + + nDiff1.AddMany(1, 2, 3, 4, 5, math.MaxUint32) + nDiff3 = nDiff1.GetDifference(nDiff2) + AllTrue(t, nDiff3.SetBits == 2, nDiff3.ContainsAll(2, 5, math.MaxUint32), !nDiff3.ContainsAny(1, 3, 4)) } func TestNSetFullRange(t *testing.T) { @@ -151,16 +166,14 @@ func TestNSetFullRange(t *testing.T) { func AllTrue(t *testing.T, values ...bool) (success bool) { - success = true - for i := 0; i < len(values); i++ { if !values[i] { t.Fatalf("Expected 'true' but got 'false'\n") - success = false + return false } } - return success + return true } func IsEq[T comparable](t *testing.T, expected, val T) bool {