Skip to content

Commit 3e9d70d

Browse files
committed
working Whisper pipeline
1 parent 80ed0c8 commit 3e9d70d

File tree

15 files changed

+738
-237
lines changed

15 files changed

+738
-237
lines changed

TensorStack.Common/Common/FileHelper.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ public static bool DeleteFile(string filename)
99
{
1010
try
1111
{
12-
if (File.Exists(filename))
12+
if (!File.Exists(filename))
1313
return false;
1414

1515
File.Delete(filename);

TensorStack.Common/Extensions/OrtExtensions.cs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,18 @@ public static Tensor<float> ToTensor(this OrtValue ortValue)
125125
}
126126

127127

128+
/// <summary>
129+
/// Copy OrtValue data to float Tensor.
130+
/// </summary>
131+
/// <param name="ortValue">The ort value.</param>
132+
/// <param name="dimensions">The dimensions.</param>
133+
/// <returns>Tensor&lt;System.Single&gt;.</returns>
134+
public static Tensor<float> ToTensor(this OrtValue ortValue, int[] dimensions)
135+
{
136+
return CreateTensor<float>(ortValue, dimensions);
137+
}
138+
139+
128140
/// <summary>
129141
/// Copy OrtValue data to float Tensor.
130142
/// </summary>
@@ -137,6 +149,19 @@ public static Tensor<T> ToTensor<T>(this OrtValue ortValue) where T : unmanaged,
137149
}
138150

139151

152+
/// <summary>
153+
/// Copy OrtValue data to float Tensor.
154+
/// </summary>
155+
/// <typeparam name="T"></typeparam>
156+
/// <param name="ortValue">The ort value.</param>
157+
/// <param name="dimensions">The dimensions.</param>
158+
/// <returns>Tensor&lt;T&gt;.</returns>
159+
public static Tensor<T> ToTensor<T>(this OrtValue ortValue, int[] dimensions) where T : unmanaged, INumber<T>
160+
{
161+
return CreateTensor<T>(ortValue, dimensions);
162+
}
163+
164+
140165
/// <summary>
141166
/// Copy OrtValue data to flot array.
142167
/// </summary>
@@ -160,6 +185,17 @@ public static T[] ToArray<T>(this OrtValue ortValue) where T : unmanaged, INumbe
160185
}
161186

162187

188+
/// <summary>
189+
/// Gets the dimensions.
190+
/// </summary>
191+
/// <param name="ortValue">The ort value.</param>
192+
/// <returns>ReadOnlySpan&lt;System.Int32&gt;.</returns>
193+
public static int[] GetDimensions(this OrtValue ortValue)
194+
{
195+
return ortValue.GetTensorTypeAndShape().Shape.ToInt();
196+
}
197+
198+
163199
/// <summary>
164200
/// Creates the Tensor from OrtValue.
165201
/// </summary>
@@ -169,6 +205,19 @@ private static Tensor<T> CreateTensor<T>(OrtValue ortValue) where T : unmanaged,
169205
{
170206
var metadata = ortValue.GetTensorTypeAndShape();
171207
var dimensions = metadata.Shape.ToInt();
208+
return CreateTensor<T>(ortValue, dimensions);
209+
}
210+
211+
212+
/// <summary>
213+
/// Creates the Tensor from OrtValue.
214+
/// </summary>
215+
/// <typeparam name="T"></typeparam>
216+
/// <param name="ortValue">The ort value.</param>
217+
/// <param name="dimensions">The dimensions.</param>
218+
/// <returns>Tensor&lt;T&gt;.</returns>
219+
private static Tensor<T> CreateTensor<T>(OrtValue ortValue, int[] dimensions) where T : unmanaged, INumber<T>
220+
{
172221
var buffer = CreateArray<T>(ortValue);
173222
return new Tensor<T>(buffer, dimensions);
174223
}

TensorStack.TextGeneration/Pipelines/DecoderPipeline.cs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
namespace TensorStack.TextGeneration.Pipelines
1818
{
19-
public abstract class DecoderPipeline : IDisposable
19+
public abstract class DecoderPipeline<O> : IDisposable where O : GenerateOptions
2020
{
2121
private readonly DecoderConfig _decoderConfig;
2222
private readonly ModelSession _decoder;
@@ -40,7 +40,7 @@ public DecoderPipeline(ITokenizer tokenizer, DecoderConfig decoderConfig)
4040
protected ModelSession Decoder => _decoder;
4141
protected DecoderConfig DecoderConfig => _decoderConfig;
4242
protected TokenizerResult TokenizerOutput { get; set; }
43-
protected abstract Task<Sequence> InitializeAsync(GenerateOptions options);
43+
protected abstract Task<Sequence> InitializeAsync(O options);
4444
protected abstract Task<Tensor<float>> RunDecoderAsync(Sequence sequence);
4545

4646

@@ -66,7 +66,7 @@ public virtual async Task UnloadAsync(CancellationToken cancellationToken = defa
6666
/// Gets the logits processors.
6767
/// </summary>
6868
/// <param name="options">The options.</param>
69-
protected virtual ILogitsProcessor[] GetLogitsProcessor(GenerateOptions options)
69+
protected virtual ILogitsProcessor[] GetLogitsProcessor(O options)
7070
{
7171
return
7272
[
@@ -80,7 +80,7 @@ protected virtual ILogitsProcessor[] GetLogitsProcessor(GenerateOptions options)
8080
/// Gets the token processors.
8181
/// </summary>
8282
/// <param name="options">The options.</param>
83-
protected virtual ITokenProcessor[] GetTokenProcessors(GenerateOptions options)
83+
protected virtual ITokenProcessor[] GetTokenProcessors(O options)
8484
{
8585
return
8686
[
@@ -95,7 +95,7 @@ protected virtual ITokenProcessor[] GetTokenProcessors(GenerateOptions options)
9595
/// </summary>
9696
/// <param name="options">The options.</param>
9797
/// <returns>A Task representing the asynchronous operation.</returns>
98-
protected virtual async Task TokenizePromptAsync(GenerateOptions options)
98+
protected virtual async Task TokenizePromptAsync(O options)
9999
{
100100
TokenizerOutput = await Tokenizer.EncodeAsync(options.Prompt);
101101
}
@@ -105,7 +105,7 @@ protected virtual async Task TokenizePromptAsync(GenerateOptions options)
105105
/// Gets the sampler.
106106
/// </summary>
107107
/// <param name="options">The options.</param>
108-
protected virtual Sampler GetSampler(GenerateOptions options, bool isBeamSerach)
108+
protected virtual Sampler GetSampler(O options, bool isBeamSerach)
109109
{
110110
return isBeamSerach
111111
? new MultinomialSampler(options)
@@ -119,7 +119,7 @@ protected virtual Sampler GetSampler(GenerateOptions options, bool isBeamSerach)
119119
/// <param name="options">The options.</param>
120120
/// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
121121
/// <returns>A Task&lt;Sequence&gt; representing the asynchronous operation.</returns>
122-
protected virtual async Task<Sequence> GreedySearchAsync(GenerateOptions options, CancellationToken cancellationToken = default)
122+
protected virtual async Task<Sequence> GreedySearchAsync(O options, CancellationToken cancellationToken = default)
123123
{
124124
var sampler = GetSampler(options, false);
125125
var logitsProcessors = GetLogitsProcessor(options);
@@ -158,8 +158,8 @@ protected virtual async Task<Sequence> GreedySearchAsync(GenerateOptions options
158158
/// <param name="options">The options.</param>
159159
/// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
160160
/// <returns>A Task&lt;Sequence[]&gt; representing the asynchronous operation.</returns>
161-
protected virtual async Task<Sequence[]> BeamSearchAsync(GenerateOptions options, CancellationToken cancellationToken = default)
162-
{
161+
protected virtual async Task<Sequence[]> BeamSearchAsync(O options, CancellationToken cancellationToken = default)
162+
{
163163
var sampler = GetSampler(options, true);
164164
var logitsProcessors = GetLogitsProcessor(options);
165165
var tokenProcessors = GetTokenProcessors(options);
@@ -227,7 +227,7 @@ protected virtual async Task<Sequence[]> BeamSearchAsync(GenerateOptions options
227227
// Process Beams
228228
foreach (var beam in activeBeams)
229229
{
230-
// Console.WriteLine(Tokenizer.Decode(beam.Tokens));
230+
Console.WriteLine(Tokenizer.Decode(beam.Tokens));
231231
if (beam.IsComplete)
232232
continue;
233233

@@ -256,7 +256,7 @@ protected virtual async Task<Sequence[]> BeamSearchAsync(GenerateOptions options
256256
/// </summary>
257257
/// <param name="candidates">The sequences.</param>
258258
/// <param name="options">The options.</param>
259-
protected virtual IEnumerable<Sequence> GetSequenceCandidates(SequenceCollection candidates, GenerateOptions options)
259+
protected virtual IEnumerable<Sequence> GetSequenceCandidates(SequenceCollection candidates, O options)
260260
{
261261
// TODO: Diversity Penalty
262262
_sequenceComparer.SetLength(options.DiversityLength);
@@ -281,7 +281,7 @@ protected virtual IEnumerable<Sequence> GetSequenceCandidates(SequenceCollection
281281
/// </summary>
282282
/// <param name="sequences">The sequences.</param>
283283
/// <param name="options">The options.</param>
284-
protected virtual bool IsEarlyStopping(SequenceCollection sequences, GenerateOptions options)
284+
protected virtual bool IsEarlyStopping(SequenceCollection sequences, O options)
285285
{
286286
if (options.EarlyStopping != EarlyStopping.None)
287287
{
@@ -323,7 +323,7 @@ protected virtual float GetLengthPenalty(Sequence sequence, float penalty)
323323
/// </summary>
324324
/// <param name="sequences">The sequences.</param>
325325
/// <param name="options">The options.</param>
326-
protected virtual Sequence[] NormalizeAndSort(SequenceCollection sequences, GenerateOptions options)
326+
protected virtual Sequence[] NormalizeAndSort(SequenceCollection sequences, O options)
327327
{
328328
var resultSequences = sequences
329329
.Where(x => x.IsComplete)

TensorStack.TextGeneration/Pipelines/EncoderDecoderPipeline.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
namespace TensorStack.TextGeneration.Pipelines
1212
{
13-
public abstract class EncoderDecoderPipeline : DecoderPipeline
13+
public abstract class EncoderDecoderPipeline<O> : DecoderPipeline<O> where O : GenerateOptions
1414
{
1515
/// <summary>
1616
/// Initializes a new instance of the <see cref="EncoderDecoderPipeline"/> class.
@@ -57,7 +57,7 @@ public override async Task UnloadAsync(CancellationToken cancellationToken = def
5757
/// </summary>
5858
/// <param name="options">The options.</param>
5959
/// <returns>A Task representing the asynchronous operation.</returns>
60-
protected override async Task TokenizePromptAsync(GenerateOptions options)
60+
protected override async Task TokenizePromptAsync(O options)
6161
{
6262
await base.TokenizePromptAsync(options);
6363
EncoderOutput = await RunEncoderAsync();
@@ -128,7 +128,7 @@ protected override async Task<Tensor<float>> RunDecoderAsync(Sequence sequence)
128128
/// </summary>
129129
/// <param name="options">The options.</param>
130130
/// <returns>A Task&lt;Sequence&gt; representing the asynchronous operation.</returns>
131-
protected override async Task<Sequence> InitializeAsync(GenerateOptions options)
131+
protected override async Task<Sequence> InitializeAsync(O options)
132132
{
133133
var modelMetadata = await Decoder.LoadAsync();
134134
var dataType = modelMetadata.Outputs[0].Value.ElementDataType;

TensorStack.TextGeneration/Pipelines/Florence/FlorencePipeline.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
namespace TensorStack.TextGeneration.Pipelines.Florence
1919
{
20-
public class FlorencePipeline : EncoderDecoderPipeline,
20+
public class FlorencePipeline : EncoderDecoderPipeline<FlorenceOptions>,
2121
IPipeline<GenerateResult, FlorenceOptions>,
2222
IPipelineStream<GenerateResult, FlorenceSearchOptions>
2323
{

TensorStack.TextGeneration/Pipelines/Other/SummaryPipeline.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
namespace TensorStack.TextGeneration.Pipelines.Other
1313
{
14-
public class SummaryPipeline : EncoderDecoderPipeline, ITextGeneration
14+
public class SummaryPipeline : EncoderDecoderPipeline<GenerateOptions>, ITextGeneration
1515
{
1616
/// <summary>
1717
/// Initializes a new instance of the <see cref="SummaryPipeline"/> class.

TensorStack.TextGeneration/Pipelines/Phi/Phi3Pipeline.cs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
namespace TensorStack.TextGeneration.Pipelines.Phi
1717
{
18-
public class Phi3Pipeline : DecoderPipeline, ITextGeneration
18+
public class Phi3Pipeline : DecoderPipeline<GenerateOptions>, ITextGeneration
1919
{
2020
/// <summary>
2121
/// Initializes a new instance of the <see cref="Phi3Pipeline"/> class.
@@ -113,12 +113,9 @@ protected override async Task<Sequence> InitializeAsync(GenerateOptions options)
113113

114114
// Result
115115
var modelResult = Decoder.RunInference(parameters);
116-
using (var logitsResult = modelResult[0])
117-
{
118-
var logits = logitsResult.ToTensor();
119-
var presentKeyValues = modelResult.ToArray()[1..];
120-
sequence.UpdateCache(presentKeyValues, false);
121-
}
116+
modelResult[0].Dispose(); // logits
117+
var presentKeyValues = modelResult.ToArray()[1..];
118+
sequence.UpdateCache(presentKeyValues, false);
122119
}
123120
return sequence;
124121
}

0 commit comments

Comments
 (0)