Skip to content

Commit 9d78f71

Browse files
committed
Review logic
1 parent 22b4c75 commit 9d78f71

File tree

3 files changed

+23
-33
lines changed

3 files changed

+23
-33
lines changed

LLama.Examples/Examples/BatchedExecutorMtmd.cs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,26 +63,27 @@ public static async Task Run()
6363
Console.ResetColor();
6464

6565
var remaining = TokenCount;
66-
while (remaining > 0)
66+
67+
async Task<bool> ProcessNextAsync()
6768
{
6869
var decodeResult = await executor.Infer();
6970
if (decodeResult == DecodeResult.NoKvSlot)
7071
{
7172
Console.ForegroundColor = ConsoleColor.Red;
7273
Console.WriteLine("Insufficient KV cache space for multimodal evaluation.");
7374
Console.ResetColor();
74-
break;
75+
return false;
7576
}
7677

7778
if (decodeResult != DecodeResult.Ok)
7879
throw new RuntimeError($"Failed to evaluate batch: {decodeResult}.");
7980

8081
if (!conversation.RequiresSampling)
81-
continue;
82+
return true;
8283

8384
var token = conversation.Sample(sampler);
8485
if (token.IsEndOfGeneration(vocab))
85-
break;
86+
return false;
8687

8788
decoder.Add(token);
8889
var delta = decoder.Read();
@@ -92,6 +93,11 @@ public static async Task Run()
9293
sampler.Accept(token);
9394
conversation.Prompt(token);
9495
remaining--;
96+
return remaining > 0;
97+
}
98+
99+
while (remaining > 0 && await ProcessNextAsync())
100+
{
95101
}
96102

97103
Console.WriteLine();

LLama/Batched/BatchedExecutor.cs

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -332,17 +332,12 @@ public Task<DecodeResult> DecodeAsync(LLamaContext ctx, CancellationToken token)
332332
try
333333
{
334334
var nPast = _conversation.GetMtmdPast();
335-
for (var i = 0; i < _sequence.Chunks.Count; i++)
335+
var status = _clipModel.EvaluateChunks(_sequence.Chunks, ctx.NativeHandle, ref nPast,
336+
(int)_conversation.ConversationId, checked((int)ctx.BatchSize), logitsLast: true);
337+
if (status != 0)
336338
{
337-
var chunk = _sequence.Chunks[i];
338-
var logitsLast = i == _sequence.Chunks.Count - 1;
339-
var status = _clipModel.EvaluateChunk(chunk.NativePtr, ctx.NativeHandle, ref nPast,
340-
(int)_conversation.ConversationId, checked((int)ctx.BatchSize), logitsLast);
341-
if (status != 0)
342-
{
343-
_conversation.OnMtmdEvaluationFailed(status);
344-
return Task.FromResult(DecodeResult.DecodeFailed);
345-
}
339+
_conversation.OnMtmdEvaluationFailed(status);
340+
return Task.FromResult(DecodeResult.DecodeFailed);
346341
}
347342

348343
_conversation.OnMtmdEvaluationCompleted(nPast, _sequence);

LLama/Batched/Conversation.cs

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,12 @@ internal Conversation(BatchedExecutor batch, LLamaSeqId id)
7575

7676
internal sealed class MtmdChunkSequence : IDisposable
7777
{
78-
public List<SafeMtmdInputChunk> Chunks { get; }
78+
public SafeMtmdInputChunks Chunks { get; }
7979
public List<LLamaToken> TextTokens { get; }
8080
public int TotalPositions { get; }
8181
public int TotalTokens => TextTokens.Count;
8282

83-
private MtmdChunkSequence(List<SafeMtmdInputChunk> chunks, List<LLamaToken> textTokens, int totalPositions)
83+
private MtmdChunkSequence(SafeMtmdInputChunks chunks, List<LLamaToken> textTokens, int totalPositions)
8484
{
8585
Chunks = chunks;
8686
TextTokens = textTokens;
@@ -89,38 +89,27 @@ private MtmdChunkSequence(List<SafeMtmdInputChunk> chunks, List<LLamaToken> text
8989

9090
public static MtmdChunkSequence Create(SafeMtmdInputChunks chunks, SafeMtmdWeights clipModel)
9191
{
92-
var copies = new List<SafeMtmdInputChunk>();
9392
var textTokens = new List<LLamaToken>();
9493

9594
foreach (var chunk in chunks.Enumerate())
9695
{
97-
var copy = chunk.Copy();
98-
if (copy == null)
96+
using (chunk)
9997
{
100-
chunk.Dispose();
101-
continue;
102-
}
103-
104-
copies.Add(copy);
98+
if (chunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text)
99+
continue;
105100

106-
if (copy.Type == SafeMtmdInputChunk.SafeMtmdInputChunkType.Text)
107-
{
108-
foreach (var token in copy.GetTextTokensSpan())
101+
foreach (var token in chunk.GetTextTokensSpan())
109102
textTokens.Add((LLamaToken)unchecked((int)token));
110103
}
111-
112-
chunk.Dispose();
113104
}
114105

115106
var totalPositions = (int)clipModel.CountPositions(chunks);
116-
return new MtmdChunkSequence(copies, textTokens, totalPositions);
107+
return new MtmdChunkSequence(chunks, textTokens, totalPositions);
117108
}
118109

119110
public void Dispose()
120111
{
121-
foreach (var chunk in Chunks)
122-
chunk.Dispose();
123-
Chunks.Clear();
112+
Chunks.Dispose();
124113
}
125114
}
126115

0 commit comments

Comments
 (0)