11// Copyright (c) TensorStack. All rights reserved.
22// Licensed under the Apache 2.0 License.
3-
43using System ;
54using System . Collections . Generic ;
6- using System . Linq ;
75
86namespace TensorStack . TextGeneration . Processing
97{
10- public class SequenceComparer : IEqualityComparer < Sequence >
8+ public sealed class SequenceComparer : IEqualityComparer < Sequence >
119 {
1210 private readonly HashSet < long > _specialTokens ;
1311 private int _compareLength ;
@@ -19,8 +17,8 @@ public class SequenceComparer : IEqualityComparer<Sequence>
1917 /// <param name="compareLength">Length of the compare.</param>
2018 public SequenceComparer ( IReadOnlyDictionary < long , string > specialTokens , int compareLength = int . MaxValue )
2119 {
22- SetLength ( compareLength ) ;
2320 _specialTokens = [ .. specialTokens . Keys ] ;
21+ _compareLength = Math . Max ( 1 , compareLength ) ;
2422 }
2523
2624
@@ -35,9 +33,26 @@ public bool Equals(Sequence x, Sequence y)
3533 if ( x == null || y == null )
3634 return false ;
3735
38- var normX = NormalizeTokens ( x . Tokens ) ;
39- var normY = NormalizeTokens ( y . Tokens ) ;
40- return normX . SequenceEqual ( normY ) ;
36+ int cx = 0 , cy = 0 ;
37+ var xt = x . Tokens ;
38+ var yt = y . Tokens ;
39+ int xi = 0 , yi = 0 ;
40+ while ( xi < xt . Count && yi < yt . Count && cx < _compareLength && cy < _compareLength )
41+ {
42+ while ( xi < xt . Count && _specialTokens . Contains ( xt [ xi ] ) ) xi ++ ;
43+ while ( yi < yt . Count && _specialTokens . Contains ( yt [ yi ] ) ) yi ++ ;
44+
45+ if ( xi >= xt . Count || yi >= yt . Count )
46+ break ;
47+
48+ if ( xt [ xi ] != yt [ yi ] )
49+ return false ;
50+
51+ xi ++ ; yi ++ ;
52+ cx ++ ; cy ++ ;
53+ }
54+
55+ return cx == cy ;
4156 }
4257
4358
@@ -50,9 +65,18 @@ public int GetHashCode(Sequence obj)
5065 {
5166 unchecked
5267 {
53- int hash = 17 ;
54- foreach ( var val in NormalizeTokens ( obj . Tokens ) )
55- hash = hash * 23 + val . GetHashCode ( ) ;
68+ var hash = 17 ;
69+ var count = 0 ;
70+ var tokens = obj . Tokens ;
71+ for ( int i = 0 ; i < tokens . Count && count < _compareLength ; i ++ )
72+ {
73+ var t = tokens [ i ] ;
74+ if ( _specialTokens . Contains ( t ) )
75+ continue ;
76+
77+ hash = hash * 23 + t . GetHashCode ( ) ;
78+ count ++ ;
79+ }
5680 return hash ;
5781 }
5882 }
@@ -67,18 +91,5 @@ public void SetLength(int length)
6791 _compareLength = Math . Max ( 1 , length ) ;
6892 }
6993
70-
71- /// <summary>
72- /// Normalizes the tokens.
73- /// </summary>
74- /// <param name="tokens">The tokens.</param>
75- /// <returns>IEnumerable<System.Int64>.</returns>
76- private IEnumerable < long > NormalizeTokens ( IReadOnlyList < long > tokens )
77- {
78- foreach ( var t in tokens . Except ( _specialTokens ) . Take ( _compareLength ) )
79- {
80- yield return t ;
81- }
82- }
8394 }
8495}
0 commit comments