Skip to content

Commit ea4ba82

Browse files
SignalRTJosé Luis Santiago
authored andcommitted
Move common logic to LlamaExecutorBase
1 parent 03d4441 commit ea4ba82

File tree

3 files changed

+204
-355
lines changed

3 files changed

+204
-355
lines changed

LLama/LLamaExecutorBase.cs

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ public bool IsMultiModal
8686
/// <inheritdoc />
8787
public List<SafeMtmdEmbed> Embeds { get; }
8888

89+
/// <summary>
90+
/// Pending multimodal chunks produced by the MTMD tokenizer.
91+
/// </summary>
92+
protected SafeMtmdInputChunks? MtmdChunks { get; set; }
93+
94+
private string? _mtmdMarker;
95+
8996
private readonly StreamingTokenDecoder _decoder;
9097

9198
/// <summary>
@@ -242,6 +249,194 @@ protected virtual void TryReuseMatchingPrefix()
242249
}
243250
}
244251

252+
/// <summary>
253+
/// Dispose and clear any queued multimodal chunk collection.
254+
/// </summary>
255+
protected void DisposeMtmdChunks()
256+
{
257+
MtmdChunks?.Dispose();
258+
MtmdChunks = null;
259+
}
260+
261+
/// <summary>
262+
/// Dispose and clear any pending multimodal embeddings.
263+
/// </summary>
264+
protected void DisposeEmbeds()
265+
{
266+
if (Embeds.Count == 0)
267+
return;
268+
269+
foreach (var embed in Embeds)
270+
embed.Dispose();
271+
272+
Embeds.Clear();
273+
}
274+
275+
/// <summary>
276+
/// Retrieve the marker token used to signal media segments to the tokenizer.
277+
/// </summary>
278+
protected string GetMtmdMarker()
279+
{
280+
if (_mtmdMarker is not null)
281+
return _mtmdMarker;
282+
283+
_mtmdMarker = NativeApi.MtmdDefaultMarker() ?? "<media>";
284+
return _mtmdMarker;
285+
}
286+
287+
/// <summary>
288+
/// Ensure the token list fills all positional slots reported by the MTMD helper.
289+
/// </summary>
290+
protected static List<LLamaToken> BuildTokensWithFiller(List<LLamaToken> tokens, int totalPositions, LLamaToken fillerToken)
291+
{
292+
if (totalPositions <= tokens.Count)
293+
return new List<LLamaToken>(tokens);
294+
295+
var result = new List<LLamaToken>(totalPositions);
296+
result.AddRange(tokens);
297+
result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count));
298+
return result;
299+
}
300+
301+
/// <summary>
302+
/// Resolve the fallback token inserted when the tokenizer emits fewer tokens than positions.
303+
/// </summary>
304+
protected LLamaToken GetFillerToken(string marker)
305+
{
306+
var markerTokens = Context.Tokenize(marker, false, true);
307+
if (markerTokens.Length > 0)
308+
return markerTokens[markerTokens.Length - 1];
309+
310+
var eos = Context.Vocab.EOS;
311+
if (eos.HasValue)
312+
return eos.Value;
313+
314+
return default;
315+
}
316+
317+
/// <summary>
318+
/// Prepare multimodal inputs by invoking the MTMD tokenizer and aligning filler tokens.
319+
/// </summary>
320+
protected Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, bool replaceExisting)
321+
{
322+
if (ClipModel is null)
323+
throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model.");
324+
325+
DisposeMtmdChunks();
326+
327+
var marker = GetMtmdMarker();
328+
var prompt = text;
329+
330+
if (Embeds.Count > 0)
331+
{
332+
if (prompt.Contains("<image>"))
333+
prompt = prompt.Replace("<image>", marker);
334+
335+
if (!prompt.Contains(marker))
336+
{
337+
var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count));
338+
prompt = string.Concat(prompt, suffix);
339+
}
340+
}
341+
342+
SafeMtmdInputChunks? chunks = null;
343+
try
344+
{
345+
var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks);
346+
if (status != 0 || chunks is null)
347+
{
348+
ClipModel.ClearMedia();
349+
throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}.");
350+
}
351+
352+
MtmdChunks = chunks;
353+
354+
var tokens = new List<LLamaToken>();
355+
foreach (var chunk in chunks.Enumerate())
356+
{
357+
using var scopedChunk = chunk;
358+
if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text)
359+
continue;
360+
361+
foreach (var token in scopedChunk.GetTextTokensSpan())
362+
tokens.Add(unchecked((int)token));
363+
}
364+
365+
var totalPositions = (int)ClipModel.CountPositions(chunks);
366+
var fillerToken = GetFillerToken(marker);
367+
368+
if (replaceExisting)
369+
{
370+
_embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken);
371+
_consumedTokensCount = 0;
372+
}
373+
else
374+
{
375+
if (_embed_inps.Count == 0)
376+
_embed_inps = new List<LLamaToken>();
377+
378+
_embed_inps.AddRange(tokens);
379+
var fillerCount = totalPositions - tokens.Count;
380+
if (fillerCount > 0)
381+
_embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount));
382+
383+
args.RemainedTokens -= tokens.Count;
384+
}
385+
}
386+
catch
387+
{
388+
chunks?.Dispose();
389+
MtmdChunks = null;
390+
throw;
391+
}
392+
finally
393+
{
394+
DisposeEmbeds();
395+
}
396+
397+
return Task.CompletedTask;
398+
}
399+
400+
/// <summary>
401+
/// Apply bookkeeping after successfully evaluating multimodal chunks.
402+
/// </summary>
403+
protected void FinalizeMtmdEvaluation(long newNPast, int previousConsumed)
404+
{
405+
_pastTokensCount = checked((int)newNPast);
406+
DisposeMtmdChunks();
407+
408+
if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed)
409+
{
410+
_session_tokens.AddRange(_embed_inps.Skip(previousConsumed));
411+
_n_session_consumed = _session_tokens.Count;
412+
}
413+
414+
_consumedTokensCount = _embed_inps.Count;
415+
_embeds.Clear();
416+
}
417+
418+
/// <summary>
419+
/// Evaluate the queued MTMD chunks and update executor state.
420+
/// </summary>
421+
protected void EvaluateMtmdChunks(ref long nPast, int previousConsumed, string executorName)
422+
{
423+
if (ClipModel is null)
424+
throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model.");
425+
if (MtmdChunks is null)
426+
throw new InvalidOperationException("No MTMD chunks are queued for evaluation.");
427+
428+
var evalStatus = ClipModel.EvaluateChunks(MtmdChunks, Context.NativeHandle, ref nPast, seqId: 0,
429+
nBatch: checked((int)Context.BatchSize), logitsLast: true);
430+
if (evalStatus != 0)
431+
{
432+
_logger?.LogError("[{Executor}] Failed to evaluate multimodal chunks. Status: {Status}", executorName, evalStatus);
433+
DisposeMtmdChunks();
434+
throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}.");
435+
}
436+
437+
FinalizeMtmdEvaluation(nPast, previousConsumed);
438+
}
439+
245440
/// <summary>
246441
/// Determine whether the inference loop should continue processing tokens.
247442
/// </summary>

LLama/LLamaInstructExecutor.cs

Lines changed: 2 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ public class InstructExecutor
2626
private readonly string _instructionPrefix;
2727
private LLamaToken[] _inp_pfx;
2828
private LLamaToken[] _inp_sfx;
29-
private SafeMtmdInputChunks? _mtmdChunks;
30-
private string? _mtmdMarker;
3129
private readonly string _instructionSuffix;
3230

3331
/// <summary>
@@ -192,136 +190,6 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc
192190
return Task.CompletedTask;
193191
}
194192

195-
private void DisposeMtmdChunks()
196-
{
197-
_mtmdChunks?.Dispose();
198-
_mtmdChunks = null;
199-
}
200-
201-
private void DisposeEmbeds()
202-
{
203-
if (Embeds.Count == 0)
204-
return;
205-
206-
foreach (var embed in Embeds)
207-
embed.Dispose();
208-
209-
Embeds.Clear();
210-
}
211-
212-
private string GetMtmdMarker()
213-
{
214-
if (_mtmdMarker is not null)
215-
return _mtmdMarker;
216-
217-
_mtmdMarker = NativeApi.MtmdDefaultMarker() ?? "<media>";
218-
return _mtmdMarker;
219-
}
220-
221-
private static List<LLamaToken> BuildTokensWithFiller(List<LLamaToken> tokens, int totalPositions, LLamaToken fillerToken)
222-
{
223-
if (totalPositions <= tokens.Count)
224-
return new List<LLamaToken>(tokens);
225-
226-
var result = new List<LLamaToken>(totalPositions);
227-
result.AddRange(tokens);
228-
result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count));
229-
return result;
230-
}
231-
232-
private LLamaToken GetFillerToken(string marker)
233-
{
234-
var markerTokens = Context.Tokenize(marker, false, true);
235-
if (markerTokens.Length > 0)
236-
return markerTokens[markerTokens.Length - 1];
237-
238-
var eos = Context.Vocab.EOS;
239-
if (eos.HasValue)
240-
return eos.Value;
241-
242-
return default(LLamaToken);
243-
}
244-
245-
private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, bool replaceExisting)
246-
{
247-
if (ClipModel is null)
248-
throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model.");
249-
250-
DisposeMtmdChunks();
251-
252-
var marker = GetMtmdMarker();
253-
var prompt = text;
254-
255-
if (Embeds.Count > 0)
256-
{
257-
if (prompt.Contains("<image>"))
258-
prompt = prompt.Replace("<image>", marker);
259-
260-
if (!prompt.Contains(marker))
261-
{
262-
var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count));
263-
prompt = string.Concat(prompt, suffix);
264-
}
265-
}
266-
267-
SafeMtmdInputChunks? chunks = null;
268-
try
269-
{
270-
var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks);
271-
if (status != 0 || chunks is null)
272-
{
273-
ClipModel.ClearMedia();
274-
throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}.");
275-
}
276-
277-
_mtmdChunks = chunks;
278-
279-
var tokens = new List<LLamaToken>();
280-
foreach (var chunk in chunks.Enumerate())
281-
{
282-
using var scopedChunk = chunk;
283-
if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text)
284-
continue;
285-
286-
foreach (var token in scopedChunk.GetTextTokensSpan())
287-
tokens.Add(unchecked((int)token));
288-
}
289-
290-
var totalPositions = (int)ClipModel.CountPositions(chunks);
291-
var fillerToken = GetFillerToken(marker);
292-
293-
if (replaceExisting)
294-
{
295-
_embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken);
296-
_consumedTokensCount = 0;
297-
}
298-
else
299-
{
300-
if (_embed_inps.Count == 0)
301-
_embed_inps = new List<LLamaToken>();
302-
303-
_embed_inps.AddRange(tokens);
304-
var fillerCount = totalPositions - tokens.Count;
305-
if (fillerCount > 0)
306-
_embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount));
307-
308-
args.RemainedTokens -= tokens.Count;
309-
}
310-
}
311-
catch
312-
{
313-
chunks?.Dispose();
314-
_mtmdChunks = null;
315-
throw;
316-
}
317-
finally
318-
{
319-
DisposeEmbeds();
320-
}
321-
322-
return Task.CompletedTask;
323-
}
324-
325193
/// <inheritdoc />
326194
protected override Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken)
327195
{
@@ -384,30 +252,12 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
384252
_n_session_consumed = _session_tokens.Count;
385253
}
386254
}
387-
else if (IsMultiModal && _mtmdChunks is not null)
255+
else if (IsMultiModal && MtmdChunks is not null)
388256
{
389257
_is_prompt_run = false;
390258
var nPast = (long)_pastTokensCount;
391259
var previousConsumed = _consumedTokensCount;
392-
var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)Context.BatchSize), logitsLast: true);
393-
if (evalStatus != 0)
394-
{
395-
_logger?.LogError("[InstructExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus);
396-
DisposeMtmdChunks();
397-
throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}.");
398-
}
399-
400-
_pastTokensCount = checked((int)nPast);
401-
DisposeMtmdChunks();
402-
403-
if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed)
404-
{
405-
_session_tokens.AddRange(_embed_inps.Skip(previousConsumed));
406-
_n_session_consumed = _session_tokens.Count;
407-
}
408-
409-
_consumedTokensCount = _embed_inps.Count;
410-
_embeds.Clear();
260+
EvaluateMtmdChunks(ref nPast, previousConsumed, nameof(InstructExecutor));
411261
}
412262

413263
_embeds.Clear();

0 commit comments

Comments
 (0)