Skip to content

Commit d5aab12

Browse files
committed
Move common logic to LlamaExecutorBase
1 parent 384ec34 commit d5aab12

File tree

3 files changed

+204
-359
lines changed

3 files changed

+204
-359
lines changed

LLama/LLamaExecutorBase.cs

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ public bool IsMultiModal
8181
/// <inheritdoc />
8282
public List<SafeMtmdEmbed> Embeds { get; }
8383

84+
/// <summary>
85+
/// Pending multimodal chunks produced by the MTMD tokenizer.
86+
/// </summary>
87+
protected SafeMtmdInputChunks? MtmdChunks { get; set; }
88+
89+
private string? _mtmdMarker;
90+
8491
private readonly StreamingTokenDecoder _decoder;
8592

8693
/// <summary>
@@ -235,6 +242,194 @@ protected virtual void TryReuseMatchingPrefix()
235242
}
236243
}
237244

245+
/// <summary>
246+
/// Dispose and clear any queued multimodal chunk collection.
247+
/// </summary>
248+
protected void DisposeMtmdChunks()
249+
{
250+
MtmdChunks?.Dispose();
251+
MtmdChunks = null;
252+
}
253+
254+
/// <summary>
255+
/// Dispose and clear any pending multimodal embeddings.
256+
/// </summary>
257+
protected void DisposeEmbeds()
258+
{
259+
if (Embeds.Count == 0)
260+
return;
261+
262+
foreach (var embed in Embeds)
263+
embed.Dispose();
264+
265+
Embeds.Clear();
266+
}
267+
268+
/// <summary>
269+
/// Retrieve the marker token used to signal media segments to the tokenizer.
270+
/// </summary>
271+
protected string GetMtmdMarker()
272+
{
273+
if (_mtmdMarker is not null)
274+
return _mtmdMarker;
275+
276+
_mtmdMarker = NativeApi.MtmdDefaultMarker() ?? "<media>";
277+
return _mtmdMarker;
278+
}
279+
280+
/// <summary>
281+
/// Ensure the token list fills all positional slots reported by the MTMD helper.
282+
/// </summary>
283+
protected static List<LLamaToken> BuildTokensWithFiller(List<LLamaToken> tokens, int totalPositions, LLamaToken fillerToken)
284+
{
285+
if (totalPositions <= tokens.Count)
286+
return new List<LLamaToken>(tokens);
287+
288+
var result = new List<LLamaToken>(totalPositions);
289+
result.AddRange(tokens);
290+
result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count));
291+
return result;
292+
}
293+
294+
/// <summary>
295+
/// Resolve the fallback token inserted when the tokenizer emits fewer tokens than positions.
296+
/// </summary>
297+
protected LLamaToken GetFillerToken(string marker)
298+
{
299+
var markerTokens = Context.Tokenize(marker, false, true);
300+
if (markerTokens.Length > 0)
301+
return markerTokens[markerTokens.Length - 1];
302+
303+
var eos = Context.Vocab.EOS;
304+
if (eos.HasValue)
305+
return eos.Value;
306+
307+
return default;
308+
}
309+
310+
/// <summary>
311+
/// Prepare multimodal inputs by invoking the MTMD tokenizer and aligning filler tokens.
312+
/// </summary>
313+
protected Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, bool replaceExisting)
314+
{
315+
if (ClipModel is null)
316+
throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model.");
317+
318+
DisposeMtmdChunks();
319+
320+
var marker = GetMtmdMarker();
321+
var prompt = text;
322+
323+
if (Embeds.Count > 0)
324+
{
325+
if (prompt.Contains("<image>"))
326+
prompt = prompt.Replace("<image>", marker);
327+
328+
if (!prompt.Contains(marker))
329+
{
330+
var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count));
331+
prompt = string.Concat(prompt, suffix);
332+
}
333+
}
334+
335+
SafeMtmdInputChunks? chunks = null;
336+
try
337+
{
338+
var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks);
339+
if (status != 0 || chunks is null)
340+
{
341+
ClipModel.ClearMedia();
342+
throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}.");
343+
}
344+
345+
MtmdChunks = chunks;
346+
347+
var tokens = new List<LLamaToken>();
348+
foreach (var chunk in chunks.Enumerate())
349+
{
350+
using var scopedChunk = chunk;
351+
if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text)
352+
continue;
353+
354+
foreach (var token in scopedChunk.GetTextTokensSpan())
355+
tokens.Add(unchecked((int)token));
356+
}
357+
358+
var totalPositions = (int)ClipModel.CountPositions(chunks);
359+
var fillerToken = GetFillerToken(marker);
360+
361+
if (replaceExisting)
362+
{
363+
_embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken);
364+
_consumedTokensCount = 0;
365+
}
366+
else
367+
{
368+
if (_embed_inps.Count == 0)
369+
_embed_inps = new List<LLamaToken>();
370+
371+
_embed_inps.AddRange(tokens);
372+
var fillerCount = totalPositions - tokens.Count;
373+
if (fillerCount > 0)
374+
_embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount));
375+
376+
args.RemainedTokens -= tokens.Count;
377+
}
378+
}
379+
catch
380+
{
381+
chunks?.Dispose();
382+
MtmdChunks = null;
383+
throw;
384+
}
385+
finally
386+
{
387+
DisposeEmbeds();
388+
}
389+
390+
return Task.CompletedTask;
391+
}
392+
393+
/// <summary>
394+
/// Apply bookkeeping after successfully evaluating multimodal chunks.
395+
/// </summary>
396+
protected void FinalizeMtmdEvaluation(long newNPast, int previousConsumed)
397+
{
398+
_pastTokensCount = checked((int)newNPast);
399+
DisposeMtmdChunks();
400+
401+
if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed)
402+
{
403+
_session_tokens.AddRange(_embed_inps.Skip(previousConsumed));
404+
_n_session_consumed = _session_tokens.Count;
405+
}
406+
407+
_consumedTokensCount = _embed_inps.Count;
408+
_embeds.Clear();
409+
}
410+
411+
/// <summary>
412+
/// Evaluate the queued MTMD chunks and update executor state.
413+
/// </summary>
414+
protected void EvaluateMtmdChunks(ref long nPast, int previousConsumed, string executorName)
415+
{
416+
if (ClipModel is null)
417+
throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model.");
418+
if (MtmdChunks is null)
419+
throw new InvalidOperationException("No MTMD chunks are queued for evaluation.");
420+
421+
var evalStatus = ClipModel.EvaluateChunks(MtmdChunks, Context.NativeHandle, ref nPast, seqId: 0,
422+
nBatch: checked((int)Context.BatchSize), logitsLast: true);
423+
if (evalStatus != 0)
424+
{
425+
_logger?.LogError("[{Executor}] Failed to evaluate multimodal chunks. Status: {Status}", executorName, evalStatus);
426+
DisposeMtmdChunks();
427+
throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}.");
428+
}
429+
430+
FinalizeMtmdEvaluation(nPast, previousConsumed);
431+
}
432+
238433
/// <summary>
239434
/// Determine whether the inference loop should continue processing tokens.
240435
/// </summary>

LLama/LLamaInstructExecutor.cs

Lines changed: 2 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ public class InstructExecutor
2525
private readonly string _instructionPrefix;
2626
private LLamaToken[] _inp_pfx;
2727
private LLamaToken[] _inp_sfx;
28-
private SafeMtmdInputChunks? _mtmdChunks;
29-
private string? _mtmdMarker;
3028
private readonly string _instructionSuffix;
3129

3230
/// <summary>
@@ -190,136 +188,6 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
190188
return Task.CompletedTask;
191189
}
192190

193-
private void DisposeMtmdChunks()
194-
{
195-
_mtmdChunks?.Dispose();
196-
_mtmdChunks = null;
197-
}
198-
199-
private void DisposeEmbeds()
200-
{
201-
if (Embeds.Count == 0)
202-
return;
203-
204-
foreach (var embed in Embeds)
205-
embed.Dispose();
206-
207-
Embeds.Clear();
208-
}
209-
210-
private string GetMtmdMarker()
211-
{
212-
if (_mtmdMarker is not null)
213-
return _mtmdMarker;
214-
215-
_mtmdMarker = NativeApi.MtmdDefaultMarker() ?? "<media>";
216-
return _mtmdMarker;
217-
}
218-
219-
private static List<LLamaToken> BuildTokensWithFiller(List<LLamaToken> tokens, int totalPositions, LLamaToken fillerToken)
220-
{
221-
if (totalPositions <= tokens.Count)
222-
return new List<LLamaToken>(tokens);
223-
224-
var result = new List<LLamaToken>(totalPositions);
225-
result.AddRange(tokens);
226-
result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count));
227-
return result;
228-
}
229-
230-
private LLamaToken GetFillerToken(string marker)
231-
{
232-
var markerTokens = Context.Tokenize(marker, false, true);
233-
if (markerTokens.Length > 0)
234-
return markerTokens[markerTokens.Length - 1];
235-
236-
var eos = Context.Vocab.EOS;
237-
if (eos.HasValue)
238-
return eos.Value;
239-
240-
return default(LLamaToken);
241-
}
242-
243-
private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, bool replaceExisting)
244-
{
245-
if (ClipModel is null)
246-
throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model.");
247-
248-
DisposeMtmdChunks();
249-
250-
var marker = GetMtmdMarker();
251-
var prompt = text;
252-
253-
if (Embeds.Count > 0)
254-
{
255-
if (prompt.Contains("<image>"))
256-
prompt = prompt.Replace("<image>", marker);
257-
258-
if (!prompt.Contains(marker))
259-
{
260-
var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count));
261-
prompt = string.Concat(prompt, suffix);
262-
}
263-
}
264-
265-
SafeMtmdInputChunks? chunks = null;
266-
try
267-
{
268-
var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks);
269-
if (status != 0 || chunks is null)
270-
{
271-
ClipModel.ClearMedia();
272-
throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}.");
273-
}
274-
275-
_mtmdChunks = chunks;
276-
277-
var tokens = new List<LLamaToken>();
278-
foreach (var chunk in chunks.Enumerate())
279-
{
280-
using var scopedChunk = chunk;
281-
if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text)
282-
continue;
283-
284-
foreach (var token in scopedChunk.GetTextTokensSpan())
285-
tokens.Add(unchecked((int)token));
286-
}
287-
288-
var totalPositions = (int)ClipModel.CountPositions(chunks);
289-
var fillerToken = GetFillerToken(marker);
290-
291-
if (replaceExisting)
292-
{
293-
_embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken);
294-
_consumedTokensCount = 0;
295-
}
296-
else
297-
{
298-
if (_embed_inps.Count == 0)
299-
_embed_inps = new List<LLamaToken>();
300-
301-
_embed_inps.AddRange(tokens);
302-
var fillerCount = totalPositions - tokens.Count;
303-
if (fillerCount > 0)
304-
_embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount));
305-
306-
args.RemainedTokens -= tokens.Count;
307-
}
308-
}
309-
catch
310-
{
311-
chunks?.Dispose();
312-
_mtmdChunks = null;
313-
throw;
314-
}
315-
finally
316-
{
317-
DisposeEmbeds();
318-
}
319-
320-
return Task.CompletedTask;
321-
}
322-
323191
/// <inheritdoc />
324192
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
325193
{
@@ -380,30 +248,12 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
380248
_n_session_consumed = _session_tokens.Count;
381249
}
382250
}
383-
else if (IsMultiModal && _mtmdChunks is not null)
251+
else if (IsMultiModal && MtmdChunks is not null)
384252
{
385253
_is_prompt_run = false;
386254
var nPast = (long)_pastTokensCount;
387255
var previousConsumed = _consumedTokensCount;
388-
var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)Context.BatchSize), logitsLast: true);
389-
if (evalStatus != 0)
390-
{
391-
_logger?.LogError("[InstructExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus);
392-
DisposeMtmdChunks();
393-
throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}.");
394-
}
395-
396-
_pastTokensCount = checked((int)nPast);
397-
DisposeMtmdChunks();
398-
399-
if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed)
400-
{
401-
_session_tokens.AddRange(_embed_inps.Skip(previousConsumed));
402-
_n_session_consumed = _session_tokens.Count;
403-
}
404-
405-
_consumedTokensCount = _embed_inps.Count;
406-
_embeds.Clear();
256+
EvaluateMtmdChunks(ref nPast, previousConsumed, nameof(InstructExecutor));
407257
}
408258

409259
_embeds.Clear();

0 commit comments

Comments
 (0)