11using System ;
22using System . Collections . Generic ;
33using System . Diagnostics ;
4- using System . Linq ;
54using System . Threading ;
65using System . Threading . Tasks ;
76using LLama . Abstractions ;
@@ -16,7 +15,12 @@ public sealed class BatchedExecutor
1615 : IDisposable
1716{
1817 private int _nextSequenceId ;
19- private readonly List < IBatch > _batchQueue = [ ] ;
18+ private readonly List < IBatch > _batchQueue = [ ] ;
19+ private int _batchQueueHead ;
20+ private int _batchedTokenCount ;
21+ private bool _batchedTokenCountDirty = true ;
22+ // Skip compacting the queue until this many processed batches accumulate at the front.
23+ private const int CleanupThreshold = 16 ;
2024
2125 /// <summary>
2226 /// Set to 1 using interlocked exchange while inference is running
@@ -42,12 +46,27 @@ public sealed class BatchedExecutor
4246 /// <summary>
4347 /// Get the number of tokens in the batch, waiting for <see cref="Infer"/> to be called
4448 /// </summary>
45- public int BatchedTokenCount => _batchQueue . Sum ( a => a . ItemCount ) ;
49+ public int BatchedTokenCount
50+ {
51+ get
52+ {
53+ if ( _batchedTokenCountDirty )
54+ {
55+ var total = 0 ;
56+ for ( var i = _batchQueueHead ; i < _batchQueue . Count ; i ++ )
57+ total += _batchQueue [ i ] . ItemCount ;
58+ _batchedTokenCount = total ;
59+ _batchedTokenCountDirty = false ;
60+ }
61+
62+ return _batchedTokenCount ;
63+ }
64+ }
4665
4766 /// <summary>
4867 /// Number of batches in the queue, waiting for <see cref="Infer"/> to be called
4968 /// </summary>
50- public int BatchQueueCount => _batchQueue . Count ;
69+ public int BatchQueueCount => _batchQueue . Count - _batchQueueHead ;
5170
5271 /// <summary>
5372 /// Check if this executor has been disposed.
@@ -147,12 +166,13 @@ public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
147166 // again after the issue has been fixed (e.g. some KV cache space has been freed) to retry this operation.
148167 if ( status != DecodeResult . Ok )
149168 {
150- _batchQueue . Insert ( 0 , next ) ;
169+ RequeueFront ( next ) ;
151170 return status ;
152171 }
153172
154173 // Everything was ok, advance the epoch
155174 Epoch ++ ;
175+ CleanupQueue ( ) ;
156176
157177 return status ;
158178 }
@@ -166,13 +186,45 @@ public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
166186
167187 IBatch ? GetNextBatch ( )
168188 {
169- if ( _batchQueue . Count == 0 )
189+ if ( _batchQueueHead >= _batchQueue . Count )
190+ {
191+ _batchQueue . Clear ( ) ;
192+ _batchQueueHead = 0 ;
170193 return null ;
171-
172- var nextBatch = _batchQueue [ 0 ] ;
173- _batchQueue . RemoveAt ( 0 ) ;
194+ }
195+
196+ var nextBatch = _batchQueue [ _batchQueueHead ] ;
197+ _batchQueueHead ++ ;
198+ _batchedTokenCountDirty = true ;
174199 return nextBatch ;
175200 }
201+
202+ void RequeueFront ( IBatch batch )
203+ {
204+ Debug . Assert ( _batchQueueHead > 0 , "Cannot requeue batch when queue head is at zero." ) ;
205+ _batchQueue [ -- _batchQueueHead ] = batch ;
206+ _batchedTokenCountDirty = true ;
207+ }
208+
209+ // Remove batches that have already been consumed so the head index does not grow without bound.
210+ void CleanupQueue ( )
211+ {
212+ if ( _batchQueueHead == 0 )
213+ return ;
214+
215+ if ( _batchQueueHead >= _batchQueue . Count )
216+ {
217+ _batchQueue . Clear ( ) ;
218+ _batchQueueHead = 0 ;
219+ return ;
220+ }
221+
222+ if ( _batchQueueHead > CleanupThreshold && _batchQueueHead > _batchQueue . Count / 2 )
223+ {
224+ _batchQueue . RemoveRange ( 0 , _batchQueueHead ) ;
225+ _batchQueueHead = 0 ;
226+ }
227+ }
176228 }
177229
178230 /// <inheritdoc />
@@ -202,7 +254,7 @@ internal LLamaSeqId GetNextSequenceId()
202254 throw new ArgumentOutOfRangeException ( nameof ( minCapacity ) , $ "Request batch capacity must be less than or equal to BatchSize ({ Context . BatchSize } )") ;
203255
204256 // Find a batch with space for at least minCapacity tokens
205- for ( var i = 0 ; i < _batchQueue . Count ; i ++ )
257+ for ( var i = _batchQueueHead ; i < _batchQueue . Count ; i ++ )
206258 {
207259 var item = _batchQueue [ i ] ;
208260 if ( item is not TokenBatch { Batch : var batch } )
@@ -213,13 +265,17 @@ internal LLamaSeqId GetNextSequenceId()
213265 continue ;
214266
215267 if ( batch . TokenCount < Context . BatchSize )
216- return ( batch , Epoch + ( uint ) ( i + 1 ) * 2 ) ;
268+ {
269+ _batchedTokenCountDirty = true ;
270+ return ( batch , Epoch + ( uint ) ( i - _batchQueueHead + 1 ) * 2 ) ;
271+ }
217272 }
218273
219274 // Add a new batch to the end of the queue
220275 var end = new LLamaBatch ( ) ;
221276 _batchQueue . Add ( new TokenBatch ( end ) ) ;
222- return ( end , Epoch + ( uint ) _batchQueue . Count * 2 ) ;
277+ _batchedTokenCountDirty = true ;
278+ return ( end , Epoch + ( uint ) ( _batchQueue . Count - _batchQueueHead ) * 2 ) ;
223279 }
224280
225281 /// <summary>
@@ -234,7 +290,7 @@ internal LLamaSeqId GetNextSequenceId()
234290 throw new ArgumentOutOfRangeException ( nameof ( minCapacity ) , $ "Request batch capacity must be less than or equal to BatchSize ({ Context . BatchSize } )") ;
235291
236292 // Find a batch with space for at least minCapacity embeddings
237- for ( var i = 0 ; i < _batchQueue . Count ; i ++ )
293+ for ( var i = _batchQueueHead ; i < _batchQueue . Count ; i ++ )
238294 {
239295 var item = _batchQueue [ i ] ;
240296 if ( item is not EmbeddingBatch { Batch : var batch } )
@@ -245,13 +301,17 @@ internal LLamaSeqId GetNextSequenceId()
245301 continue ;
246302
247303 if ( batch . EmbeddingsCount < Context . BatchSize )
248- return ( batch , Epoch + ( uint ) ( i + 1 ) * 2 ) ;
304+ {
305+ _batchedTokenCountDirty = true ;
306+ return ( batch , Epoch + ( uint ) ( i - _batchQueueHead + 1 ) * 2 ) ;
307+ }
249308 }
250309
251310 // Add a new batch to the end of the queue
252311 var end = new LLamaBatchEmbeddings ( Context . EmbeddingSize ) ;
253312 _batchQueue . Add ( new EmbeddingBatch ( end ) ) ;
254- return ( end , Epoch + ( uint ) _batchQueue . Count * 2 ) ;
313+ _batchedTokenCountDirty = true ;
314+ return ( end , Epoch + ( uint ) ( _batchQueue . Count - _batchQueueHead ) * 2 ) ;
255315 }
256316
257317 #region batches
@@ -286,4 +346,4 @@ public Task<DecodeResult> DecodeAsync(LLamaContext ctx, CancellationToken token)
286346 }
287347 }
288348 #endregion
289- }
349+ }
0 commit comments