Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

using System.Collections.Generic;
using Benchly;
using BenchmarkDotNet.Attributes;
using BenchmarkDotNet.Jobs;
using BitFaster.Caching.Lfu;
Expand All @@ -9,6 +10,7 @@ namespace BitFaster.Caching.Benchmarks.Lfu
[SimpleJob(RuntimeMoniker.Net60)]
[MemoryDiagnoser(displayGenColumns: false)]
[HideColumns("Job", "Median", "RatioSD", "Alloc Ratio")]
[ColumnChart(Title = "Sketch Increment ({JOB})")]
public class SketchIncrement
{
const int iterations = 1_048_576;
Expand Down
10 changes: 10 additions & 0 deletions BitFaster.Caching.UnitTests/Intrinsics.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#if NETCOREAPP3_1_OR_GREATER
using System.Runtime.Intrinsics.X86;
#endif
#if NET6_0_OR_GREATER
using System.Runtime.Intrinsics.Arm;
#endif

using Xunit;

namespace BitFaster.Caching.UnitTests
Expand All @@ -10,8 +14,14 @@ public static class Intrinsics
public static void SkipAvxIfNotSupported<I>()
{
#if NETCOREAPP3_1_OR_GREATER
#if NET6_0_OR_GREATER
// when we are trying to test Avx2/Arm64, skip the test if it's not supported
Skip.If(typeof(I) == typeof(DetectIsa) && !(Avx2.IsSupported || AdvSimd.Arm64.IsSupported));
#else
// when we are trying to test Avx2, skip the test if it's not supported
Skip.If(typeof(I) == typeof(DetectIsa) && !Avx2.IsSupported);
#endif

#else
Skip.If(true);
#endif
Expand Down
15 changes: 12 additions & 3 deletions BitFaster.Caching.UnitTests/Lfu/CmSketchTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

namespace BitFaster.Caching.UnitTests.Lfu
{
// Test with AVX2 if it is supported
public class CMSketchAvx2Tests : CmSketchTestBase<DetectIsa>
// Test with AVX2/ARM64 if it is supported
public class CMSketchIntrinsicsTests : CmSketchTestBase<DetectIsa>
{
}

// Test with AVX2 disabled
// Test with AVX2/ARM64 disabled
public class CmSketchTests : CmSketchTestBase<DisableHardwareIntrinsics>
{
}
Expand All @@ -29,14 +29,23 @@ public CmSketchTestBase()
public void Repro()
{
sketch = new CmSketchCore<int, I>(1_048_576, EqualityComparer<int>.Default);
var baseline = new CmSketchCore<int, DisableHardwareIntrinsics>(1_048_576, EqualityComparer<int>.Default);

for (int i = 0; i < 1_048_576; i++)
{
if (i % 3 == 0)
{
sketch.Increment(i);
baseline.Increment(i);
}
}

baseline.Size.Should().Be(sketch.Size);

for (int i = 0; i < 1_048_576; i++)
{
sketch.EstimateFrequency(i).Should().Be(baseline.EstimateFrequency(i));
}
}


Expand Down
28 changes: 25 additions & 3 deletions BitFaster.Caching/Intrinsics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
using System.Runtime.Intrinsics.X86;
#endif

#if NET6_0
using System.Runtime.Intrinsics.Arm;
#endif

namespace BitFaster.Caching
{
/// <summary>
Expand All @@ -12,7 +16,14 @@ public interface IsaProbe
/// <summary>
/// Gets a value indicating whether AVX2 is supported.
/// </summary>
bool IsAvx2Supported { get; }
bool IsAvx2Supported { get; }

#if NET6_0_OR_GREATER
/// <summary>
/// Gets a value indicating whether Arm64 is supported.
/// </summary>
bool IsArm64Supported { get => false; }
#endif
}

/// <summary>
Expand All @@ -25,7 +36,15 @@ public interface IsaProbe
public bool IsAvx2Supported => false;
#else
/// <inheritdoc/>
public bool IsAvx2Supported => Avx2.IsSupported;
public bool IsAvx2Supported => Avx2.IsSupported;
#endif

#if NET6_0_OR_GREATER
/// <inheritdoc/>
public bool IsArm64Supported => AdvSimd.Arm64.IsSupported;
#else
/// <inheritdoc/>
public bool IsArm64Supported => false;
#endif
}

Expand All @@ -35,6 +54,9 @@ public interface IsaProbe
public readonly struct DisableHardwareIntrinsics : IsaProbe
{
/// <inheritdoc/>
public bool IsAvx2Supported => false;
public bool IsAvx2Supported => false;

/// <inheritdoc/>
public bool IsArm64Supported => false;
}
}
156 changes: 156 additions & 0 deletions BitFaster.Caching/Lfu/CmSketchCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
using System.Runtime.Intrinsics.X86;
#endif

#if NET6_0_OR_GREATER
using System.Runtime.Intrinsics.Arm;
#endif

namespace BitFaster.Caching.Lfu
{
/// <summary>
Expand Down Expand Up @@ -76,6 +80,12 @@ public int EstimateFrequency(T value)
{
return EstimateFrequencyAvx(value);
}
#if NET6_0_OR_GREATER
else if (isa.IsArm64Supported)
{
return EstimateFrequencyArm(value);
}
#endif
else
{
return EstimateFrequencyStd(value);
Expand All @@ -99,6 +109,12 @@ public void Increment(T value)
{
IncrementAvx(value);
}
#if NET6_0_OR_GREATER
else if (isa.IsArm64Supported)
{
IncrementArm(value);
}
#endif
else
{
IncrementStd(value);
Expand Down Expand Up @@ -329,5 +345,145 @@ private unsafe void IncrementAvx(T value)
}
}
#endif

#if NET6_0_OR_GREATER
private unsafe void IncrementArm(T value)
{
int blockHash = Spread(comparer.GetHashCode(value));
int counterHash = Rehash(blockHash);
int block = (blockHash & blockMask) << 3;

Vector128<int> h = Vector128.Create(counterHash);
h = AdvSimd.ShiftArithmetic(h, Vector128.Create(0, -8, -16, -24));

Vector128<int> index = AdvSimd.ShiftRightLogical(h, 1);
index = AdvSimd.And(index, Vector128.Create(15)); // j - counter index
Vector128<int> offset = AdvSimd.And(h, Vector128.Create(1));
Vector128<int> blockOffset = AdvSimd.Add(Vector128.Create(block), offset); // i - table index
blockOffset = AdvSimd.Add(blockOffset, Vector128.Create(0, 2, 4, 6)); // + (i << 1)

fixed (long* tablePtr = table)
{
int t0 = AdvSimd.Extract(blockOffset, 0);
int t1 = AdvSimd.Extract(blockOffset, 1);
int t2 = AdvSimd.Extract(blockOffset, 2);
int t3 = AdvSimd.Extract(blockOffset, 3);

var ta0 = AdvSimd.LoadVector64(tablePtr + t0);
var ta1 = AdvSimd.LoadVector64(tablePtr + t1);
var ta2 = AdvSimd.LoadVector64(tablePtr + t2);
var ta3 = AdvSimd.LoadVector64(tablePtr + t3);

Vector128<long> tableVectorA = Vector128.Create(ta0, ta1);
Vector128<long> tableVectorB = Vector128.Create(ta2, ta3);

// TODO: VectorTableLookup
//Vector128<long> tableVectorA = Vector128.Create(
// tablePtr[t0],
// tablePtr[t1]);
//Vector128<long> tableVectorB = Vector128.Create(
// tablePtr[t2],
// tablePtr[t3]);

// j == index
index = AdvSimd.ShiftLeftLogicalSaturate(index, 2);

Vector128<int> longOffA = AdvSimd.Arm64.InsertSelectedScalar(Vector128<int>.Zero, 0, index, 0);
longOffA = AdvSimd.Arm64.InsertSelectedScalar(longOffA, 2, index, 1);

Vector128<int> longOffB = AdvSimd.Arm64.InsertSelectedScalar(Vector128<int>.Zero, 0, index, 2);
longOffB = AdvSimd.Arm64.InsertSelectedScalar(longOffB, 2, index, 3);

Vector128<long> fifteen = Vector128.Create(0xfL);
Vector128<long> maskA = AdvSimd.ShiftArithmetic(fifteen, longOffA.AsInt64());
Vector128<long> maskB = AdvSimd.ShiftArithmetic(fifteen, longOffB.AsInt64());

Vector128<long> maskedA = AdvSimd.Arm64.CompareEqual(AdvSimd.And(tableVectorA, maskA), maskA);
Vector128<long> maskedB = AdvSimd.Arm64.CompareEqual(AdvSimd.And(tableVectorB, maskB), maskB);

var one = Vector128.Create(1L);
Vector128<long> incA = AdvSimd.ShiftArithmetic(one, longOffA.AsInt64());
Vector128<long> incB = AdvSimd.ShiftArithmetic(one, longOffB.AsInt64());

maskedA = AdvSimd.Not(maskedA);
maskedB = AdvSimd.Not(maskedB);

incA = AdvSimd.And(maskedA, incA);
incB = AdvSimd.And(maskedB, incB);

tablePtr[t0] += AdvSimd.Extract(incA, 0);
tablePtr[t1] += AdvSimd.Extract(incA, 1);
tablePtr[t2] += AdvSimd.Extract(incB, 0);
tablePtr[t3] += AdvSimd.Extract(incB, 1);

var maxA = AdvSimd.Arm64.MaxAcross(incA.AsInt32());
var maxB = AdvSimd.Arm64.MaxAcross(incB.AsInt32());
maxA = AdvSimd.Arm64.InsertSelectedScalar(maxA, 1, maxB, 0);
var max = AdvSimd.Arm64.MaxAcross(maxA.AsInt16());

if (max.ToScalar() != 0 && (++size == sampleSize))
{
Reset();
}
}
}

private unsafe int EstimateFrequencyArm(T value)
{
int blockHash = Spread(comparer.GetHashCode(value));
int counterHash = Rehash(blockHash);
int block = (blockHash & blockMask) << 3;

Vector128<int> h = Vector128.Create(counterHash);
h = AdvSimd.ShiftArithmetic(h, Vector128.Create(0, -8, -16, -24));

Vector128<int> index = AdvSimd.ShiftRightLogical(h, 1);

index = AdvSimd.And(index, Vector128.Create(0xf)); // j - counter index
Vector128<int> offset = AdvSimd.And(h, Vector128.Create(1));
Vector128<int> blockOffset = AdvSimd.Add(Vector128.Create(block), offset); // i - table index
blockOffset = AdvSimd.Add(blockOffset, Vector128.Create(0, 2, 4, 6)); // + (i << 1)

fixed (long* tablePtr = table)
{
// TODO: VectorTableLookup
Vector128<long> tableVectorA = Vector128.Create(
tablePtr[AdvSimd.Extract(blockOffset, 0)],
tablePtr[AdvSimd.Extract(blockOffset, 1)]);
Vector128<long> tableVectorB = Vector128.Create(
tablePtr[AdvSimd.Extract(blockOffset, 2)],
tablePtr[AdvSimd.Extract(blockOffset, 3)]);

// j == index
index = AdvSimd.ShiftLeftLogicalSaturate(index, 2);

Vector128<int> indexA = AdvSimd.Arm64.InsertSelectedScalar(Vector128<int>.Zero, 0, index, 0);
indexA = AdvSimd.Arm64.InsertSelectedScalar(indexA, 2, index, 1);

Vector128<int> indexB = AdvSimd.Arm64.InsertSelectedScalar(Vector128<int>.Zero, 0, index, 2);
indexB = AdvSimd.Arm64.InsertSelectedScalar(indexB, 2, index, 3);

indexA = AdvSimd.Negate(indexA);
indexB = AdvSimd.Negate(indexB);

Vector128<long> a = AdvSimd.ShiftArithmetic(tableVectorA, indexA.AsInt64());
Vector128<long> b = AdvSimd.ShiftArithmetic(tableVectorB, indexB.AsInt64());

var fifteen = Vector128.Create(0xfL);
a = AdvSimd.And(a, fifteen);
b = AdvSimd.And(b, fifteen);

// TODO: VectorTableLookup
Vector128<int> x = AdvSimd.Arm64.InsertSelectedScalar(Vector128<int>.Zero, 0, a.AsInt32(), 0);
x = AdvSimd.Arm64.InsertSelectedScalar(x, 1, a.AsInt32(), 2);
x = AdvSimd.Arm64.InsertSelectedScalar(x, 2, b.AsInt32(), 0);
x = AdvSimd.Arm64.InsertSelectedScalar(x, 3, b.AsInt32(), 2);

var minA = AdvSimd.Arm64.MinAcross(x);

return minA.ToScalar();
}
}
#endif
}
}