1616
1717namespace TensorStack . TextGeneration . Pipelines
1818{
19- public abstract class DecoderPipeline : IDisposable
19+ public abstract class DecoderPipeline < O > : IDisposable where O : GenerateOptions
2020 {
2121 private readonly DecoderConfig _decoderConfig ;
2222 private readonly ModelSession _decoder ;
@@ -40,7 +40,7 @@ public DecoderPipeline(ITokenizer tokenizer, DecoderConfig decoderConfig)
4040 protected ModelSession Decoder => _decoder ;
4141 protected DecoderConfig DecoderConfig => _decoderConfig ;
4242 protected TokenizerResult TokenizerOutput { get ; set ; }
43- protected abstract Task < Sequence > InitializeAsync ( GenerateOptions options ) ;
43+ protected abstract Task < Sequence > InitializeAsync ( O options ) ;
4444 protected abstract Task < Tensor < float > > RunDecoderAsync ( Sequence sequence ) ;
4545
4646
@@ -66,7 +66,7 @@ public virtual async Task UnloadAsync(CancellationToken cancellationToken = defa
6666 /// Gets the logits processors.
6767 /// </summary>
6868 /// <param name="options">The options.</param>
69- protected virtual ILogitsProcessor [ ] GetLogitsProcessor ( GenerateOptions options )
69+ protected virtual ILogitsProcessor [ ] GetLogitsProcessor ( O options )
7070 {
7171 return
7272 [
@@ -80,7 +80,7 @@ protected virtual ILogitsProcessor[] GetLogitsProcessor(GenerateOptions options)
8080 /// Gets the token processors.
8181 /// </summary>
8282 /// <param name="options">The options.</param>
83- protected virtual ITokenProcessor [ ] GetTokenProcessors ( GenerateOptions options )
83+ protected virtual ITokenProcessor [ ] GetTokenProcessors ( O options )
8484 {
8585 return
8686 [
@@ -95,7 +95,7 @@ protected virtual ITokenProcessor[] GetTokenProcessors(GenerateOptions options)
9595 /// </summary>
9696 /// <param name="options">The options.</param>
9797 /// <returns>A Task representing the asynchronous operation.</returns>
98- protected virtual async Task TokenizePromptAsync ( GenerateOptions options )
98+ protected virtual async Task TokenizePromptAsync ( O options )
9999 {
100100 TokenizerOutput = await Tokenizer . EncodeAsync ( options . Prompt ) ;
101101 }
@@ -105,7 +105,7 @@ protected virtual async Task TokenizePromptAsync(GenerateOptions options)
105105 /// Gets the sampler.
106106 /// </summary>
107107 /// <param name="options">The options.</param>
108- protected virtual Sampler GetSampler ( GenerateOptions options , bool isBeamSerach )
108+ protected virtual Sampler GetSampler ( O options , bool isBeamSerach )
109109 {
110110 return isBeamSerach
111111 ? new MultinomialSampler ( options )
@@ -119,7 +119,7 @@ protected virtual Sampler GetSampler(GenerateOptions options, bool isBeamSerach)
119119 /// <param name="options">The options.</param>
120120 /// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
121121 /// <returns>A Task<Sequence> representing the asynchronous operation.</returns>
122- protected virtual async Task < Sequence > GreedySearchAsync ( GenerateOptions options , CancellationToken cancellationToken = default )
122+ protected virtual async Task < Sequence > GreedySearchAsync ( O options , CancellationToken cancellationToken = default )
123123 {
124124 var sampler = GetSampler ( options , false ) ;
125125 var logitsProcessors = GetLogitsProcessor ( options ) ;
@@ -158,8 +158,8 @@ protected virtual async Task<Sequence> GreedySearchAsync(GenerateOptions options
158158 /// <param name="options">The options.</param>
159159 /// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
160160 /// <returns>A Task<Sequence[]> representing the asynchronous operation.</returns>
161- protected virtual async Task < Sequence [ ] > BeamSearchAsync ( GenerateOptions options , CancellationToken cancellationToken = default )
162- {
161+ protected virtual async Task < Sequence [ ] > BeamSearchAsync ( O options , CancellationToken cancellationToken = default )
162+ {
163163 var sampler = GetSampler ( options , true ) ;
164164 var logitsProcessors = GetLogitsProcessor ( options ) ;
165165 var tokenProcessors = GetTokenProcessors ( options ) ;
@@ -227,7 +227,7 @@ protected virtual async Task<Sequence[]> BeamSearchAsync(GenerateOptions options
227227 // Process Beams
228228 foreach ( var beam in activeBeams )
229229 {
230- // Console.WriteLine(Tokenizer.Decode(beam.Tokens));
230+ Console . WriteLine ( Tokenizer . Decode ( beam . Tokens ) ) ;
231231 if ( beam . IsComplete )
232232 continue ;
233233
@@ -256,7 +256,7 @@ protected virtual async Task<Sequence[]> BeamSearchAsync(GenerateOptions options
256256 /// </summary>
257257 /// <param name="candidates">The sequences.</param>
258258 /// <param name="options">The options.</param>
259- protected virtual IEnumerable < Sequence > GetSequenceCandidates ( SequenceCollection candidates , GenerateOptions options )
259+ protected virtual IEnumerable < Sequence > GetSequenceCandidates ( SequenceCollection candidates , O options )
260260 {
261261 // TODO: Diversity Penalty
262262 _sequenceComparer . SetLength ( options . DiversityLength ) ;
@@ -281,7 +281,7 @@ protected virtual IEnumerable<Sequence> GetSequenceCandidates(SequenceCollection
281281 /// </summary>
282282 /// <param name="sequences">The sequences.</param>
283283 /// <param name="options">The options.</param>
284- protected virtual bool IsEarlyStopping ( SequenceCollection sequences , GenerateOptions options )
284+ protected virtual bool IsEarlyStopping ( SequenceCollection sequences , O options )
285285 {
286286 if ( options . EarlyStopping != EarlyStopping . None )
287287 {
@@ -323,7 +323,7 @@ protected virtual float GetLengthPenalty(Sequence sequence, float penalty)
323323 /// </summary>
324324 /// <param name="sequences">The sequences.</param>
325325 /// <param name="options">The options.</param>
326- protected virtual Sequence [ ] NormalizeAndSort ( SequenceCollection sequences , GenerateOptions options )
326+ protected virtual Sequence [ ] NormalizeAndSort ( SequenceCollection sequences , O options )
327327 {
328328 var resultSequences = sequences
329329 . Where ( x => x . IsComplete )
0 commit comments