1111using TensorStack . TextGeneration . Common ;
1212using TensorStack . TextGeneration . Pipelines . Phi ;
1313using TensorStack . TextGeneration . Processing ;
14+ using TensorStack . TextGeneration . Processing . Sampler ;
1415using TensorStack . TextGeneration . Tokenizers ;
1516
1617namespace TensorStack . TextGeneration . Pipelines
@@ -100,6 +101,18 @@ protected virtual async Task TokenizePromptAsync(GenerateOptions options)
100101 }
101102
102103
104+ /// <summary>
105+ /// Gets the sampler.
106+ /// </summary>
107+ /// <param name="options">The options.</param>
108+ protected virtual Sampler GetSampler ( GenerateOptions options , bool isBeamSerach )
109+ {
110+ return isBeamSerach
111+ ? new MultinomialSampler ( options )
112+ : new GreedySampler ( options ) ;
113+ }
114+
115+
103116 /// <summary>
104117 /// Greedy search
105118 /// </summary>
@@ -108,7 +121,7 @@ protected virtual async Task TokenizePromptAsync(GenerateOptions options)
108121 /// <returns>A Task<Sequence> representing the asynchronous operation.</returns>
109122 protected virtual async Task < Sequence > GreedySearchAsync ( GenerateOptions options , CancellationToken cancellationToken = default )
110123 {
111- var sampler = new DefaultSampler ( options . Seed ) ;
124+ var sampler = GetSampler ( options , false ) ;
112125 var logitsProcessors = GetLogitsProcessor ( options ) ;
113126 var tokenProcessors = GetTokenProcessors ( options ) ;
114127
@@ -147,7 +160,7 @@ protected virtual async Task<Sequence> GreedySearchAsync(GenerateOptions options
147160 /// <returns>A Task<Sequence[]> representing the asynchronous operation.</returns>
148161 protected virtual async Task < Sequence [ ] > BeamSearchAsync ( GenerateOptions options , CancellationToken cancellationToken = default )
149162 {
150- var sampler = new DefaultSampler ( options . Seed ) ;
163+ var sampler = GetSampler ( options , true ) ;
151164 var logitsProcessors = GetLogitsProcessor ( options ) ;
152165 var tokenProcessors = GetTokenProcessors ( options ) ;
153166
0 commit comments