Skip to content

Commit 4e4eaf9

Browse files
committed
WIP
1 parent 71a44c9 commit 4e4eaf9

File tree

2 files changed

+162
-64
lines changed

2 files changed

+162
-64
lines changed

LLama/LLamaInteractExecutor.cs

Lines changed: 144 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using System.Text.Json;
99
using System.Text.Json.Serialization;
1010
using System.Threading.Tasks;
11+
using LLama;
1112
using LLama.Exceptions;
1213
using LLama.Sampling;
1314
using Microsoft.Extensions.Logging;
@@ -21,12 +22,10 @@ namespace LLama
2122
public class InteractiveExecutor : StatefulExecutorBase
2223
{
2324
private bool _is_prompt_run = true;
24-
25-
// LLava
26-
private int _EmbedImagePosition = -1;
27-
// TODO JLS:
28-
//private List<SafeMtmdImageEmbedHandle> _imageEmbedHandles = new List<SafeMtmdImageEmbedHandle>();
29-
private bool _imageInPrompt = false;
25+
26+
// MTMD multimodal state
27+
private SafeMtmdInputChunks? _mtmdChunks;
28+
private string? _mtmdMarker;
3029

3130
/// <summary>
3231
///
@@ -71,6 +70,7 @@ public override ExecutorBaseState GetStateData()
7170
/// <inheritdoc />
7271
public override Task LoadState(ExecutorBaseState data)
7372
{
73+
DisposeMtmdChunks();
7474
if (data is InteractiveExecutorState state)
7575
{
7676
_n_session_consumed = state.ConsumedSessionCount;
@@ -130,7 +130,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
130130
}
131131
else
132132
{
133-
PreprocessLlava(text, args, true);
133+
PreprocessMtmd(text, args, true);
134134
}
135135
}
136136
else
@@ -151,51 +151,121 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
151151
}
152152
else
153153
{
154-
PreprocessLlava(text, args, false);
154+
PreprocessMtmd(text, args, false);
155155
}
156156
}
157157
}
158158

159159
return Task.CompletedTask;
160160
}
161161

162+
private void DisposeMtmdChunks()
163+
{
164+
_mtmdChunks?.Dispose();
165+
_mtmdChunks = null;
166+
}
167+
168+
private void DisposeEmbeds()
169+
{
170+
if (Embeds.Count == 0)
171+
{
172+
return;
173+
}
174+
175+
foreach (var embed in Embeds)
176+
{
177+
embed.Dispose();
178+
}
179+
180+
Embeds.Clear();
181+
}
182+
183+
private string GetMtmdMarker()
184+
{
185+
if (_mtmdMarker is not null)
186+
{
187+
return _mtmdMarker;
188+
}
189+
190+
_mtmdMarker = NativeApi.MtmdDefaultMarker() ?? "<media>";
191+
return _mtmdMarker;
192+
}
193+
162194
/// <inheritdoc />
163-
private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true )
164-
{
165-
// If the prompt contains the tag <image> extract this.
166-
_imageInPrompt = text.Contains("<image>");
167-
if (_imageInPrompt && IsMultiModal)
195+
private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos = true)
196+
{
197+
if (ClipModel is null)
198+
{
199+
throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model.");
200+
}
201+
202+
DisposeMtmdChunks();
203+
204+
var marker = GetMtmdMarker();
205+
var prompt = text;
206+
207+
if (Embeds.Count > 0)
168208
{
169-
foreach (var embed in Embeds)
209+
if (prompt.Contains("<image>"))
170210
{
171-
// TODO JLS:
172-
//_imageEmbedHandles.Add(SafeMtmdImageEmbedHandle.CreateFromMemory(ClipModel!.NativeHandle, Context, image));
211+
prompt = prompt.Replace("<image>", marker);
173212
}
174213

175-
int imageIndex = text.IndexOf("<image>");
176-
// Tokenize segment 1 (before <image> tag)
177-
string preImagePrompt = text.Substring(0, imageIndex);
178-
var segment1 = Context.Tokenize(preImagePrompt, addBos, true);
179-
// Remember the position to add the image embeddings
180-
_EmbedImagePosition = segment1.Length;
181-
string postImagePrompt = text.Substring(imageIndex + 7);
182-
var segment2 = Context.Tokenize(postImagePrompt, false, true);
183-
_embed_inps.AddRange(segment1);
184-
_embed_inps.AddRange(segment2);
214+
if (!prompt.Contains(marker))
215+
{
216+
var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count));
217+
prompt = string.Concat(prompt, suffix);
218+
}
185219
}
186-
else
220+
221+
SafeMtmdInputChunks? chunks = null;
222+
try
187223
{
224+
var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks);
225+
if (status != 0 || chunks is null)
226+
{
227+
ClipModel.ClearMedia();
228+
throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}.");
229+
}
230+
231+
_mtmdChunks = chunks;
232+
233+
var tokens = new List<LLamaToken>();
234+
foreach (var chunk in chunks.Enumerate())
235+
{
236+
using var scopedChunk = chunk;
237+
if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text)
238+
{
239+
continue;
240+
}
241+
242+
foreach (var token in scopedChunk.GetTextTokensSpan())
243+
{
244+
tokens.Add(unchecked((int)token));
245+
}
246+
}
247+
188248
if (addBos)
189249
{
190-
_embed_inps = Context.Tokenize(text, true, true).ToList();
250+
_embed_inps = tokens;
191251
}
192252
else
193253
{
194-
var line_inp = Context.Tokenize(text, false, true);
195-
_embed_inps.AddRange(line_inp);
196-
args.RemainedTokens -= line_inp.Length;
254+
_embed_inps.AddRange(tokens);
255+
args.RemainedTokens -= tokens.Count;
197256
}
198257
}
258+
catch
259+
{
260+
chunks?.Dispose();
261+
_mtmdChunks = null;
262+
throw;
263+
}
264+
finally
265+
{
266+
DisposeEmbeds();
267+
}
268+
199269
return Task.CompletedTask;
200270
}
201271

@@ -255,49 +325,60 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
255325
HandleRunOutOfContext(tokensToKeep);
256326
}
257327

258-
TryReuseMatchingPrefix();
328+
if (_mtmdChunks is null)
329+
{
330+
TryReuseMatchingPrefix();
331+
}
259332

260-
// Changes to support Multi-Modal LLMs.
261-
//
262-
(DecodeResult, int, int) header, end, result;
263-
if (IsMultiModal && _EmbedImagePosition > 0)
333+
if (IsMultiModal && _mtmdChunks is not null)
264334
{
265-
// Tokens previous to the images
266-
header = await Context.DecodeAsync(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount);
267-
_pastTokensCount = header.Item3;
268-
269-
if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1);
270-
271-
// TODO JLS:
272-
// Images
273-
//foreach( var image in _imageEmbedHandles )
274-
// ClipModel!.EvalImageEmbed(Context, image, ref _pastTokensCount);
275-
276-
// Post-image Tokens
277-
end = await Context.DecodeAsync(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount);
278-
_pastTokensCount = end.Item3;
279-
280-
_EmbedImagePosition = -1;
281-
// TODO JLS:
282-
//_imageEmbedHandles.Clear();
283-
Embeds.Clear();
335+
var nPast = (long)_pastTokensCount;
336+
var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0,
337+
nBatch: checked((int)Context.BatchSize), logitsLast: true);
338+
if (evalStatus != 0)
339+
{
340+
DisposeMtmdChunks();
341+
throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}.");
342+
}
343+
344+
_pastTokensCount = checked((int)nPast);
345+
DisposeMtmdChunks();
346+
347+
if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
348+
{
349+
_session_tokens.AddRange(_embeds);
350+
_n_session_consumed = _session_tokens.Count;
351+
}
284352
}
285353
else
286354
{
287-
result = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount);
355+
var result = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount);
288356
_pastTokensCount = result.Item3;
289357

290358
if (result.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(result.Item1);
291-
}
292-
293359

294-
if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
295-
{
296-
_session_tokens.AddRange(_embeds);
297-
_n_session_consumed = _session_tokens.Count;
360+
if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
361+
{
362+
_session_tokens.AddRange(_embeds);
363+
_n_session_consumed = _session_tokens.Count;
364+
}
298365
}
299366
}
367+
else if (IsMultiModal && _mtmdChunks is not null)
368+
{
369+
_is_prompt_run = false;
370+
var nPast = (long)_pastTokensCount;
371+
var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)Context.BatchSize), logitsLast: true);
372+
if (evalStatus != 0)
373+
{
374+
DisposeMtmdChunks();
375+
throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}.");
376+
}
300377

378+
_pastTokensCount = checked((int)nPast);
379+
DisposeMtmdChunks();
380+
}
381+
301382
_embeds.Clear();
302383

303384
if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput)
@@ -351,7 +432,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
351432
/// The descriptor of the state of the interactive executor.
352433
/// </summary>
353434
public class InteractiveExecutorState
354-
: ExecutorBaseState
435+
: StatefulExecutorBase.ExecutorBaseState
355436
{
356437
/// <summary>
357438
/// Whether the executor is running for the first time (running the prompt).

LLama/Native/SafeMtmdInputChunks.cs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Collections.Generic;
23

34
namespace LLama.Native;
45

@@ -34,4 +35,20 @@ public IntPtr GetChunkPtr(ulong index)
3435
if (index >= Size) throw new IndexOutOfRangeException();
3536
return NativeApi.mtmd_input_chunks_get(NativePtr, (UIntPtr)index);
3637
}
37-
}
38+
39+
/// <summary>
40+
/// Enumerate the contained chunks as non-owning wrappers.
41+
/// Callers should dispose the returned chunk if they create a copy.
42+
/// </summary>
43+
public IEnumerable<SafeMtmdInputChunk> Enumerate()
44+
{
45+
for (ulong i = 0; i < Size; i++)
46+
{
47+
var chunk = SafeMtmdInputChunk.Wrap(GetChunkPtr(i));
48+
if (chunk != null)
49+
{
50+
yield return chunk;
51+
}
52+
}
53+
}
54+
}

0 commit comments

Comments
 (0)