Skip to content

Commit 7a16be5

Browse files
committed
Extend Mtmd test
1 parent 3efe956 commit 7a16be5

File tree

1 file changed

+77
-36
lines changed

1 file changed

+77
-36
lines changed

LLama.Unittest/MtmdWeightsTests.cs

Lines changed: 77 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -44,56 +44,97 @@ public void Dispose()
4444
_llamaWeights.Dispose();
4545
}
4646

47-
[Fact,Trait("Category", "NoCI")]
48-
public void EmbedImageAsFileName()
47+
private SafeMtmdInputChunks TokenizeWithEmbed(Func<SafeMtmdEmbed> loadEmbed)
4948
{
5049
_safeMtmdWeights.ClearMedia();
5150

52-
using var image = _safeMtmdWeights.LoadMedia(Constants.MtmdImage);
53-
Assert.NotNull(image);
54-
Assert.True(image.Nx > 0);
55-
56-
var prompt = _mediaMarker;
57-
SafeMtmdInputChunks? chunks;
58-
var status = _safeMtmdWeights.Tokenize(prompt, addSpecial: true, parseSpecial: true, out chunks);
59-
Assert.Equal(0, status);
60-
Assert.NotNull(chunks);
51+
var embed = loadEmbed();
52+
Assert.NotNull(embed);
6153

62-
var ownedChunks = chunks!;
63-
using (ownedChunks)
54+
using (embed)
6455
{
65-
long nPast = 0;
66-
var eval = _safeMtmdWeights.EvaluateChunks(ownedChunks, _context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)_context.BatchSize), logitsLast: true);
67-
Assert.Equal(0, eval);
68-
Assert.True(nPast > 0);
56+
Assert.True(embed.Nx > 0);
57+
Assert.True(embed.Ny > 0);
58+
Assert.False(embed.IsAudio);
59+
Assert.True(embed.GetDataSpan().Length > 0);
60+
61+
var status = _safeMtmdWeights.Tokenize(_mediaMarker, addSpecial: true, parseSpecial: true, out var chunks);
62+
Assert.Equal(0, status);
63+
Assert.NotNull(chunks);
64+
65+
return chunks!;
6966
}
70-
}
71-
67+
}
68+
69+
private void AssertChunksEvaluate(SafeMtmdInputChunks chunks)
70+
{
71+
long nPast = 0;
72+
var eval = _safeMtmdWeights.EvaluateChunks(chunks, _context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)_context.BatchSize), logitsLast: true);
73+
Assert.Equal(0, eval);
74+
Assert.True(nPast > 0);
75+
}
76+
77+
[Fact,Trait("Category", "NoCI")]
78+
public void EmbedImageAsFileName()
79+
{
80+
using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(Constants.MtmdImage));
81+
AssertChunksEvaluate(chunks);
82+
}
83+
7284
[Fact,Trait("Category", "NoCI")]
7385
public void EmbedImageAsBinary()
7486
{
75-
_safeMtmdWeights.ClearMedia();
87+
var imageBytes = File.ReadAllBytes(Constants.MtmdImage);
88+
using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(imageBytes));
89+
AssertChunksEvaluate(chunks);
90+
}
91+
92+
[Fact,Trait("Category", "NoCI")]
93+
public void TokenizeProvidesChunkMetadata()
94+
{
95+
using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(Constants.MtmdImage));
7696

77-
byte[] imageBytes = File.ReadAllBytes(Constants.MtmdImage);
78-
using var image = _safeMtmdWeights.LoadMedia(imageBytes);
79-
Assert.NotNull(image);
80-
Assert.True(image.Nx > 0);
97+
Assert.True(chunks.Size > 0);
8198

82-
var prompt = _mediaMarker;
83-
SafeMtmdInputChunks? chunks;
84-
var status = _safeMtmdWeights.Tokenize(prompt, addSpecial: true, parseSpecial: true, out chunks);
85-
Assert.Equal(0, status);
86-
Assert.NotNull(chunks);
99+
ulong totalTokens = 0;
100+
long totalPositions = 0;
101+
var imageChunks = 0;
87102

88-
var ownedChunks = chunks!;
89-
using (ownedChunks)
103+
foreach (var chunk in chunks.Enumerate())
90104
{
91-
long nPast = 0;
92-
var eval = _safeMtmdWeights.EvaluateChunks(ownedChunks, _context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)_context.BatchSize), logitsLast: true);
93-
Assert.Equal(0, eval);
94-
Assert.True(nPast > 0);
105+
totalTokens += chunk.NTokens;
106+
totalPositions += chunk.NPos;
107+
108+
if (chunk.Type == SafeMtmdInputChunk.SafeMtmdInputChunkType.Image)
109+
{
110+
imageChunks++;
111+
112+
var copy = chunk.Copy();
113+
try
114+
{
115+
Assert.NotNull(copy);
116+
if (copy != null)
117+
{
118+
Assert.Equal(chunk.NTokens, copy.NTokens);
119+
Assert.Equal(chunk.NPos, copy.NPos);
120+
}
121+
}
122+
finally
123+
{
124+
copy?.Dispose();
125+
}
126+
}
95127
}
96-
}
97128

129+
Assert.True(imageChunks > 0);
130+
Assert.True(totalTokens > 0);
131+
Assert.Equal(totalTokens, _safeMtmdWeights.CountTokens(chunks));
132+
Assert.Equal(totalPositions, _safeMtmdWeights.CountPositions(chunks));
133+
Assert.True(_safeMtmdWeights.SupportsVision);
134+
Assert.False(_safeMtmdWeights.SupportsAudio);
135+
136+
var audioBitrate = _safeMtmdWeights.AudioBitrate;
137+
Assert.True(audioBitrate <= 0);
138+
}
98139
}
99140
}

0 commit comments

Comments
 (0)