Skip to content

Commit f1471bb

Browse files
committed
Add florence task descriptions
1 parent fde6e26 commit f1471bb

File tree

12 files changed

+250
-149
lines changed

12 files changed

+250
-149
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ site/
350350
docker-test-output/*
351351

352352
Examples/*
353-
TensorStack.UI.WPF/*
353+
TensorStack.WPF/*
354354
TensorStackFull.sln
355355
TensorStack.Diffusers/*
356356
TensorStudio/*
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) TensorStack. All rights reserved.
2+
// Licensed under the Apache 2.0 License.
3+
using System;
4+
using Microsoft.ML.OnnxRuntime;
5+
6+
namespace TensorStack.Common
7+
{
8+
public class ExecutionProvider
9+
{
10+
private readonly string _name;
11+
private readonly Func<ModelConfig, SessionOptions> _sessionOptionsFactory;
12+
13+
public ExecutionProvider(string name, Func<ModelConfig, SessionOptions> sessionOptionsFactory)
14+
{
15+
_name = name;
16+
_sessionOptionsFactory = sessionOptionsFactory;
17+
}
18+
19+
public string Name => _name;
20+
21+
public SessionOptions CreateSession(ModelConfig modelConfig)
22+
{
23+
return _sessionOptionsFactory(modelConfig);
24+
}
25+
}
26+
}

TensorStack.Common/ModelConfig.cs

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
// Copyright (c) TensorStack. All rights reserved.
22
// Licensed under the Apache 2.0 License.
3-
using System;
43
using System.Text.Json.Serialization;
5-
using Microsoft.ML.OnnxRuntime;
64

75
namespace TensorStack.Common
86
{
@@ -22,23 +20,4 @@ public virtual void SetProvider(ExecutionProvider executionProvider)
2220
ExecutionProvider = executionProvider;
2321
}
2422
}
25-
26-
public class ExecutionProvider
27-
{
28-
private readonly string _name;
29-
private readonly Func<ModelConfig, SessionOptions> _sessionOptionsFactory;
30-
31-
public ExecutionProvider(string name, Func<ModelConfig, SessionOptions> sessionOptionsFactory)
32-
{
33-
_name = name;
34-
_sessionOptionsFactory = sessionOptionsFactory;
35-
}
36-
37-
public string Name => _name;
38-
39-
public SessionOptions CreateSession(ModelConfig modelConfig)
40-
{
41-
return _sessionOptionsFactory(modelConfig);
42-
}
43-
}
4423
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using TensorStack.Common.Pipeline;
2+
using TensorStack.TextGeneration.Common;
3+
4+
namespace TensorStack.TextGeneration
5+
{
6+
public interface ITextGeneration :
7+
IPipeline<GenerateResult, GenerateOptions>,
8+
IPipeline<GenerateResult[], SearchOptions>
9+
{
10+
}
11+
}

TensorStack.TextGeneration/Pipelines/DecoderPipeline.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,8 @@ public abstract class DecoderPipeline : IDisposable
2929
/// <param name="decoderConfig">The decoder configuration.</param>
3030
public DecoderPipeline(ITokenizer tokenizer, DecoderConfig decoderConfig)
3131
{
32-
_decoderConfig = decoderConfig;
33-
3432
_tokenizer = tokenizer;
33+
_decoderConfig = decoderConfig;
3534
_decoder = new ModelSession(_decoderConfig);
3635
_sequenceComparer = new SequenceComparer(_tokenizer.SpecialTokens, 5);
3736
}
@@ -126,7 +125,7 @@ protected virtual async Task<Sequence> GreedySearchAsync(GenerateOptions options
126125
logitsProcessor.Process(sequence.Tokens, logits);
127126

128127
// Sample
129-
var sample = sampler.Sample(logits, temperature: options.Temperature).First();
128+
var sample = sampler.Sample(logits, temperature: options.Temperature).First();
130129
sequence.Tokens.Add(sample.TokenId);
131130
sequence.Score += sample.Score;
132131

TensorStack.TextGeneration/Pipelines/EncoderDecoderPipeline.cs

Lines changed: 1 addition & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,16 @@
11
// Copyright (c) TensorStack. All rights reserved.
22
// Licensed under the Apache 2.0 License.
3-
using System;
4-
using System.Collections.Generic;
53
using System.Linq;
6-
using System.Runtime.CompilerServices;
74
using System.Threading;
85
using System.Threading.Tasks;
96
using TensorStack.Common;
10-
using TensorStack.Common.Pipeline;
117
using TensorStack.Common.Tensor;
128
using TensorStack.TextGeneration.Common;
139
using TensorStack.TextGeneration.Processing;
1410

1511
namespace TensorStack.TextGeneration.Pipelines
1612
{
17-
public abstract class EncoderDecoderPipeline : DecoderPipeline,
18-
IPipeline<GenerateResult, GenerateOptions>,
19-
IPipelineStream<GenerateResult, SearchOptions>
13+
public abstract class EncoderDecoderPipeline : DecoderPipeline
2014
{
2115
/// <summary>
2216
/// Initializes a new instance of the <see cref="EncoderDecoderPipeline"/> class.
@@ -57,55 +51,6 @@ public override async Task UnloadAsync(CancellationToken cancellationToken = def
5751
await Encoder.UnloadAsync();
5852
}
5953

60-
/// <summary>
61-
/// Run pipeline GreedySearch
62-
/// </summary>
63-
/// <param name="options">The options.</param>
64-
/// <param name="progressCallback">The progress callback.</param>
65-
/// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
66-
/// <returns>A Task&lt;GenerateResult&gt; representing the asynchronous operation.</returns>
67-
public virtual async Task<GenerateResult> RunAsync(GenerateOptions options, IProgress<RunProgress> progressCallback = null, CancellationToken cancellationToken = default)
68-
{
69-
await TokenizePromptAsync(options);
70-
71-
var sequence = await GreedySearchAsync(options, cancellationToken);
72-
using (sequence)
73-
{
74-
return new GenerateResult
75-
{
76-
Score = sequence.Score,
77-
Result = Tokenizer.Decode(sequence.Tokens)
78-
};
79-
}
80-
}
81-
82-
83-
/// <summary>
84-
/// Run pipeline BeamSearch
85-
/// </summary>
86-
/// <param name="options">The options.</param>
87-
/// <param name="progressCallback">The progress callback.</param>
88-
/// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
89-
/// <returns>A Task&lt;IAsyncEnumerable`1&gt; representing the asynchronous operation.</returns>
90-
public virtual async IAsyncEnumerable<GenerateResult> RunAsync(SearchOptions options, IProgress<RunProgress> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
91-
{
92-
await TokenizePromptAsync(options);
93-
94-
var sequences = await BeamSearchAsync(options, cancellationToken);
95-
foreach (var sequence in sequences)
96-
{
97-
using (sequence)
98-
{
99-
yield return new GenerateResult
100-
{
101-
Beam = sequence.Id,
102-
Score = sequence.Score,
103-
Result = Tokenizer.Decode(sequence.Tokens)
104-
};
105-
}
106-
}
107-
}
108-
10954

11055
/// <summary>
11156
/// Tokenize the prompt

TensorStack.TextGeneration/Pipelines/Florence/FlorenceOptions.cs

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright (c) TensorStack. All rights reserved.
22
// Licensed under the Apache 2.0 License.
3+
using System.ComponentModel.DataAnnotations;
34
using TensorStack.Common.Tensor;
45
using TensorStack.Common.Vision;
56
using TensorStack.TextGeneration.Common;
@@ -13,32 +14,62 @@ public record FlorenceOptions : GenerateOptions
1314
public CoordinateBox<float> Region { get; set; }
1415
}
1516

17+
1618
public record FlorenceSearchOptions : FlorenceOptions
1719
{
18-
public FlorenceSearchOptions(){ }
20+
public FlorenceSearchOptions() { }
1921
public FlorenceSearchOptions(FlorenceOptions options) : base(options) { }
2022
}
2123

2224

23-
24-
2525
public enum TaskType
2626
{
27+
[Display(Name = "None", Description = "Free text prompt without a predefined task.")]
2728
NONE,
29+
30+
[Display(Name = "OCR", Description = "Reads all the visible text in an image.")]
2831
OCR,
32+
33+
[Display(Name = "OCR with Region", Description = "Reads text in an image and also gives the region where each piece of text is found.")]
2934
OCR_WITH_REGION,
35+
36+
[Display(Name = "Caption", Description = "Generates a short description of the overall image.")]
3037
CAPTION,
38+
39+
[Display(Name = "Detailed Caption", Description = "Produces a richer description of the image with more detail than a normal caption.")]
3140
DETAILED_CAPTION,
41+
42+
[Display(Name = "More Detailed Caption", Description = "Gives a very thorough and verbose description of the image.")]
3243
MORE_DETAILED_CAPTION,
44+
45+
[Display(Name = "Object Detection", Description = "Identifies and localizes objects in the image with bounding boxes.")]
3346
OD,
47+
48+
[Display(Name = "Dense Region Caption", Description = "Splits the image into many regions and generates a caption for each region.")]
3449
DENSE_REGION_CAPTION,
50+
51+
[Display(Name = "Caption to Phrase Grounding", Description = "Finds the specific region(s) in the image that correspond to a given phrase in a caption.")]
3552
CAPTION_TO_PHRASE_GROUNDING,
53+
54+
[Display(Name = "Referring Expression Segmentation", Description = "Given a phrase (e.g., 'the red car'), segments out that exact object region at the pixel level.")]
3655
REFERRING_EXPRESSION_SEGMENTATION,
56+
57+
[Display(Name = "Region to Segmentation", Description = "Converts a region (bounding box) into a precise pixel-level segmentation mask.")]
3758
REGION_TO_SEGMENTATION,
59+
60+
[Display(Name = "Open Vocabulary Detection", Description = "Detects objects of arbitrary categories, even ones not seen during training.")]
3861
OPEN_VOCABULARY_DETECTION,
62+
63+
[Display(Name = "Region to Category", Description = "Assigns a category label (e.g., 'cat', 'chair') to a given region.")]
3964
REGION_TO_CATEGORY,
65+
66+
[Display(Name = "Region to Description", Description = "Generates a natural-language description for a given region.")]
4067
REGION_TO_DESCRIPTION,
68+
69+
[Display(Name = "Region to OCR", Description = "Extracts text only from within a specified region.")]
4170
REGION_TO_OCR,
71+
72+
[Display(Name = "Region Proposal", Description = "Suggests candidate regions of interest in the image (without labeling them).")]
4273
REGION_PROPOSAL
4374
}
4475
}

TensorStack.TextGeneration/Pipelines/Florence/FlorencePipeline.cs

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,26 @@ protected override void Dispose(bool disposing)
267267
/// <param name="visionModel">The vision model.</param>
268268
/// <returns>FlorencePipeline.</returns>
269269
public static FlorencePipeline Create(ExecutionProvider provider, string modelPath, FlorenceType modelType, string encoderModel = "encoder_model.onnx", string decoderModel = "decoder_model_merged.onnx", string embedModel = "embed_tokens.onnx", string visionModel = "vision_encoder.onnx")
270+
{
271+
return Create(provider, provider, provider, provider, modelPath, modelType, encoderModel, decoderModel, embedModel, visionModel);
272+
}
273+
274+
275+
/// <summary>
276+
/// Creates a FlorencePipeline with the specified configuration.
277+
/// </summary>
278+
/// <param name="encoderProvider">The encoder provider.</param>
279+
/// <param name="decoderProvider">The decoder provider.</param>
280+
/// <param name="embedsProvider">The embeds provider.</param>
281+
/// <param name="visionProvider">The vision provider.</param>
282+
/// <param name="modelPath">The model path.</param>
283+
/// <param name="modelType">Type of the model.</param>
284+
/// <param name="encoderModel">The encoder model.</param>
285+
/// <param name="decoderModel">The decoder model.</param>
286+
/// <param name="embedModel">The embed model.</param>
287+
/// <param name="visionModel">The vision model.</param>
288+
/// <returns>FlorencePipeline.</returns>
289+
public static FlorencePipeline Create(ExecutionProvider encoderProvider, ExecutionProvider decoderProvider, ExecutionProvider embedsProvider, ExecutionProvider visionProvider, string modelPath, FlorenceType modelType, string encoderModel = "encoder_model.onnx", string decoderModel = "decoder_model_merged.onnx", string embedModel = "embed_tokens.onnx", string visionModel = "vision_encoder.onnx")
270290
{
271291
var numLayers = 6;
272292
var numHeads = 12;
@@ -316,36 +336,11 @@ public static FlorencePipeline Create(ExecutionProvider provider, string modelPa
316336
}
317337
};
318338

319-
config.EncoderConfig.SetProvider(provider);
320-
//config.DecoderConfig.SetProvider(provider);
321-
config.DecoderConfig.SetProvider(ProviderCPU()); // TODO
322-
config.EmbedsConfig.SetProvider(provider);
323-
config.VisionConfig.SetProvider(provider);
339+
config.EncoderConfig.SetProvider(encoderProvider);
340+
config.DecoderConfig.SetProvider(decoderProvider);
341+
config.EmbedsConfig.SetProvider(embedsProvider);
342+
config.VisionConfig.SetProvider(visionProvider);
324343
return new FlorencePipeline(config);
325344
}
326-
327-
328-
private static ExecutionProvider ProviderCPU()
329-
{
330-
return new ExecutionProvider("CPU", configuration =>
331-
{
332-
var sessionOptions = new SessionOptions
333-
{
334-
ExecutionMode = ExecutionMode.ORT_PARALLEL,
335-
EnableCpuMemArena = true,
336-
EnableMemoryPattern = true,
337-
GraphOptimizationLevel = GraphOptimizationLevel.ORT_DISABLE_ALL
338-
};
339-
340-
sessionOptions.AppendExecutionProvider_CPU();
341-
return sessionOptions;
342-
});
343-
}
344-
}
345-
346-
public enum FlorenceType
347-
{
348-
Base = 0,
349-
Large = 1
350345
}
351346
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copyright (c) TensorStack. All rights reserved.
2+
// Licensed under the Apache 2.0 License.
3+
namespace TensorStack.TextGeneration.Pipelines.Florence
4+
{
5+
public enum FlorenceType
6+
{
7+
Base = 0,
8+
Large = 1
9+
}
10+
}

0 commit comments

Comments
 (0)