Skip to content

Commit bef0aae

Browse files
committed
Some small optimizations
1 parent de00c15 commit bef0aae

File tree

8 files changed

+230
-61
lines changed

8 files changed

+230
-61
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
using System.Linq;
2+
using BenchmarkDotNet.Attributes;
3+
using BenchmarkDotNet.Engines;
4+
using BenchmarkDotNet.Jobs;
5+
using LLama.Common;
6+
7+
namespace LLama.Benchmark.Collections;
8+
9+
[SimpleJob(RunStrategy.Throughput, RuntimeMoniker.Net80)]
10+
[MemoryDiagnoser]
11+
[BenchmarkCategory("Collections", "FixedSizeQueue")]
12+
public class FixedSizeQueueBenchmark
13+
{
14+
[Params(32, 512, 4096)]
15+
public int Capacity { get; set; }
16+
17+
private int[] _values = Array.Empty<int>();
18+
19+
[GlobalSetup]
20+
public void Setup()
21+
{
22+
_values = Enumerable.Range(0, Capacity * 4).ToArray();
23+
}
24+
25+
[Benchmark]
26+
public int EnqueueWrap()
27+
{
28+
var queue = new FixedSizeQueue<int>(Capacity);
29+
foreach (var value in _values)
30+
queue.Enqueue(value);
31+
return queue.Count;
32+
}
33+
34+
[Benchmark]
35+
public int IterateTailSum()
36+
{
37+
var queue = new FixedSizeQueue<int>(Capacity);
38+
foreach (var value in _values)
39+
queue.Enqueue(value);
40+
41+
var sum = 0;
42+
foreach (var value in queue)
43+
sum += value;
44+
return sum;
45+
}
46+
}

LLama/AntipromptProcessor.cs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public sealed class AntipromptProcessor
1111
private int _longestAntiprompt;
1212
private readonly List<string> _antiprompts = new();
1313

14-
private string? _string;
14+
private string _buffer = string.Empty;
1515

1616

1717
/// <summary>
@@ -46,6 +46,8 @@ public void SetAntiprompts(IEnumerable<string> antiprompts)
4646
_longestAntiprompt = 0;
4747
foreach (var antiprompt in _antiprompts)
4848
_longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length);
49+
50+
_buffer = string.Empty;
4951
}
5052

5153
/// <summary>
@@ -55,21 +57,21 @@ public void SetAntiprompts(IEnumerable<string> antiprompts)
5557
/// <returns>true if the text buffer ends with any antiprompt</returns>
5658
public bool Add(string text)
5759
{
58-
_string += text;
60+
_buffer += text;
5961

6062
// When the string gets very long (4x antiprompt length) trim it down (to 2x antiprompt length).
6163
// This trimming leaves a lot of extra characters because two sequences can be considered "equal" in unicode
6264
// even with different numbers of characters. Hopefully there are enough characters here to handle all those weird circumstances!
6365
var maxLength = Math.Max(32, _longestAntiprompt * 4);
6466
var trimLength = Math.Max(16, _longestAntiprompt * 2);
65-
if (_string.Length > maxLength)
66-
_string = _string.Substring(_string.Length - trimLength);
67+
if (_buffer.Length > maxLength)
68+
_buffer = _buffer.Substring(_buffer.Length - trimLength);
6769

6870
foreach (var antiprompt in _antiprompts)
69-
if (_string.EndsWith(antiprompt, StringComparison.CurrentCulture))
71+
if (_buffer.EndsWith(antiprompt, StringComparison.CurrentCulture))
7072
return true;
7173

7274
return false;
7375
}
7476
}
75-
}
77+
}

LLama/Batched/BatchedExecutor.cs

Lines changed: 73 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Diagnostics;
4-
using System.Linq;
54
using System.Threading;
65
using System.Threading.Tasks;
76
using LLama.Abstractions;
@@ -16,7 +15,10 @@ public sealed class BatchedExecutor
1615
: IDisposable
1716
{
1817
private int _nextSequenceId;
19-
private readonly List<IBatch> _batchQueue = [ ];
18+
private readonly List<IBatch> _batchQueue = [];
19+
private int _batchQueueHead;
20+
private int _batchedTokenCount;
21+
private bool _batchedTokenCountDirty = true;
2022

2123
/// <summary>
2224
/// Set to 1 using interlocked exchange while inference is running
@@ -42,12 +44,27 @@ public sealed class BatchedExecutor
4244
/// <summary>
4345
/// Get the number of tokens in the batch, waiting for <see cref="Infer"/> to be called
4446
/// </summary>
45-
public int BatchedTokenCount => _batchQueue.Sum(a => a.ItemCount);
47+
public int BatchedTokenCount
48+
{
49+
get
50+
{
51+
if (_batchedTokenCountDirty)
52+
{
53+
var total = 0;
54+
for (var i = _batchQueueHead; i < _batchQueue.Count; i++)
55+
total += _batchQueue[i].ItemCount;
56+
_batchedTokenCount = total;
57+
_batchedTokenCountDirty = false;
58+
}
59+
60+
return _batchedTokenCount;
61+
}
62+
}
4663

4764
/// <summary>
4865
/// Number of batches in the queue, waiting for <see cref="Infer"/> to be called
4966
/// </summary>
50-
public int BatchQueueCount => _batchQueue.Count;
67+
public int BatchQueueCount => _batchQueue.Count - _batchQueueHead;
5168

5269
/// <summary>
5370
/// Check if this executor has been disposed.
@@ -147,12 +164,13 @@ public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
147164
// again after the issue has been fixed (e.g. some KV cache space has been freed) to retry this operation.
148165
if (status != DecodeResult.Ok)
149166
{
150-
_batchQueue.Insert(0, next);
167+
RequeueFront(next);
151168
return status;
152169
}
153170

154171
// Everything was ok, advance the epoch
155172
Epoch++;
173+
CleanupQueue();
156174

157175
return status;
158176
}
@@ -166,13 +184,44 @@ public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
166184

167185
IBatch? GetNextBatch()
168186
{
169-
if (_batchQueue.Count == 0)
187+
if (_batchQueueHead >= _batchQueue.Count)
188+
{
189+
_batchQueue.Clear();
190+
_batchQueueHead = 0;
170191
return null;
171-
172-
var nextBatch = _batchQueue[0];
173-
_batchQueue.RemoveAt(0);
192+
}
193+
194+
var nextBatch = _batchQueue[_batchQueueHead];
195+
_batchQueueHead++;
196+
_batchedTokenCountDirty = true;
174197
return nextBatch;
175198
}
199+
200+
void RequeueFront(IBatch batch)
201+
{
202+
Debug.Assert(_batchQueueHead > 0, "Cannot requeue batch when queue head is at zero.");
203+
_batchQueue[--_batchQueueHead] = batch;
204+
_batchedTokenCountDirty = true;
205+
}
206+
207+
void CleanupQueue()
208+
{
209+
if (_batchQueueHead == 0)
210+
return;
211+
212+
if (_batchQueueHead >= _batchQueue.Count)
213+
{
214+
_batchQueue.Clear();
215+
_batchQueueHead = 0;
216+
return;
217+
}
218+
219+
if (_batchQueueHead > 16 && _batchQueueHead > _batchQueue.Count / 2)
220+
{
221+
_batchQueue.RemoveRange(0, _batchQueueHead);
222+
_batchQueueHead = 0;
223+
}
224+
}
176225
}
177226

178227
/// <inheritdoc />
@@ -202,7 +251,7 @@ internal LLamaSeqId GetNextSequenceId()
202251
throw new ArgumentOutOfRangeException(nameof(minCapacity), $"Request batch capacity must be less than or equal to BatchSize ({Context.BatchSize})");
203252

204253
// Find a batch with space for at least minCapacity tokens
205-
for (var i = 0; i < _batchQueue.Count; i++)
254+
for (var i = _batchQueueHead; i < _batchQueue.Count; i++)
206255
{
207256
var item = _batchQueue[i];
208257
if (item is not TokenBatch { Batch: var batch })
@@ -213,13 +262,17 @@ internal LLamaSeqId GetNextSequenceId()
213262
continue;
214263

215264
if (batch.TokenCount < Context.BatchSize)
216-
return (batch, Epoch + (uint)(i + 1) * 2);
265+
{
266+
_batchedTokenCountDirty = true;
267+
return (batch, Epoch + (uint)(i - _batchQueueHead + 1) * 2);
268+
}
217269
}
218270

219271
// Add a new batch to the end of the queue
220272
var end = new LLamaBatch();
221273
_batchQueue.Add(new TokenBatch(end));
222-
return (end, Epoch + (uint)_batchQueue.Count * 2);
274+
_batchedTokenCountDirty = true;
275+
return (end, Epoch + (uint)(_batchQueue.Count - _batchQueueHead) * 2);
223276
}
224277

225278
/// <summary>
@@ -234,7 +287,7 @@ internal LLamaSeqId GetNextSequenceId()
234287
throw new ArgumentOutOfRangeException(nameof(minCapacity), $"Request batch capacity must be less than or equal to BatchSize ({Context.BatchSize})");
235288

236289
// Find a batch with space for at least minCapacity embeddings
237-
for (var i = 0; i < _batchQueue.Count; i++)
290+
for (var i = _batchQueueHead; i < _batchQueue.Count; i++)
238291
{
239292
var item = _batchQueue[i];
240293
if (item is not EmbeddingBatch { Batch: var batch })
@@ -245,13 +298,17 @@ internal LLamaSeqId GetNextSequenceId()
245298
continue;
246299

247300
if (batch.EmbeddingsCount < Context.BatchSize)
248-
return (batch, Epoch + (uint)(i + 1) * 2);
301+
{
302+
_batchedTokenCountDirty = true;
303+
return (batch, Epoch + (uint)(i - _batchQueueHead + 1) * 2);
304+
}
249305
}
250306

251307
// Add a new batch to the end of the queue
252308
var end = new LLamaBatchEmbeddings(Context.EmbeddingSize);
253309
_batchQueue.Add(new EmbeddingBatch(end));
254-
return (end, Epoch + (uint)_batchQueue.Count * 2);
310+
_batchedTokenCountDirty = true;
311+
return (end, Epoch + (uint)(_batchQueue.Count - _batchQueueHead) * 2);
255312
}
256313

257314
#region batches
@@ -286,4 +343,4 @@ public Task<DecodeResult> DecodeAsync(LLamaContext ctx, CancellationToken token)
286343
}
287344
}
288345
#endregion
289-
}
346+
}

0 commit comments

Comments
 (0)