Skip to content

Commit 68ae07c

Browse files
committed
Review Executors
Interactive and Instruct executors seems to work. BatchExecutor is not working at all.
1 parent af399a7 commit 68ae07c

17 files changed

+817
-144
lines changed

LLama.Examples/ExampleRunner.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public class ExampleRunner
3333
{ "Batched Executor: Save/Load", BatchedExecutorSaveAndLoad.Run },
3434
{ "Batched Executor: Fork", BatchedExecutorFork.Run },
3535
{ "Batched Executor: Rewind", BatchedExecutorRewind.Run },
36-
// { "Batched Executor: LLava", BatchedExecutorLLava.Run },
36+
{ "Batched Executor: Mtmd", BatchedExecutorMtmd.Run },
3737
{ "Batched Executor: BoolQ Benchmark", BatchedExecutorBoolQ.Run },
3838
{ "Batched Executor: Beam Search", BatchedExecutorBeamSearch.Run },
3939
{ "Custom Sampling Pipeline", CustomSampler.Run },

LLama.Examples/Examples/BatchedExecutorLLava.cs

Lines changed: 0 additions & 92 deletions
This file was deleted.
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.IO;
4+
using LLama.Batched;
5+
using LLama.Common;
6+
using LLama.Exceptions;
7+
using LLama.Native;
8+
using LLama.Sampling;
9+
using Spectre.Console;
10+
11+
namespace LLama.Examples.Examples;
12+
13+
/// <summary>
14+
/// Demonstrates how to evaluate an image with MTMD helpers and continue generation by
15+
/// manually scheduling batches, similar to what the batched executor does internally.
16+
/// </summary>
17+
public class BatchedExecutorMtmd
18+
{
19+
/// <summary>
20+
/// Number of completion tokens to generate after sending the image prompt.
21+
/// </summary>
22+
public const int TokenCount = 64;
23+
24+
public static async Task Run()
25+
{
26+
var parameters = new ModelParams(UserSettings.GetModelPath());
27+
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
28+
var mtmdParams = MtmdContextParams.Default();
29+
mtmdParams.UseGpu = false;
30+
var marker = mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? "<media>";
31+
32+
using var mtmd = await SafeMtmdWeights.LoadFromFileAsync(UserSettings.GetMMProjPath(), model, mtmdParams);
33+
34+
using var executor = new BatchedExecutor(model, parameters, mtmd);
35+
36+
var defaultPrompt = "\nUSER: Provide a full description of the image.\nASSISTANT: ";
37+
var promptSuffix = AnsiConsole.Ask("Prompt (or ENTER for default):", defaultPrompt);
38+
var promptText = string.Concat(marker, promptSuffix);
39+
40+
var imagePath = UserSettings.GetImagePath();
41+
AnsiConsole.Write(new CanvasImage(imagePath));
42+
43+
var vocab = executor.Context.NativeHandle.ModelHandle.Vocab;
44+
45+
var sampler = new DefaultSamplingPipeline
46+
{
47+
Temperature = 0.1f
48+
};
49+
50+
var decoder = new StreamingTokenDecoder(executor.Context)
51+
{
52+
DecodeSpecialTokens = false
53+
};
54+
55+
try
56+
{
57+
var conversation = executor.Create();
58+
conversation.QueueMedia(imagePath);
59+
conversation.Prompt(promptText, addBos: true, special: true);
60+
61+
Console.ForegroundColor = ConsoleColor.Yellow;
62+
Console.WriteLine("Prompt queued with multimodal chunks. Generating response...\n");
63+
Console.ResetColor();
64+
65+
var remaining = TokenCount;
66+
while (remaining > 0)
67+
{
68+
var decodeResult = await executor.Infer();
69+
if (decodeResult == DecodeResult.NoKvSlot)
70+
{
71+
Console.ForegroundColor = ConsoleColor.Red;
72+
Console.WriteLine("Insufficient KV cache space for multimodal evaluation.");
73+
Console.ResetColor();
74+
break;
75+
}
76+
77+
if (decodeResult != DecodeResult.Ok)
78+
throw new RuntimeError($"Failed to evaluate batch: {decodeResult}.");
79+
80+
if (!conversation.RequiresSampling)
81+
continue;
82+
83+
var token = conversation.Sample(sampler);
84+
if (token.IsEndOfGeneration(vocab))
85+
break;
86+
87+
decoder.Add(token);
88+
var delta = decoder.Read();
89+
if (!string.IsNullOrEmpty(delta))
90+
Console.Write(delta);
91+
92+
sampler.Accept(token);
93+
conversation.Prompt(token);
94+
remaining--;
95+
}
96+
97+
Console.WriteLine();
98+
}
99+
catch (IOException ex)
100+
{
101+
Console.ForegroundColor = ConsoleColor.Red;
102+
Console.WriteLine($"Could not load media '{imagePath}': {ex.Message}");
103+
Console.ResetColor();
104+
}
105+
catch (RuntimeError ex)
106+
{
107+
Console.ForegroundColor = ConsoleColor.Red;
108+
Console.WriteLine($"MTMD processing failed: {ex.Message}");
109+
Console.ResetColor();
110+
}
111+
}
112+
}

LLama.Examples/Examples/MtmdInteractiveModeExecute.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ public static async Task Run()
2424
var parameters = new ModelParams(modelPath);
2525

2626
var mtmdParameters = MtmdContextParams.Default();
27+
mtmdParameters.UseGpu = false;
2728

2829
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
2930
using var context = model.CreateContext(parameters);

LLama.Examples/LLama.Examples.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
<!-- Set IncludeBuiltInRuntimes to false to include your own runtime libraries and not link the defaults -->
1010
<IncludeBuiltInRuntimes>true</IncludeBuiltInRuntimes>
1111
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
12-
<LangVersion>12</LangVersion>
12+
<LangVersion>13</LangVersion>
1313
<NoWarn>1701;1702;8604;SKEXP0001;SKEXP0050;SKEXP0052;SKEXP0003</NoWarn>
1414
</PropertyGroup>
1515

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Threading.Tasks;
4+
using LLama.Common;
5+
using LLama.Native;
6+
using Microsoft.Extensions.Logging.Abstractions;
7+
using Xunit;
8+
9+
namespace LLama.Unittest;
10+
11+
[Trait("Category", "NoCI")]
12+
public class MtmdExecutorTests : IDisposable
13+
{
14+
private readonly LLamaWeights _weights;
15+
private readonly MtmdContextParams _mtmdParams;
16+
private readonly SafeMtmdWeights _mtmd;
17+
private readonly ModelParams _modelParams;
18+
19+
public MtmdExecutorTests()
20+
{
21+
_modelParams = new ModelParams(Constants.MtmdModelPath)
22+
{
23+
ContextSize = 1024 * 8,
24+
GpuLayerCount = Constants.CIGpuLayerCount,
25+
};
26+
27+
_weights = LLamaWeights.LoadFromFile(_modelParams);
28+
29+
_mtmdParams = MtmdContextParams.Default();
30+
_mtmdParams.NThreads = Math.Max(1, Constants.CIGpuLayerCount);
31+
_mtmdParams.UseGpu = false;
32+
33+
_mtmd = SafeMtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _weights, _mtmdParams);
34+
}
35+
36+
public void Dispose()
37+
{
38+
_mtmd.Dispose();
39+
_weights.Dispose();
40+
}
41+
42+
[Fact]
43+
public async Task InteractiveExecutor_EvaluateChunks_DoesNotRetokenize()
44+
{
45+
using var context = _weights.CreateContext(_modelParams, NullLogger.Instance);
46+
var executor = new InteractiveExecutor(context, _mtmd, NullLogger.Instance);
47+
var marker = _mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? "<media>";
48+
var prompt = $"{marker}\nDescribe the image succinctly.";
49+
50+
executor.Embeds.Add(_mtmd.LoadMedia(Constants.MtmdImage));
51+
52+
await foreach (var _ in executor.InferAsync(prompt, new InferenceParams { MaxTokens = 0 }))
53+
{
54+
Assert.True(false, "Prefill should not emit generated text");
55+
}
56+
57+
var diagnostics = executor.GetDiagnostics();
58+
Assert.Equal(diagnostics.EmbedCount, diagnostics.ConsumedCount);
59+
Assert.Equal(diagnostics.ConsumedCount, diagnostics.PastCount);
60+
Assert.Equal(0, diagnostics.PendingEmbedCount);
61+
}
62+
63+
[Fact]
64+
public async Task InstructExecutor_MtmdPromptAdvancesPastTokensOnce()
65+
{
66+
using var context = _weights.CreateContext(_modelParams, NullLogger.Instance);
67+
var executor = new InstructExecutor(context, _mtmd, logger: NullLogger.Instance);
68+
executor.Embeds.Add(_mtmd.LoadMedia(Constants.MtmdImage));
69+
70+
var prompt = $"{_mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? "<media>"} Provide details.";
71+
72+
await foreach (var _ in executor.InferAsync(prompt, new InferenceParams { MaxTokens = 0 }))
73+
{
74+
}
75+
76+
var diagnostics = executor.GetDiagnostics();
77+
Assert.Equal(diagnostics.EmbedCount, diagnostics.ConsumedCount);
78+
Assert.Equal(diagnostics.ConsumedCount, diagnostics.PastCount);
79+
Assert.Equal(0, diagnostics.PendingEmbedCount);
80+
}
81+
}

0 commit comments

Comments
 (0)