Skip to content

Commit 3efe956

Browse files
committed
WIP
1 parent 78f6137 commit 3efe956

File tree

8 files changed

+425
-159
lines changed

8 files changed

+425
-159
lines changed

LLama/LLamaExecutorBase.cs

Lines changed: 50 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ public abstract class StatefulExecutorBase : ILLamaExecutor
3232
/// </summary>
3333
protected int _consumedTokensCount; // n_consume
3434
/// <summary>
35-
///
35+
/// Number of tokens consumed from the session cache during the current run.
3636
/// </summary>
3737
protected int _n_session_consumed;
3838
/// <summary>
39-
///
39+
/// Number of prompt tokens that match the loaded session cache prefix.
4040
/// </summary>
4141
protected int _n_matching_session_tokens;
4242
/// <summary>
@@ -52,7 +52,7 @@ public abstract class StatefulExecutorBase : ILLamaExecutor
5252
/// </summary>
5353
protected List<LLamaToken> _embed_inps = new();
5454
/// <summary>
55-
///
55+
/// Tokens recovered from the session file and reused to warm up the KV cache.
5656
/// </summary>
5757
protected List<LLamaToken> _session_tokens = new();
5858
/// <summary>
@@ -84,10 +84,10 @@ public bool IsMultiModal
8484
private readonly StreamingTokenDecoder _decoder;
8585

8686
/// <summary>
87-
///
87+
/// Initialize a stateful executor bound to a specific context.
8888
/// </summary>
89-
/// <param name="context"></param>
90-
/// <param name="logger"></param>
89+
/// <param name="context">LLama context used for all native interactions.</param>
90+
/// <param name="logger">Optional logger for diagnostic output.</param>
9191
protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null)
9292
{
9393
Embeds = new List<SafeMtmdEmbed>();
@@ -101,22 +101,22 @@ protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null)
101101
}
102102

103103
/// <summary>
104-
///
104+
/// Initialize a multimodal executor with the supplied MTMD weights.
105105
/// </summary>
106-
/// <param name="context"></param>
107-
/// <param name="safeMtmdWeights"></param>
108-
/// <param name="logger"></param>
106+
/// <param name="context">LLama context used for all native interactions.</param>
107+
/// <param name="safeMtmdWeights">Multimodal weights to associate with this executor.</param>
108+
/// <param name="logger">Optional logger for diagnostic output.</param>
109109
public StatefulExecutorBase(LLamaContext context, SafeMtmdWeights safeMtmdWeights, ILogger? logger = null) :
110110
this( context, logger )
111111
{
112112
ClipModel = safeMtmdWeights;
113113
}
114114

115115
/// <summary>
116-
/// This API is currently not verified.
116+
/// Attach a session cache file so the executor can reuse previous KV state if compatible.
117117
/// </summary>
118-
/// <param name="filename"></param>
119-
/// <returns></returns>
118+
/// <param name="filename">Path to the llama.cpp session file.</param>
119+
/// <returns>The current executor instance for fluent configuration.</returns>
120120
/// <exception cref="ArgumentNullException"></exception>
121121
/// <exception cref="RuntimeError"></exception>
122122
public StatefulExecutorBase WithSessionFile(string filename)
@@ -173,9 +173,9 @@ public StatefulExecutorBase WithSessionFile(string filename)
173173
}
174174

175175
/// <summary>
176-
/// This API has not been verified currently.
176+
/// Persist the current session cache to disk.
177177
/// </summary>
178-
/// <param name="filename"></param>
178+
/// <param name="filename">Destination path for the llama.cpp session file.</param>
179179
public void SaveSessionFile(string filename)
180180
{
181181
var session_token_array = _session_tokens.ToArray();
@@ -203,7 +203,7 @@ protected virtual void HandleRunOutOfContext(int tokensToKeep)
203203
}
204204

205205
/// <summary>
206-
/// Try to reuse the matching prefix from the session file.
206+
/// Try to reuse the matching prompt prefix from the loaded session cache before evaluating new tokens.
207207
/// </summary>
208208
protected virtual void TryReuseMatchingPrefix()
209209
{
@@ -236,66 +236,66 @@ protected virtual void TryReuseMatchingPrefix()
236236
}
237237

238238
/// <summary>
239-
/// Decide whether to continue the loop.
239+
/// Determine whether the inference loop should continue processing tokens.
240240
/// </summary>
241-
/// <param name="args"></param>
242-
/// <returns></returns>
241+
/// <param name="args">Mutable state associated with the current inference.</param>
242+
/// <returns><c>true</c> to continue generating; otherwise <c>false</c>.</returns>
243243
protected abstract Task<bool> GetLoopCondition(InferStateArgs args);
244244

245245
/// <summary>
246-
/// Preprocess the inputs before the inference.
246+
/// Prepare the executor for inference by tokenizing input and updating cached state.
247247
/// </summary>
248-
/// <param name="text"></param>
249-
/// <param name="args"></param>
248+
/// <param name="text">Prompt text to process.</param>
249+
/// <param name="args">Mutable state associated with the current inference.</param>
250250
protected abstract Task PreprocessInputs(string? text, InferStateArgs args);
251251

252252
/// <summary>
253-
/// Do some post processing after the inference.
253+
/// Perform any post-processing on the generated tokens.
254254
/// </summary>
255-
/// <param name="inferenceParams"></param>
256-
/// <param name="args"></param>
257-
/// <returns></returns>
255+
/// <param name="inferenceParams">Parameters controlling sampling.</param>
256+
/// <param name="args">Mutable state associated with the current inference.</param>
257+
/// <returns>A tuple indicating whether generation should stop and any extra outputs to emit.</returns>
258258
protected abstract Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args);
259259

260260
/// <summary>
261-
/// The core inference logic.
261+
/// Core inference loop that advances the model by one step.
262262
/// </summary>
263-
/// <param name="inferenceParams"></param>
264-
/// <param name="args"></param>
263+
/// <param name="inferenceParams">Parameters controlling sampling.</param>
264+
/// <param name="args">Mutable state associated with the current inference.</param>
265265
protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args);
266266

267267
/// <summary>
268-
/// Save the current state to a file.
268+
/// Save the executor state to a serialized snapshot file.
269269
/// </summary>
270-
/// <param name="filename"></param>
270+
/// <param name="filename">Destination file for the serialized state.</param>
271271
public abstract Task SaveState(string filename);
272272

273273
/// <summary>
274-
/// Get the current state data.
274+
/// Capture the executor state in a serializable object.
275275
/// </summary>
276-
/// <returns></returns>
276+
/// <returns>State snapshot suitable for persistence.</returns>
277277
public abstract ExecutorBaseState GetStateData();
278278

279279
/// <summary>
280-
/// Load the state from data.
280+
/// Restore executor state from a previously captured snapshot.
281281
/// </summary>
282-
/// <param name="data"></param>
282+
/// <param name="data">State snapshot created by <see cref="GetStateData"/>.</param>
283283
public abstract Task LoadState(ExecutorBaseState data);
284284

285285
/// <summary>
286-
/// Load the state from a file.
286+
/// Restore executor state from a serialized snapshot file.
287287
/// </summary>
288-
/// <param name="filename"></param>
288+
/// <param name="filename">Path to the snapshot produced by <see cref="SaveState"/>.</param>
289289
public abstract Task LoadState(string filename);
290290

291291

292292
/// <summary>
293-
/// Execute the inference.
293+
/// Execute an asynchronous inference session.
294294
/// </summary>
295-
/// <param name="text">The prompt. If null, generation will continue where it left off previously.</param>
296-
/// <param name="inferenceParams"></param>
297-
/// <param name="cancellationToken"></param>
298-
/// <returns></returns>
295+
/// <param name="text">Optional prompt; when null generation resumes from prior state.</param>
296+
/// <param name="inferenceParams">Sampling parameters to apply; defaults are used when null.</param>
297+
/// <param name="cancellationToken">Cancellation token for cooperative cancellation.</param>
298+
/// <returns>Stream of decoded text segments as they become available.</returns>
299299
public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
300300
{
301301
cancellationToken.ThrowIfCancellationRequested();
@@ -370,33 +370,36 @@ public virtual async Task PrefillPromptAsync(string prompt)
370370
}
371371

372372
/// <summary>
373-
/// State arguments that are used in single inference
373+
/// Mutable state passed between inference callbacks during a single generation pass.
374374
/// </summary>
375375
protected class InferStateArgs
376376
{
377377
/// <summary>
378-
///
378+
/// Anti-prompts that terminate generation when encountered.
379379
/// </summary>
380380
public IList<string>? Antiprompts { get; set; }
381381
/// <summary>
382382
/// Tokens count remained to be used. (n_remain)
383383
/// </summary>
384384
public int RemainedTokens { get; set; }
385385
/// <summary>
386-
///
386+
/// Indicates whether generated tokens should be returned to the caller.
387387
/// </summary>
388388
public bool ReturnValue { get; set; }
389389
/// <summary>
390-
///
390+
/// Signals that the executor should pause and wait for additional user input.
391391
/// </summary>
392392
public bool WaitForInput { get; set; }
393393
/// <summary>
394-
///
394+
/// Indicates whether the session cache should be persisted after inference completes.
395395
/// </summary>
396396
public bool NeedToSaveSession { get; set; }
397397
}
398398

399399
#pragma warning disable CS1591, CS8618 // Missing XML and irrelevant nullable warnings
400+
/// <summary>
401+
/// Serializable snapshot of executor state used for persistence and restart.
402+
/// </summary>
400403
[JsonConverter(typeof(PolymorphicJSONConverter<ExecutorBaseState>))]
401404
public class ExecutorBaseState
402405
{

LLama/LLamaInteractExecutor.cs

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,28 +21,29 @@ namespace LLama
2121
/// </summary>
2222
public class InteractiveExecutor : StatefulExecutorBase
2323
{
24+
// Indicates whether the executor is currently evaluating the initial prompt or a follow-up turn.
2425
private bool _is_prompt_run = true;
2526

2627
// MTMD multimodal state
27-
private SafeMtmdInputChunks? _mtmdChunks;
28-
private string? _mtmdMarker;
28+
private SafeMtmdInputChunks? _mtmdChunks; // Pending chunk collection produced by the multimodal tokenizer.
29+
private string? _mtmdMarker; // Cached multimodal marker returned by the native helper.
2930

3031
/// <summary>
31-
///
32+
/// Create an interactive executor for text-only inference.
3233
/// </summary>
33-
/// <param name="context"></param>
34-
/// <param name="logger"></param>
34+
/// <param name="context">LLama context to operate against.</param>
35+
/// <param name="logger">Optional logger for diagnostic output.</param>
3536
public InteractiveExecutor(LLamaContext context, ILogger? logger = null)
3637
: base(context, logger)
3738
{
3839
}
3940

4041
/// <summary>
41-
///
42+
/// Create an interactive multimodal executor that can process text alongside media inputs.
4243
/// </summary>
43-
/// <param name="context"></param>
44-
/// <param name="clipModel"></param>
45-
/// <param name="logger"></param>
44+
/// <param name="context">LLama context to operate against.</param>
45+
/// <param name="clipModel">Multimodal weights (MTMD) to attach to the executor.</param>
46+
/// <param name="logger">Optional logger for diagnostic output.</param>
4647
public InteractiveExecutor(LLamaContext context, SafeMtmdWeights clipModel, ILogger? logger = null)
4748
: base(context, clipModel, logger)
4849
{
@@ -109,15 +110,20 @@ public override async Task LoadState(string filename)
109110
}
110111

111112
/// <summary>
112-
/// Define whether to continue the loop to generate responses.
113+
/// Decide whether generation should continue for the current iteration.
113114
/// </summary>
114-
/// <returns></returns>
115+
/// <param name="args">Mutable inference state.</param>
116+
/// <returns><c>true</c> to keep generating; otherwise <c>false</c>.</returns>
115117
protected override Task<bool> GetLoopCondition(InferStateArgs args)
116118
{
117119
return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run);
118120
}
119121

120-
/// <inheritdoc />
122+
/// <summary>
123+
/// Preprocess the incoming prompt or continuation text before inference.
124+
/// </summary>
125+
/// <param name="text">Prompt text or continuation provided by the caller.</param>
126+
/// <param name="args">Mutable inference state.</param>
121127
protected override Task PreprocessInputs(string? text, InferStateArgs args)
122128
{
123129
if (_is_prompt_run)
@@ -159,12 +165,18 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
159165
return Task.CompletedTask;
160166
}
161167

168+
/// <summary>
169+
/// Release any queued multimodal chunks and reset state.
170+
/// </summary>
162171
private void DisposeMtmdChunks()
163172
{
164173
_mtmdChunks?.Dispose();
165174
_mtmdChunks = null;
166175
}
167176

177+
/// <summary>
178+
/// Dispose and clear any pending multimodal embeddings queued for evaluation.
179+
/// </summary>
168180
private void DisposeEmbeds()
169181
{
170182
if (Embeds.Count == 0)
@@ -180,6 +192,9 @@ private void DisposeEmbeds()
180192
Embeds.Clear();
181193
}
182194

195+
/// <summary>
196+
/// Retrieve the marker token used to signal media segments to the tokenizer.
197+
/// </summary>
183198
private string GetMtmdMarker()
184199
{
185200
if (_mtmdMarker is not null)
@@ -191,7 +206,12 @@ private string GetMtmdMarker()
191206
return _mtmdMarker;
192207
}
193208

194-
/// <inheritdoc />
209+
/// <summary>
210+
/// Preprocess multimodal prompts by aligning media markers and tokenizing via MTMD helpers.
211+
/// </summary>
212+
/// <param name="text">Prompt text containing optional media markers.</param>
213+
/// <param name="args">Mutable inference state.</param>
214+
/// <param name="addBos">Whether to treat the prompt as a fresh run and add the BOS token.</param>
195215
private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos = true)
196216
{
197217
if (ClipModel is null)
@@ -213,7 +233,7 @@ private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos = true
213233

214234
if (!prompt.Contains(marker))
215235
{
216-
var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count));
236+
var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count)); // Ensure tokenizer sees one marker per embed.
217237
prompt = string.Concat(prompt, suffix);
218238
}
219239
}
@@ -228,7 +248,7 @@ private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos = true
228248
throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}.");
229249
}
230250

231-
_mtmdChunks = chunks;
251+
_mtmdChunks = chunks; // Own the chunk collection until evaluation completes.
232252

233253
var tokens = new List<LLamaToken>();
234254
foreach (var chunk in chunks.Enumerate())
@@ -263,18 +283,18 @@ private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos = true
263283
}
264284
finally
265285
{
266-
DisposeEmbeds();
286+
DisposeEmbeds(); // Flush any embeds decoded in prior step; MTMD replays them via chunk eval.
267287
}
268288

269289
return Task.CompletedTask;
270290
}
271291

272292
/// <summary>
273-
/// Return whether to break the generation.
293+
/// Decide whether generation should stop based on antiprompts, token limits, or end-of-generation markers.
274294
/// </summary>
275-
/// <param name="inferenceParams"></param>
276-
/// <param name="args"></param>
277-
/// <returns></returns>
295+
/// <param name="inferenceParams">Sampling parameters controlling generation.</param>
296+
/// <param name="args">Mutable inference state.</param>
297+
/// <returns>Tuple describing whether to stop and any additional outputs to emit.</returns>
278298
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
279299
{
280300
if (_embed_inps.Count <= _consumedTokensCount)
@@ -429,7 +449,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
429449
}
430450

431451
/// <summary>
432-
/// The descriptor of the state of the interactive executor.
452+
/// Serializable state specific to the interactive executor.
433453
/// </summary>
434454
public class InteractiveExecutorState
435455
: StatefulExecutorBase.ExecutorBaseState

0 commit comments

Comments
 (0)