Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion BitFaster.Caching.UnitTests/Lfu/CmSketchTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace BitFaster.Caching.UnitTests.Lfu
{
// Test with AVX2 if it is supported
// Test with AVX2 or ARM64 if it is supported
public class CMSketchAvx2Tests : CmSketchTestBase<DetectIsa>
{
}
Expand Down
28 changes: 25 additions & 3 deletions BitFaster.Caching/Intrinsics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
using System.Runtime.Intrinsics.X86;
#endif

#if NET6_0
using System.Runtime.Intrinsics.Arm;
#endif

namespace BitFaster.Caching
{
/// <summary>
Expand All @@ -12,7 +16,14 @@ public interface IsaProbe
/// <summary>
/// Gets a value indicating whether AVX2 is supported.
/// </summary>
bool IsAvx2Supported { get; }
bool IsAvx2Supported { get; }

#if NET6_0_OR_GREATER
/// <summary>
/// Gets a value indicating whether Arm64 is supported.
/// </summary>
bool IsArm64Supported { get => false; }
#endif
}

/// <summary>
Expand All @@ -25,7 +36,15 @@ public interface IsaProbe
public bool IsAvx2Supported => false;
#else
/// <inheritdoc/>
public bool IsAvx2Supported => Avx2.IsSupported;
public bool IsAvx2Supported => Avx2.IsSupported;
#endif

#if NET6_0_OR_GREATER
/// <inheritdoc/>
public bool IsArm64Supported => AdvSimd.Arm64.IsSupported;
#else
/// <inheritdoc/>
public bool IsArm64Supported => false;
#endif
}

Expand All @@ -35,6 +54,9 @@ public interface IsaProbe
public readonly struct DisableHardwareIntrinsics : IsaProbe
{
/// <inheritdoc/>
public bool IsAvx2Supported => false;
public bool IsAvx2Supported => false;

/// <inheritdoc/>
public bool IsArm64Supported => false;
}
}
153 changes: 153 additions & 0 deletions BitFaster.Caching/Lfu/CmSketchCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
using System.Runtime.Intrinsics.X86;
#endif

#if NET6_0_OR_GREATER
using System.Runtime.Intrinsics.Arm;
#endif

namespace BitFaster.Caching.Lfu
{
/// <summary>
Expand Down Expand Up @@ -76,6 +80,12 @@ public int EstimateFrequency(T value)
{
return EstimateFrequencyAvx(value);
}
#if NET6_0_OR_GREATER
else if (isa.IsArm64Supported)
{
return EstimateFrequencyArm(value);
}
#endif
else
{
return EstimateFrequencyStd(value);
Expand All @@ -99,6 +109,12 @@ public void Increment(T value)
{
IncrementAvx(value);
}
#if NET6_0_OR_GREATER
else if (isa.IsArm64Supported)
{
IncrementArm(value);
}
#endif
else
{
IncrementStd(value);
Expand Down Expand Up @@ -329,5 +345,142 @@ private unsafe void IncrementAvx(T value)
}
}
#endif

#if NET6_0_OR_GREATER
private unsafe void IncrementArm(T value)
{
int blockHash = Spread(comparer.GetHashCode(value));
int counterHash = Rehash(blockHash);
int block = (blockHash & blockMask) << 3;

Vector128<int> h = Vector128.Create(counterHash);
h = AdvSimd.ShiftArithmetic(h, Vector128.Create(0, -8, -16, -24));

Vector128<int> index = AdvSimd.ShiftRightLogical(h, 1);
index = AdvSimd.And(index, Vector128.Create(15)); // j - counter index
Vector128<int> offset = AdvSimd.And(h, Vector128.Create(1));
Vector128<int> blockOffset = AdvSimd.Add(Vector128.Create(block), offset); // i - table index
blockOffset = AdvSimd.Add(blockOffset, Vector128.Create(0, 2, 4, 6)); // + (i << 1)

fixed (long* tablePtr = table)
{
int t0 = AdvSimd.Extract(blockOffset, 0);
int t1 = AdvSimd.Extract(blockOffset, 1);
int t2 = AdvSimd.Extract(blockOffset, 2);
int t3 = AdvSimd.Extract(blockOffset, 3);

var ta0 = AdvSimd.LoadVector64(tablePtr + t0);
var ta1 = AdvSimd.LoadVector64(tablePtr + t1);
var ta2 = AdvSimd.LoadVector64(tablePtr + t2);
var ta3 = AdvSimd.LoadVector64(tablePtr + t3);

Vector128<long> tableVectorA = Vector128.Create(ta0, ta1);
Vector128<long> tableVectorB = Vector128.Create(ta2, ta3);

// TODO: VectorTableLookup
//Vector128<long> tableVectorA = Vector128.Create(
// tablePtr[t0],
// tablePtr[t1]);
//Vector128<long> tableVectorB = Vector128.Create(
// tablePtr[t2],
// tablePtr[t3]);

// j == index
index = AdvSimd.ShiftLeftLogicalSaturate(index, 2);

Vector128<int> longOffA = AdvSimd.Arm64.InsertSelectedScalar(Vector128<int>.Zero, 0, index, 0);
longOffA = AdvSimd.Arm64.InsertSelectedScalar(longOffA, 2, index, 1);

Vector128<int> longOffB = AdvSimd.Arm64.InsertSelectedScalar(Vector128<int>.Zero, 0, index, 2);
longOffB = AdvSimd.Arm64.InsertSelectedScalar(longOffB, 2, index, 3);

Vector128<long> fifteen = Vector128.Create(0xfL);
Vector128<long> maskA = AdvSimd.ShiftArithmetic(fifteen, longOffA.AsInt64());
Vector128<long> maskB = AdvSimd.ShiftArithmetic(fifteen, longOffB.AsInt64());

Vector128<long> maskedA = AdvSimd.Arm64.CompareEqual(AdvSimd.And(tableVectorA, maskA), maskA);
Vector128<long> maskedB = AdvSimd.Arm64.CompareEqual(AdvSimd.And(tableVectorB, maskB), maskB);

var one = Vector128.Create(1L);
Vector128<long> incA = AdvSimd.ShiftArithmetic(one, longOffA.AsInt64());
Vector128<long> incB = AdvSimd.ShiftArithmetic(one, longOffB.AsInt64());

maskedA = AdvSimd.Not(maskedA);
maskedB = AdvSimd.Not(maskedB);

incA = AdvSimd.And(maskedA, incA);
incB = AdvSimd.And(maskedA, incB);

tablePtr[t0] += AdvSimd.Extract(incA, 0);
tablePtr[t1] += AdvSimd.Extract(incA, 1);
tablePtr[t2] += AdvSimd.Extract(incB, 0);
tablePtr[t3] += AdvSimd.Extract(incB, 1);

var maxA = AdvSimd.Arm64.MaxAcross(incA.AsInt32());
var maxB = AdvSimd.Arm64.MaxAcross(incB.AsInt32());
maxA = AdvSimd.Arm64.InsertSelectedScalar(maxA, 1, maxB, 0);
var max = AdvSimd.Arm64.MaxAcross(maxA.AsInt16());

if (max.ToScalar() != 0 && (++size == sampleSize))
{
Reset();
}
}
}

private unsafe int EstimateFrequencyArm(T value)
{
int blockHash = Spread(comparer.GetHashCode(value));
int counterHash = Rehash(blockHash);
int block = (blockHash & blockMask) << 3;

Vector128<int> h = Vector128.Create(counterHash);
h = AdvSimd.ShiftArithmetic(h, Vector128.Create(0, -8, -16, -24));

Vector128<int> index = AdvSimd.ShiftRightLogical(h, 1);

index = AdvSimd.And(index, Vector128.Create(0xf)); // j - counter index
Vector128<int> offset = AdvSimd.And(h, Vector128.Create(1));
Vector128<int> blockOffset = AdvSimd.Add(Vector128.Create(block), offset); // i - table index
blockOffset = AdvSimd.Add(blockOffset, Vector128.Create(0, 2, 4, 6)); // + (i << 1)

fixed (long* tablePtr = table)
{
// TODO: VectorTableLookup
Vector128<long> tableVectorA = Vector128.Create(
tablePtr[AdvSimd.Extract(blockOffset, 0)],
tablePtr[AdvSimd.Extract(blockOffset, 1)]);
Vector128<long> tableVectorB = Vector128.Create(
tablePtr[AdvSimd.Extract(blockOffset, 2)],
tablePtr[AdvSimd.Extract(blockOffset, 3)]);

// j == index
index = AdvSimd.ShiftLeftLogicalSaturate(index, 2);

Vector128<int> indexA = AdvSimd.Arm64.InsertSelectedScalar(Vector128<int>.Zero, 0, index, 0);
indexA = AdvSimd.Arm64.InsertSelectedScalar(indexA, 2, index, 1);

Vector128<int> indexB = AdvSimd.Arm64.InsertSelectedScalar(Vector128<int>.Zero, 0, index, 2);
indexB = AdvSimd.Arm64.InsertSelectedScalar(indexB, 2, index, 3);

indexA = AdvSimd.Negate(indexA);
indexB = AdvSimd.Negate(indexB);

Vector128<long> a = AdvSimd.ShiftArithmetic(tableVectorA, indexA.AsInt64());
Vector128<long> b = AdvSimd.ShiftArithmetic(tableVectorB, indexB.AsInt64());

var fifteen = Vector128.Create(0xfL);
a = AdvSimd.And(a, fifteen);
b = AdvSimd.And(b, fifteen);

var minA = AdvSimd.Arm64.MinAcross(a.AsInt32());
var minB = AdvSimd.Arm64.MinAcross(b.AsInt32());
minA = AdvSimd.Arm64.InsertSelectedScalar(minA, 1, minB, 0);
var min = AdvSimd.Arm64.MinAcross(minA.AsInt16());

return min.ToScalar();
}
}
#endif
}
}