Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

namespace BitFaster.Caching.Benchmarks.Lfu
{
#if Windows
[DisassemblyDiagnoser(printSource: true, maxDepth: 4)]
#endif
[SimpleJob(RuntimeMoniker.Net60)]
[SimpleJob(RuntimeMoniker.Net80)]
[SimpleJob(RuntimeMoniker.Net90)]
Expand Down
3 changes: 3 additions & 0 deletions BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

namespace BitFaster.Caching.Benchmarks.Lfu
{
#if Windows
[DisassemblyDiagnoser(printSource: true, maxDepth: 4)]
#endif
[SimpleJob(RuntimeMoniker.Net60)]
[SimpleJob(RuntimeMoniker.Net80)]
[SimpleJob(RuntimeMoniker.Net90)]
Expand Down
62 changes: 15 additions & 47 deletions BitFaster.Caching/Lfu/CmSketchCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -255,39 +255,26 @@ private void Reset()
}

#if !NETSTANDARD2_0
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private unsafe int EstimateFrequencyAvx(T value)
{
int blockHash = Spread(comparer.GetHashCode(value));
int counterHash = Rehash(blockHash);
int block = (blockHash & blockMask) << 3;

Vector128<int> h = Vector128.Create(counterHash);
h = Avx2.ShiftRightLogicalVariable(h.AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
Vector128<int> h = Avx2.ShiftRightLogicalVariable(Vector128.Create(counterHash).AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
Vector128<int> index = Avx2.ShiftLeftLogical(Avx2.And(Avx2.ShiftRightLogical(h, 1), Vector128.Create(15)), 2);
Vector128<int> blockOffset = Avx2.Add(Avx2.Add(Vector128.Create(block), Avx2.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6));

var index = Avx2.ShiftRightLogical(h, 1);
index = Avx2.And(index, Vector128.Create(15)); // j - counter index
Vector128<int> offset = Avx2.And(h, Vector128.Create(1));
Vector128<int> blockOffset = Avx2.Add(Vector128.Create(block), offset); // i - table index
blockOffset = Avx2.Add(blockOffset, Vector128.Create(0, 2, 4, 6)); // + (i << 1)
Vector256<ulong> indexLong = Avx2.PermuteVar8x32(Vector256.Create(index, Vector128<int>.Zero), Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7)).AsUInt64();

#if NET6_0_OR_GREATER
long* tablePtr = tableAddr;
#else
fixed (long* tablePtr = table)
#endif
{
Vector256<long> tableVector = Avx2.GatherVector256(tablePtr, blockOffset, 8);
index = Avx2.ShiftLeftLogical(index, 2);

// convert index from int to long via permute
Vector256<long> indexLong = Vector256.Create(index, Vector128<int>.Zero).AsInt64();
Vector256<int> permuteMask2 = Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7);
indexLong = Avx2.PermuteVar8x32(indexLong.AsInt32(), permuteMask2).AsInt64();
tableVector = Avx2.ShiftRightLogicalVariable(tableVector, indexLong.AsUInt64());
tableVector = Avx2.And(tableVector, Vector256.Create(0xfL));

Vector256<int> permuteMask = Vector256.Create(0, 2, 4, 6, 1, 3, 5, 7);
Vector128<ushort> count = Avx2.PermuteVar8x32(tableVector.AsInt32(), permuteMask)
Vector128<ushort> count = Avx2.PermuteVar8x32(Avx2.And(Avx2.ShiftRightLogicalVariable(Avx2.GatherVector256(tablePtr, blockOffset, 8), indexLong), Vector256.Create(0xfL)).AsInt32(), Vector256.Create(0, 2, 4, 6, 1, 3, 5, 7))
.GetLower()
.AsUInt16();

Expand All @@ -302,52 +289,33 @@ private unsafe int EstimateFrequencyAvx(T value)
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private unsafe void IncrementAvx(T value)
{
int blockHash = Spread(comparer.GetHashCode(value));
int counterHash = Rehash(blockHash);
int block = (blockHash & blockMask) << 3;

Vector128<int> h = Vector128.Create(counterHash);
h = Avx2.ShiftRightLogicalVariable(h.AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
Vector128<int> h = Avx2.ShiftRightLogicalVariable(Vector128.Create(counterHash).AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
Vector128<int> index = Avx2.ShiftLeftLogical(Avx2.And(Avx2.ShiftRightLogical(h, 1), Vector128.Create(15)), 2);
Vector128<int> blockOffset = Avx2.Add(Avx2.Add(Vector128.Create(block), Avx2.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6));

Vector128<int> index = Avx2.ShiftRightLogical(h, 1);
index = Avx2.And(index, Vector128.Create(15)); // j - counter index
Vector128<int> offset = Avx2.And(h, Vector128.Create(1));
Vector128<int> blockOffset = Avx2.Add(Vector128.Create(block), offset); // i - table index
blockOffset = Avx2.Add(blockOffset, Vector128.Create(0, 2, 4, 6)); // + (i << 1)
Vector256<ulong> offsetLong = Avx2.PermuteVar8x32(Vector256.Create(index, Vector128<int>.Zero), Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7)).AsUInt64();
Vector256<long> mask = Avx2.ShiftLeftLogicalVariable(Vector256.Create(0xfL), offsetLong);

#if NET6_0_OR_GREATER
long* tablePtr = tableAddr;
#else
fixed (long* tablePtr = table)
#endif
{
Vector256<long> tableVector = Avx2.GatherVector256(tablePtr, blockOffset, 8);

// j == index
index = Avx2.ShiftLeftLogical(index, 2);
Vector256<long> offsetLong = Vector256.Create(index, Vector128<int>.Zero).AsInt64();

Vector256<int> permuteMask = Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7);
offsetLong = Avx2.PermuteVar8x32(offsetLong.AsInt32(), permuteMask).AsInt64();

// mask = (0xfL << offset)
Vector256<long> fifteen = Vector256.Create(0xfL);
Vector256<long> mask = Avx2.ShiftLeftLogicalVariable(fifteen, offsetLong.AsUInt64());

// (table[i] & mask) != mask)
// Note masked is 'equal' - therefore use AndNot below
Vector256<long> masked = Avx2.CompareEqual(Avx2.And(tableVector, mask), mask);

// 1L << offset
Vector256<long> inc = Avx2.ShiftLeftLogicalVariable(Vector256.Create(1L), offsetLong.AsUInt64());
Vector256<long> masked = Avx2.CompareEqual(Avx2.And(Avx2.GatherVector256(tablePtr, blockOffset, 8), mask), mask);

// Mask to zero out non matches (add zero below) - first operand is NOT then AND result (order matters)
inc = Avx2.AndNot(masked, inc);
Vector256<long> inc = Avx2.AndNot(masked, Avx2.ShiftLeftLogicalVariable(Vector256.Create(1L), offsetLong));

Vector256<byte> result = Avx2.CompareEqual(masked.AsByte(), Vector256<byte>.Zero);
bool wasInc = Avx2.MoveMask(result.AsByte()) == unchecked((int)(0b1111_1111_1111_1111_1111_1111_1111_1111));
bool wasInc = Avx2.MoveMask(Avx2.CompareEqual(masked.AsByte(), Vector256<byte>.Zero).AsByte()) == unchecked((int)(0b1111_1111_1111_1111_1111_1111_1111_1111));

tablePtr[blockOffset.GetElement(0)] += inc.GetElement(0);
tablePtr[blockOffset.GetElement(1)] += inc.GetElement(1);
Expand Down
Loading