Skip to content

Commit d2fd7a8

Browse files
committed
Improve Florence2 segmentation, add missing merges.txt
1 parent b8f6c1a commit d2fd7a8

File tree

8 files changed

+140
-22
lines changed

8 files changed

+140
-22
lines changed

TensorStack.TextGeneration/Pipelines/DecoderPipeline.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// Copyright (c) TensorStack. All rights reserved.
22
// Licensed under the Apache 2.0 License.
33

4-
using Google.Protobuf.WellKnownTypes;
54
using System;
65
using System.Collections.Generic;
76
using System.Linq;
@@ -245,7 +244,7 @@ protected virtual async Task<Sequence[]> BeamSearchAsync(GenerateOptions options
245244
/// </summary>
246245
/// <param name="candidates">The sequences.</param>
247246
/// <param name="options">The options.</param>
248-
protected IEnumerable<Sequence> GetSequenceCandidates(SequenceCollection candidates, GenerateOptions options)
247+
protected virtual IEnumerable<Sequence> GetSequenceCandidates(SequenceCollection candidates, GenerateOptions options)
249248
{
250249
// TODO: Diversity Penalty
251250
_sequenceComparer.SetLength(options.DiversityLength);

TensorStack.TextGeneration/Pipelines/Florence/FlorenceOptions.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,14 @@ public record FlorenceOptions : GenerateOptions
1313
public CoordinateBox<float> Region { get; set; }
1414
}
1515

16+
public record FlorenceSearchOptions : FlorenceOptions;
17+
18+
19+
20+
1621
public enum TaskType
1722
{
23+
NONE,
1824
OCR,
1925
OCR_WITH_REGION,
2026
CAPTION,

TensorStack.TextGeneration/Pipelines/Florence/FlorencePipeline.cs

Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
// Licensed under the Apache 2.0 License.
33
using Microsoft.ML.OnnxRuntime;
44
using System;
5+
using System.Collections.Generic;
56
using System.IO;
67
using System.Linq;
8+
using System.Runtime.CompilerServices;
79
using System.Threading;
810
using System.Threading.Tasks;
911
using TensorStack.Common;
@@ -15,7 +17,9 @@
1517

1618
namespace TensorStack.TextGeneration.Pipelines.Florence
1719
{
18-
public class FlorencePipeline : EncoderDecoderPipeline
20+
public class FlorencePipeline : EncoderDecoderPipeline,
21+
IPipeline<GenerateResult, FlorenceOptions>,
22+
IPipelineStream<GenerateResult, FlorenceSearchOptions>
1923
{
2024
private readonly FlorenceConfig _configuration;
2125
private readonly PreProcessor _preProcessor;
@@ -103,6 +107,35 @@ public virtual async Task<GenerateResult> RunAsync(FlorenceOptions options, IPro
103107
}
104108

105109

110+
public virtual async IAsyncEnumerable<GenerateResult> RunAsync(FlorenceSearchOptions options, IProgress<RunProgress> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
111+
{
112+
var textPrompt = _preProcessor.ProcessPrompt(options);
113+
var imagePrompt = _preProcessor.ProcessImage(options);
114+
115+
TokenizerOutput = await Tokenizer.EncodeAsync(textPrompt);
116+
var embedsOutput = await RunTextEmbedAsync(TokenizerOutput.InputIds);
117+
_visionOutput = await RunVisionEncoderAsync(embedsOutput, imagePrompt);
118+
EncoderOutput = await RunEncoderAsync();
119+
120+
var sequences = await BeamSearchAsync(options, cancellationToken);
121+
foreach (var sequence in sequences)
122+
{
123+
using (sequence)
124+
{
125+
var processedBeamOutput = _postProcessor.Process(options, sequence.Tokens);
126+
yield return new GenerateResult
127+
{
128+
Beam = sequence.Id,
129+
Score = sequence.Score,
130+
PenaltyScore = sequence.PenaltyScore,
131+
Result = processedBeamOutput.Result,
132+
CoordinateResults = processedBeamOutput.CoordinateResults
133+
};
134+
}
135+
}
136+
}
137+
138+
106139
/// <summary>
107140
/// Run encoder model.
108141
/// </summary>
@@ -233,14 +266,20 @@ protected override void Dispose(bool disposing)
233266
/// <param name="embedModel">The embed model.</param>
234267
/// <param name="visionModel">The vision model.</param>
235268
/// <returns>FlorencePipeline.</returns>
236-
public static FlorencePipeline Create(ExecutionProvider provider, string modelPath, string encoderModel = "encoder_model.onnx", string decoderModel = "decoder_model_merged.onnx", string embedModel = "embed_tokens.onnx", string visionModel = "vision_encoder.onnx")
269+
public static FlorencePipeline Create(ExecutionProvider provider, string modelPath, FlorenceType modelType, string encoderModel = "encoder_model.onnx", string decoderModel = "decoder_model_merged.onnx", string embedModel = "embed_tokens.onnx", string visionModel = "vision_encoder.onnx")
237270
{
238-
// Florence Large
239-
//NumLayers = 12,
240-
//NumHeads = 16,
241-
//NumKVHeads = 16,
242-
//HiddenSize = 768,
243-
//VocabSize = 51289
271+
var numLayers = 6;
272+
var numHeads = 12;
273+
var numKVHeads = 12;
274+
var hiddenSize = 768;
275+
var vocabSize = 51289;
276+
if (modelType == FlorenceType.Large)
277+
{
278+
numLayers = 12;
279+
numHeads = 16;
280+
numKVHeads = 16;
281+
hiddenSize = 1024;
282+
}
244283

245284
var config = new FlorenceConfig
246285
{
@@ -251,20 +290,20 @@ public static FlorencePipeline Create(ExecutionProvider provider, string modelPa
251290
}),
252291
EncoderConfig = new EncoderConfig
253292
{
254-
NumLayers = 6,
255-
NumHeads = 12,
256-
NumKVHeads = 12,
257-
HiddenSize = 768,
258-
VocabSize = 51289,
293+
NumLayers = numLayers,
294+
NumHeads = numHeads,
295+
NumKVHeads = numKVHeads,
296+
HiddenSize = hiddenSize,
297+
VocabSize = vocabSize,
259298
Path = Path.Combine(modelPath, encoderModel),
260299
},
261300
DecoderConfig = new DecoderConfig
262301
{
263-
NumLayers = 6,
264-
NumHeads = 12,
265-
NumKVHeads = 12,
266-
HiddenSize = 768,
267-
VocabSize = 51289,
302+
NumLayers = numLayers,
303+
NumHeads = numHeads,
304+
NumKVHeads = numKVHeads,
305+
HiddenSize = hiddenSize,
306+
VocabSize = vocabSize,
268307
Path = Path.Combine(modelPath, decoderModel),
269308
},
270309
EmbedsConfig = new ModelConfig
@@ -278,6 +317,7 @@ public static FlorencePipeline Create(ExecutionProvider provider, string modelPa
278317
};
279318

280319
config.EncoderConfig.SetProvider(provider);
320+
//config.DecoderConfig.SetProvider(provider);
281321
config.DecoderConfig.SetProvider(ProviderCPU()); // TODO
282322
config.EmbedsConfig.SetProvider(provider);
283323
config.VisionConfig.SetProvider(provider);
@@ -302,4 +342,10 @@ private static ExecutionProvider ProviderCPU()
302342
});
303343
}
304344
}
345+
346+
public enum FlorenceType
347+
{
348+
Base = 0,
349+
Large = 1
350+
}
305351
}

TensorStack.TextGeneration/Pipelines/Florence/FlorenceTokenizer.cs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ public class FlorenceTokenizer : ITokenizer
2121
private readonly MapCollection<byte, char> _unicodeMap;
2222
private readonly MapCollection<long, int> _coordinateMap;
2323
private readonly MapCollection<long, string> _vocabularyMap;
24+
private readonly Dictionary<MergeToken, int> _mergesMap;
2425
private readonly MapCollection<long, string> _specialTokensMap;
2526

2627
/// <summary>
@@ -34,6 +35,7 @@ public FlorenceTokenizer(TokenizerConfig configuration)
3435
_specialTokensMap = CreateSpecialTokenMapping();
3536
_vocabularyMap = CreateVocabMapping();
3637
_coordinateMap = CreateCoordinateMapping();
38+
_mergesMap = CreateMergesMapping();
3739
_preTokenizeRegex = new Regex(@"'s|'t|'re|'ve|'m|'ll|'d|<loc_[\p{L}\p{N}_]+>| ?[\p{L}_][\p{L}\p{N}_]*|[^ \s\p{L}\p{N}]+|\s+(?!\S)|\s+", RegexOptions.Compiled);
3840
}
3941

@@ -214,6 +216,7 @@ private string TokensToString(IEnumerable<string> tokens)
214216
private long[] StringToTokens(ReadOnlySpan<char> input)
215217
{
216218
var tokens = PreTokenize(input)
219+
.SelectMany(ApplyMerges)
217220
.Select(TokenToId)
218221
.Prepend(_configuration.BOS)
219222
.Append(_configuration.EOS)
@@ -241,6 +244,42 @@ private string[] PreTokenize(ReadOnlySpan<char> input)
241244
}
242245

243246

247+
/// <summary>
248+
/// Applies the merges.
249+
/// </summary>
250+
/// <param name="token">The token.</param>
251+
private List<string> ApplyMerges(string token)
252+
{
253+
if (_specialTokensMap.ContainsKey(token))
254+
return [token];
255+
256+
var symbols = token.Select(c => c.ToString()).ToList();
257+
while (symbols.Count > 1)
258+
{
259+
int bestIndex = -1;
260+
int bestRank = int.MaxValue;
261+
for (int i = 0; i < symbols.Count - 1; i++)
262+
{
263+
var pair = new MergeToken(symbols[i], symbols[i + 1]);
264+
if (_mergesMap.TryGetValue(pair, out int rank) && rank < bestRank)
265+
{
266+
bestRank = rank;
267+
bestIndex = i;
268+
}
269+
}
270+
271+
if (bestIndex == -1)
272+
break;
273+
274+
// Merge the pair
275+
string merged = symbols[bestIndex] + symbols[bestIndex + 1];
276+
symbols[bestIndex] = merged;
277+
symbols.RemoveAt(bestIndex + 1);
278+
}
279+
return symbols;
280+
}
281+
282+
244283
/// <summary>
245284
/// Parses the coordinate.
246285
/// </summary>
@@ -337,6 +376,22 @@ private MapCollection<long, string> CreateVocabMapping()
337376
}
338377

339378

379+
/// <summary>
380+
/// Creates the merges mapping.
381+
/// </summary>
382+
private Dictionary<MergeToken, int> CreateMergesMapping()
383+
{
384+
var mergesFile = Path.Combine(_configuration.Path, "merges.txt");
385+
return File.ReadLines(mergesFile)
386+
.Where(line => !string.IsNullOrWhiteSpace(line) && !line.StartsWith("#"))
387+
.Select((line, index) =>
388+
{
389+
var parts = line.Split(' ', StringSplitOptions.RemoveEmptyEntries);
390+
return (new MergeToken(parts[0], parts[1]), index);
391+
}).ToDictionary(x => x.Item1, x => x.index);
392+
}
393+
394+
340395
/// <summary>
341396
/// Creates the coordinate mapping.
342397
/// </summary>
@@ -374,5 +429,8 @@ private record AddedTokenJson
374429
[JsonPropertyName("content")]
375430
public string Content { get; set; }
376431
}
432+
433+
private record MergeToken(string PartA, string PartB);
434+
377435
}
378436
}

TensorStack.TextGeneration/Pipelines/Florence/PostProcessor.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,11 @@ public List<CoordinateResult> ExtractCoordinatePolygons(IReadOnlyCollection<long
133133
var currentLabel = new List<long>();
134134
var currentCoordinates = new List<int>();
135135
var results = new List<CoordinateResult>();
136+
var index = 0;
136137
foreach (var tokenId in tokens)
137138
{
138-
if (!_tokenizer.TryGetCoordinate(tokenId, out var coordinate))
139+
index++;
140+
if (!_tokenizer.TryGetCoordinate(tokenId, out var coordinate) || index == tokens.Count)
139141
{
140142
if (currentCoordinates.Count > 0)
141143
{
@@ -149,6 +151,9 @@ public List<CoordinateResult> ExtractCoordinatePolygons(IReadOnlyCollection<long
149151
coordinates.Add(new Coordinate<int>(position[0], position[1]));
150152
}
151153

154+
if (coordinates.Count == 0)
155+
return new List<CoordinateResult>();
156+
152157
var scaledCoordinates = _coordinateScaler.ScaleUp(coordinates.ToArray(), sourceImage);
153158
var coordinateBox = new CoordinateBox<int>
154159
(

TensorStack.TextGeneration/Pipelines/Florence/PreProcessor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public string ProcessPrompt(FlorenceOptions options)
4343
TaskType.REGION_PROPOSAL => "Locate the region proposals in the image.",
4444
TaskType.CAPTION_TO_PHRASE_GROUNDING => $"Locate the phrases in the caption: {options.Prompt}",
4545
TaskType.REFERRING_EXPRESSION_SEGMENTATION => $"Locate {options.Prompt} in the image with mask",
46-
TaskType.REGION_TO_SEGMENTATION => $"What is the mask of region {regionTokens}",
46+
TaskType.REGION_TO_SEGMENTATION => $"What is the polygon mask of region {regionTokens}",
4747
TaskType.OPEN_VOCABULARY_DETECTION => $"Locate {options.Prompt} in the image.",
4848
TaskType.REGION_TO_CATEGORY => $"What is the region {regionTokens}?",
4949
TaskType.REGION_TO_DESCRIPTION => $"What does the region {regionTokens} describe?",

TensorStack.TextGeneration/Processing/KVCacheEncoderDecoder.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ public void Update(OrtValue[] currentValues, bool useBranchCache)
8484
{
8585
if (i % 4 == 0)
8686
{
87+
88+
// TODO: Allocate entire Maxlength and update the buffer
89+
8790
_values[i].Dispose();
8891
_values[i + 1].Dispose();
8992

TensorStack.TextGeneration/Processing/Logit/BOSLogitsProcessor.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ public void Process(List<long> inputs, Tensor<float> logits)
2727
{
2828
if (inputs.Count == 0)
2929
{
30+
inputs.Add(_bosTokenId);
3031
logits.Fill(float.NegativeInfinity);
3132
logits[0, 0] = float.NegativeZero;
3233
}

0 commit comments

Comments
 (0)