11using System ;
22using System . Collections . Generic ;
33using System . Diagnostics ;
4- using System . Linq ;
54using System . Threading ;
65using System . Threading . Tasks ;
76using LLama . Abstractions ;
@@ -16,7 +15,10 @@ public sealed class BatchedExecutor
1615 : IDisposable
1716{
1817 private int _nextSequenceId ;
19- private readonly List < IBatch > _batchQueue = [ ] ;
18+ private readonly List < IBatch > _batchQueue = [ ] ;
19+ private int _batchQueueHead ;
20+ private int _batchedTokenCount ;
21+ private bool _batchedTokenCountDirty = true ;
2022
2123 /// <summary>
2224 /// Set to 1 using interlocked exchange while inference is running
@@ -42,12 +44,27 @@ public sealed class BatchedExecutor
4244 /// <summary>
4345 /// Get the number of tokens in the batch, waiting for <see cref="Infer"/> to be called
4446 /// </summary>
45- public int BatchedTokenCount => _batchQueue . Sum ( a => a . ItemCount ) ;
47+ public int BatchedTokenCount
48+ {
49+ get
50+ {
51+ if ( _batchedTokenCountDirty )
52+ {
53+ var total = 0 ;
54+ for ( var i = _batchQueueHead ; i < _batchQueue . Count ; i ++ )
55+ total += _batchQueue [ i ] . ItemCount ;
56+ _batchedTokenCount = total ;
57+ _batchedTokenCountDirty = false ;
58+ }
59+
60+ return _batchedTokenCount ;
61+ }
62+ }
4663
4764 /// <summary>
4865 /// Number of batches in the queue, waiting for <see cref="Infer"/> to be called
4966 /// </summary>
50- public int BatchQueueCount => _batchQueue . Count ;
67+ public int BatchQueueCount => _batchQueue . Count - _batchQueueHead ;
5168
5269 /// <summary>
5370 /// Check if this executor has been disposed.
@@ -147,12 +164,13 @@ public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
147164 // again after the issue has been fixed (e.g. some KV cache space has been freed) to retry this operation.
148165 if ( status != DecodeResult . Ok )
149166 {
150- _batchQueue . Insert ( 0 , next ) ;
167+ RequeueFront ( next ) ;
151168 return status ;
152169 }
153170
154171 // Everything was ok, advance the epoch
155172 Epoch ++ ;
173+ CleanupQueue ( ) ;
156174
157175 return status ;
158176 }
@@ -166,13 +184,44 @@ public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
166184
167185 IBatch ? GetNextBatch ( )
168186 {
169- if ( _batchQueue . Count == 0 )
187+ if ( _batchQueueHead >= _batchQueue . Count )
188+ {
189+ _batchQueue . Clear ( ) ;
190+ _batchQueueHead = 0 ;
170191 return null ;
171-
172- var nextBatch = _batchQueue [ 0 ] ;
173- _batchQueue . RemoveAt ( 0 ) ;
192+ }
193+
194+ var nextBatch = _batchQueue [ _batchQueueHead ] ;
195+ _batchQueueHead ++ ;
196+ _batchedTokenCountDirty = true ;
174197 return nextBatch ;
175198 }
199+
200+ void RequeueFront ( IBatch batch )
201+ {
202+ Debug . Assert ( _batchQueueHead > 0 , "Cannot requeue batch when queue head is at zero." ) ;
203+ _batchQueue [ -- _batchQueueHead ] = batch ;
204+ _batchedTokenCountDirty = true ;
205+ }
206+
207+ void CleanupQueue ( )
208+ {
209+ if ( _batchQueueHead == 0 )
210+ return ;
211+
212+ if ( _batchQueueHead >= _batchQueue . Count )
213+ {
214+ _batchQueue . Clear ( ) ;
215+ _batchQueueHead = 0 ;
216+ return ;
217+ }
218+
219+ if ( _batchQueueHead > 16 && _batchQueueHead > _batchQueue . Count / 2 )
220+ {
221+ _batchQueue . RemoveRange ( 0 , _batchQueueHead ) ;
222+ _batchQueueHead = 0 ;
223+ }
224+ }
176225 }
177226
178227 /// <inheritdoc />
@@ -202,7 +251,7 @@ internal LLamaSeqId GetNextSequenceId()
202251 throw new ArgumentOutOfRangeException ( nameof ( minCapacity ) , $ "Request batch capacity must be less than or equal to BatchSize ({ Context . BatchSize } )") ;
203252
204253 // Find a batch with space for at least minCapacity tokens
205- for ( var i = 0 ; i < _batchQueue . Count ; i ++ )
254+ for ( var i = _batchQueueHead ; i < _batchQueue . Count ; i ++ )
206255 {
207256 var item = _batchQueue [ i ] ;
208257 if ( item is not TokenBatch { Batch : var batch } )
@@ -213,13 +262,17 @@ internal LLamaSeqId GetNextSequenceId()
213262 continue ;
214263
215264 if ( batch . TokenCount < Context . BatchSize )
216- return ( batch , Epoch + ( uint ) ( i + 1 ) * 2 ) ;
265+ {
266+ _batchedTokenCountDirty = true ;
267+ return ( batch , Epoch + ( uint ) ( i - _batchQueueHead + 1 ) * 2 ) ;
268+ }
217269 }
218270
219271 // Add a new batch to the end of the queue
220272 var end = new LLamaBatch ( ) ;
221273 _batchQueue . Add ( new TokenBatch ( end ) ) ;
222- return ( end , Epoch + ( uint ) _batchQueue . Count * 2 ) ;
274+ _batchedTokenCountDirty = true ;
275+ return ( end , Epoch + ( uint ) ( _batchQueue . Count - _batchQueueHead ) * 2 ) ;
223276 }
224277
225278 /// <summary>
@@ -234,7 +287,7 @@ internal LLamaSeqId GetNextSequenceId()
234287 throw new ArgumentOutOfRangeException ( nameof ( minCapacity ) , $ "Request batch capacity must be less than or equal to BatchSize ({ Context . BatchSize } )") ;
235288
236289 // Find a batch with space for at least minCapacity embeddings
237- for ( var i = 0 ; i < _batchQueue . Count ; i ++ )
290+ for ( var i = _batchQueueHead ; i < _batchQueue . Count ; i ++ )
238291 {
239292 var item = _batchQueue [ i ] ;
240293 if ( item is not EmbeddingBatch { Batch : var batch } )
@@ -245,13 +298,17 @@ internal LLamaSeqId GetNextSequenceId()
245298 continue ;
246299
247300 if ( batch . EmbeddingsCount < Context . BatchSize )
248- return ( batch , Epoch + ( uint ) ( i + 1 ) * 2 ) ;
301+ {
302+ _batchedTokenCountDirty = true ;
303+ return ( batch , Epoch + ( uint ) ( i - _batchQueueHead + 1 ) * 2 ) ;
304+ }
249305 }
250306
251307 // Add a new batch to the end of the queue
252308 var end = new LLamaBatchEmbeddings ( Context . EmbeddingSize ) ;
253309 _batchQueue . Add ( new EmbeddingBatch ( end ) ) ;
254- return ( end , Epoch + ( uint ) _batchQueue . Count * 2 ) ;
310+ _batchedTokenCountDirty = true ;
311+ return ( end , Epoch + ( uint ) ( _batchQueue . Count - _batchQueueHead ) * 2 ) ;
255312 }
256313
257314 #region batches
@@ -286,4 +343,4 @@ public Task<DecodeResult> DecodeAsync(LLamaContext ctx, CancellationToken token)
286343 }
287344 }
288345 #endregion
289- }
346+ }
0 commit comments