Skip to content

Commit a1222f4

Browse files
committed
Abstract sampler and BPETokeizer
1 parent f1471bb commit a1222f4

File tree

6 files changed

+508
-466
lines changed

6 files changed

+508
-466
lines changed

TensorStack.TextGeneration/Pipelines/DecoderPipeline.cs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using TensorStack.TextGeneration.Common;
1212
using TensorStack.TextGeneration.Pipelines.Phi;
1313
using TensorStack.TextGeneration.Processing;
14+
using TensorStack.TextGeneration.Processing.Sampler;
1415
using TensorStack.TextGeneration.Tokenizers;
1516

1617
namespace TensorStack.TextGeneration.Pipelines
@@ -100,6 +101,18 @@ protected virtual async Task TokenizePromptAsync(GenerateOptions options)
100101
}
101102

102103

104+
/// <summary>
105+
/// Gets the sampler.
106+
/// </summary>
107+
/// <param name="options">The options.</param>
108+
protected virtual Sampler GetSampler(GenerateOptions options, bool isBeamSerach)
109+
{
110+
return isBeamSerach
111+
? new MultinomialSampler(options)
112+
: new GreedySampler(options);
113+
}
114+
115+
103116
/// <summary>
104117
/// Greedy search
105118
/// </summary>
@@ -108,7 +121,7 @@ protected virtual async Task TokenizePromptAsync(GenerateOptions options)
108121
/// <returns>A Task&lt;Sequence&gt; representing the asynchronous operation.</returns>
109122
protected virtual async Task<Sequence> GreedySearchAsync(GenerateOptions options, CancellationToken cancellationToken = default)
110123
{
111-
var sampler = new DefaultSampler(options.Seed);
124+
var sampler = GetSampler(options, false);
112125
var logitsProcessors = GetLogitsProcessor(options);
113126
var tokenProcessors = GetTokenProcessors(options);
114127

@@ -147,7 +160,7 @@ protected virtual async Task<Sequence> GreedySearchAsync(GenerateOptions options
147160
/// <returns>A Task&lt;Sequence[]&gt; representing the asynchronous operation.</returns>
148161
protected virtual async Task<Sequence[]> BeamSearchAsync(GenerateOptions options, CancellationToken cancellationToken = default)
149162
{
150-
var sampler = new DefaultSampler(options.Seed);
163+
var sampler = GetSampler(options, true);
151164
var logitsProcessors = GetLogitsProcessor(options);
152165
var tokenProcessors = GetTokenProcessors(options);
153166

0 commit comments

Comments
 (0)