88using System . Text . Json ;
99using System . Text . Json . Serialization ;
1010using System . Threading . Tasks ;
11+ using LLama ;
1112using LLama . Exceptions ;
1213using LLama . Sampling ;
1314using Microsoft . Extensions . Logging ;
@@ -21,12 +22,10 @@ namespace LLama
2122 public class InteractiveExecutor : StatefulExecutorBase
2223 {
2324 private bool _is_prompt_run = true ;
24-
25- // LLava
26- private int _EmbedImagePosition = - 1 ;
27- // TODO JLS:
28- //private List<SafeMtmdImageEmbedHandle> _imageEmbedHandles = new List<SafeMtmdImageEmbedHandle>();
29- private bool _imageInPrompt = false ;
25+
26+ // MTMD multimodal state
27+ private SafeMtmdInputChunks ? _mtmdChunks ;
28+ private string ? _mtmdMarker ;
3029
3130 /// <summary>
3231 ///
@@ -71,6 +70,7 @@ public override ExecutorBaseState GetStateData()
7170 /// <inheritdoc />
7271 public override Task LoadState ( ExecutorBaseState data )
7372 {
73+ DisposeMtmdChunks ( ) ;
7474 if ( data is InteractiveExecutorState state )
7575 {
7676 _n_session_consumed = state . ConsumedSessionCount ;
@@ -130,7 +130,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
130130 }
131131 else
132132 {
133- PreprocessLlava ( text , args , true ) ;
133+ PreprocessMtmd ( text , args , true ) ;
134134 }
135135 }
136136 else
@@ -151,51 +151,121 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
151151 }
152152 else
153153 {
154- PreprocessLlava ( text , args , false ) ;
154+ PreprocessMtmd ( text , args , false ) ;
155155 }
156156 }
157157 }
158158
159159 return Task . CompletedTask ;
160160 }
161161
162+ private void DisposeMtmdChunks ( )
163+ {
164+ _mtmdChunks ? . Dispose ( ) ;
165+ _mtmdChunks = null ;
166+ }
167+
168+ private void DisposeEmbeds ( )
169+ {
170+ if ( Embeds . Count == 0 )
171+ {
172+ return ;
173+ }
174+
175+ foreach ( var embed in Embeds )
176+ {
177+ embed . Dispose ( ) ;
178+ }
179+
180+ Embeds . Clear ( ) ;
181+ }
182+
183+ private string GetMtmdMarker ( )
184+ {
185+ if ( _mtmdMarker is not null )
186+ {
187+ return _mtmdMarker ;
188+ }
189+
190+ _mtmdMarker = NativeApi . MtmdDefaultMarker ( ) ?? "<media>" ;
191+ return _mtmdMarker ;
192+ }
193+
162194 /// <inheritdoc />
163- private Task PreprocessLlava ( string text , InferStateArgs args , bool addBos = true )
164- {
165- // If the prompt contains the tag <image> extract this.
166- _imageInPrompt = text . Contains ( "<image>" ) ;
167- if ( _imageInPrompt && IsMultiModal )
195+ private Task PreprocessMtmd ( string text , InferStateArgs args , bool addBos = true )
196+ {
197+ if ( ClipModel is null )
198+ {
199+ throw new InvalidOperationException ( "Multimodal execution requires a loaded mtmd clip model." ) ;
200+ }
201+
202+ DisposeMtmdChunks ( ) ;
203+
204+ var marker = GetMtmdMarker ( ) ;
205+ var prompt = text ;
206+
207+ if ( Embeds . Count > 0 )
168208 {
169- foreach ( var embed in Embeds )
209+ if ( prompt . Contains ( "<image>" ) )
170210 {
171- // TODO JLS:
172- //_imageEmbedHandles.Add(SafeMtmdImageEmbedHandle.CreateFromMemory(ClipModel!.NativeHandle, Context, image));
211+ prompt = prompt . Replace ( "<image>" , marker ) ;
173212 }
174213
175- int imageIndex = text . IndexOf ( "<image>" ) ;
176- // Tokenize segment 1 (before <image> tag)
177- string preImagePrompt = text . Substring ( 0 , imageIndex ) ;
178- var segment1 = Context . Tokenize ( preImagePrompt , addBos , true ) ;
179- // Remember the position to add the image embeddings
180- _EmbedImagePosition = segment1 . Length ;
181- string postImagePrompt = text . Substring ( imageIndex + 7 ) ;
182- var segment2 = Context . Tokenize ( postImagePrompt , false , true ) ;
183- _embed_inps . AddRange ( segment1 ) ;
184- _embed_inps . AddRange ( segment2 ) ;
214+ if ( ! prompt . Contains ( marker ) )
215+ {
216+ var suffix = string . Concat ( Enumerable . Repeat ( marker , Embeds . Count ) ) ;
217+ prompt = string . Concat ( prompt , suffix ) ;
218+ }
185219 }
186- else
220+
221+ SafeMtmdInputChunks ? chunks = null ;
222+ try
187223 {
224+ var status = ClipModel . Tokenize ( prompt , addBos , parseSpecial : true , out chunks ) ;
225+ if ( status != 0 || chunks is null )
226+ {
227+ ClipModel . ClearMedia ( ) ;
228+ throw new RuntimeError ( $ "Failed to tokenize multimodal prompt. Status: { status } .") ;
229+ }
230+
231+ _mtmdChunks = chunks ;
232+
233+ var tokens = new List < LLamaToken > ( ) ;
234+ foreach ( var chunk in chunks . Enumerate ( ) )
235+ {
236+ using var scopedChunk = chunk ;
237+ if ( scopedChunk . Type != SafeMtmdInputChunk . SafeMtmdInputChunkType . Text )
238+ {
239+ continue ;
240+ }
241+
242+ foreach ( var token in scopedChunk . GetTextTokensSpan ( ) )
243+ {
244+ tokens . Add ( unchecked ( ( int ) token ) ) ;
245+ }
246+ }
247+
188248 if ( addBos )
189249 {
190- _embed_inps = Context . Tokenize ( text , true , true ) . ToList ( ) ;
250+ _embed_inps = tokens ;
191251 }
192252 else
193253 {
194- var line_inp = Context . Tokenize ( text , false , true ) ;
195- _embed_inps . AddRange ( line_inp ) ;
196- args . RemainedTokens -= line_inp . Length ;
254+ _embed_inps . AddRange ( tokens ) ;
255+ args . RemainedTokens -= tokens . Count ;
197256 }
198257 }
258+ catch
259+ {
260+ chunks ? . Dispose ( ) ;
261+ _mtmdChunks = null ;
262+ throw ;
263+ }
264+ finally
265+ {
266+ DisposeEmbeds ( ) ;
267+ }
268+
199269 return Task . CompletedTask ;
200270 }
201271
@@ -255,49 +325,60 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
255325 HandleRunOutOfContext ( tokensToKeep ) ;
256326 }
257327
258- TryReuseMatchingPrefix ( ) ;
328+ if ( _mtmdChunks is null )
329+ {
330+ TryReuseMatchingPrefix ( ) ;
331+ }
259332
260- // Changes to support Multi-Modal LLMs.
261- //
262- ( DecodeResult , int , int ) header , end , result ;
263- if ( IsMultiModal && _EmbedImagePosition > 0 )
333+ if ( IsMultiModal && _mtmdChunks is not null )
264334 {
265- // Tokens previous to the images
266- header = await Context . DecodeAsync ( _embeds . GetRange ( 0 , _EmbedImagePosition ) , LLamaSeqId . Zero , batch , _pastTokensCount ) ;
267- _pastTokensCount = header . Item3 ;
268-
269- if ( header . Item1 != DecodeResult . Ok ) throw new LLamaDecodeError ( header . Item1 ) ;
270-
271- // TODO JLS:
272- // Images
273- //foreach( var image in _imageEmbedHandles )
274- // ClipModel!.EvalImageEmbed(Context, image, ref _pastTokensCount);
275-
276- // Post-image Tokens
277- end = await Context . DecodeAsync ( _embeds . GetRange ( _EmbedImagePosition , _embeds . Count - _EmbedImagePosition ) , LLamaSeqId . Zero , batch , _pastTokensCount ) ;
278- _pastTokensCount = end . Item3 ;
279-
280- _EmbedImagePosition = - 1 ;
281- // TODO JLS:
282- //_imageEmbedHandles.Clear();
283- Embeds . Clear ( ) ;
335+ var nPast = ( long ) _pastTokensCount ;
336+ var evalStatus = ClipModel ! . EvaluateChunks ( _mtmdChunks , Context . NativeHandle , ref nPast , seqId : 0 ,
337+ nBatch : checked ( ( int ) Context . BatchSize ) , logitsLast : true ) ;
338+ if ( evalStatus != 0 )
339+ {
340+ DisposeMtmdChunks ( ) ;
341+ throw new RuntimeError ( $ "Failed to evaluate multimodal chunks. Status: { evalStatus } .") ;
342+ }
343+
344+ _pastTokensCount = checked ( ( int ) nPast ) ;
345+ DisposeMtmdChunks ( ) ;
346+
347+ if ( _embeds . Count > 0 && ! string . IsNullOrEmpty ( _pathSession ) )
348+ {
349+ _session_tokens . AddRange ( _embeds ) ;
350+ _n_session_consumed = _session_tokens . Count ;
351+ }
284352 }
285353 else
286354 {
287- result = await Context . DecodeAsync ( _embeds , LLamaSeqId . Zero , batch , _pastTokensCount ) ;
355+ var result = await Context . DecodeAsync ( _embeds , LLamaSeqId . Zero , batch , _pastTokensCount ) ;
288356 _pastTokensCount = result . Item3 ;
289357
290358 if ( result . Item1 != DecodeResult . Ok ) throw new LLamaDecodeError ( result . Item1 ) ;
291- }
292-
293359
294- if ( _embeds . Count > 0 && ! string . IsNullOrEmpty ( _pathSession ) )
295- {
296- _session_tokens . AddRange ( _embeds ) ;
297- _n_session_consumed = _session_tokens . Count ;
360+ if ( _embeds . Count > 0 && ! string . IsNullOrEmpty ( _pathSession ) )
361+ {
362+ _session_tokens . AddRange ( _embeds ) ;
363+ _n_session_consumed = _session_tokens . Count ;
364+ }
298365 }
299366 }
367+ else if ( IsMultiModal && _mtmdChunks is not null )
368+ {
369+ _is_prompt_run = false ;
370+ var nPast = ( long ) _pastTokensCount ;
371+ var evalStatus = ClipModel ! . EvaluateChunks ( _mtmdChunks , Context . NativeHandle , ref nPast , seqId : 0 , nBatch : checked ( ( int ) Context . BatchSize ) , logitsLast : true ) ;
372+ if ( evalStatus != 0 )
373+ {
374+ DisposeMtmdChunks ( ) ;
375+ throw new RuntimeError ( $ "Failed to evaluate multimodal chunks. Status: { evalStatus } .") ;
376+ }
300377
378+ _pastTokensCount = checked ( ( int ) nPast ) ;
379+ DisposeMtmdChunks ( ) ;
380+ }
381+
301382 _embeds . Clear ( ) ;
302383
303384 if ( _embed_inps . Count <= _consumedTokensCount && ! args . WaitForInput )
@@ -351,7 +432,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
351432 /// The descriptor of the state of the interactive executor.
352433 /// </summary>
353434 public class InteractiveExecutorState
354- : ExecutorBaseState
435+ : StatefulExecutorBase . ExecutorBaseState
355436 {
356437 /// <summary>
357438 /// Whether the executor is running for the first time (running the prompt).
0 commit comments