Skip to content

Commit a780ad9

Browse files
committed
Faster SequenceComparer
1 parent 37e7ddb commit a780ad9

File tree

2 files changed

+35
-24
lines changed

2 files changed

+35
-24
lines changed

TensorStack.TextGeneration/Common/GenerateOptions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@ public record GenerateOptions : IRunOptions
1818
public float Temperature { get; set; } = 1.0f;
1919
public float LengthPenalty { get; set; } = 1.0f;
2020
public EarlyStopping EarlyStopping { get; set; }
21-
public int DiversityLength { get; set; } = 5;
21+
public int DiversityLength { get; set; } = 20;
2222
}
2323
}
Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
// Copyright (c) TensorStack. All rights reserved.
22
// Licensed under the Apache 2.0 License.
3-
43
using System;
54
using System.Collections.Generic;
6-
using System.Linq;
75

86
namespace TensorStack.TextGeneration.Processing
97
{
10-
public class SequenceComparer : IEqualityComparer<Sequence>
8+
public sealed class SequenceComparer : IEqualityComparer<Sequence>
119
{
1210
private readonly HashSet<long> _specialTokens;
1311
private int _compareLength;
@@ -19,8 +17,8 @@ public class SequenceComparer : IEqualityComparer<Sequence>
1917
/// <param name="compareLength">Length of the compare.</param>
2018
public SequenceComparer(IReadOnlyDictionary<long, string> specialTokens, int compareLength = int.MaxValue)
2119
{
22-
SetLength(compareLength);
2320
_specialTokens = [.. specialTokens.Keys];
21+
_compareLength = Math.Max(1, compareLength);
2422
}
2523

2624

@@ -35,9 +33,26 @@ public bool Equals(Sequence x, Sequence y)
3533
if (x == null || y == null)
3634
return false;
3735

38-
var normX = NormalizeTokens(x.Tokens);
39-
var normY = NormalizeTokens(y.Tokens);
40-
return normX.SequenceEqual(normY);
36+
int cx = 0, cy = 0;
37+
var xt = x.Tokens;
38+
var yt = y.Tokens;
39+
int xi = 0, yi = 0;
40+
while (xi < xt.Count && yi < yt.Count && cx < _compareLength && cy < _compareLength)
41+
{
42+
while (xi < xt.Count && _specialTokens.Contains(xt[xi])) xi++;
43+
while (yi < yt.Count && _specialTokens.Contains(yt[yi])) yi++;
44+
45+
if (xi >= xt.Count || yi >= yt.Count)
46+
break;
47+
48+
if (xt[xi] != yt[yi])
49+
return false;
50+
51+
xi++; yi++;
52+
cx++; cy++;
53+
}
54+
55+
return cx == cy;
4156
}
4257

4358

@@ -50,9 +65,18 @@ public int GetHashCode(Sequence obj)
5065
{
5166
unchecked
5267
{
53-
int hash = 17;
54-
foreach (var val in NormalizeTokens(obj.Tokens))
55-
hash = hash * 23 + val.GetHashCode();
68+
var hash = 17;
69+
var count = 0;
70+
var tokens = obj.Tokens;
71+
for (int i = 0; i < tokens.Count && count < _compareLength; i++)
72+
{
73+
var t = tokens[i];
74+
if (_specialTokens.Contains(t))
75+
continue;
76+
77+
hash = hash * 23 + t.GetHashCode();
78+
count++;
79+
}
5680
return hash;
5781
}
5882
}
@@ -67,18 +91,5 @@ public void SetLength(int length)
6791
_compareLength = Math.Max(1, length);
6892
}
6993

70-
71-
/// <summary>
72-
/// Normalizes the tokens.
73-
/// </summary>
74-
/// <param name="tokens">The tokens.</param>
75-
/// <returns>IEnumerable&lt;System.Int64&gt;.</returns>
76-
private IEnumerable<long> NormalizeTokens(IReadOnlyList<long> tokens)
77-
{
78-
foreach (var t in tokens.Except(_specialTokens).Take(_compareLength))
79-
{
80-
yield return t;
81-
}
82-
}
8394
}
8495
}

0 commit comments

Comments
 (0)