@@ -64,14 +64,14 @@ public async Task<GenerateResult[]> RunAsync(SearchOptions options, IProgress<Ru
6464
6565 var sequences = await BeamSearchAsync ( options , cancellationToken ) ;
6666 var results = new GenerateResult [ sequences . Length ] ;
67- for ( int i = 0 ; i < sequences . Length ; i ++ )
67+ for ( int beam = 0 ; beam < sequences . Length ; beam ++ )
6868 {
69- var sequence = sequences [ i ] ;
69+ var sequence = sequences [ beam ] ;
7070 using ( sequence )
7171 {
72- results [ i ] = new GenerateResult
72+ results [ beam ] = new GenerateResult
7373 {
74- Beam = sequence . Id ,
74+ Beam = beam ,
7575 Score = sequence . Score ,
7676 PenaltyScore = sequence . PenaltyScore ,
7777 Result = Tokenizer . Decode ( sequence . Tokens )
@@ -82,6 +82,28 @@ public async Task<GenerateResult[]> RunAsync(SearchOptions options, IProgress<Ru
8282 }
8383
8484
85+ /// <summary>
86+ /// Gets the token processors.
87+ /// </summary>
88+ /// <param name="options">The options.</param>
89+ /// <returns>ITokenProcessor[].</returns>
90+ protected override ITokenProcessor [ ] GetTokenProcessors ( GenerateOptions options )
91+ {
92+ return
93+ [
94+ new EOSTokenProcessor
95+ (
96+ options . MinLength , // min length
97+ Tokenizer . EOS ,
98+ 32000 , // <|endoftext|>
99+ 32001 , // <|assistant|>
100+ 32007 // <|end|>
101+ ) ,
102+ new MaxLengthTokenProcessor ( options . MaxLength )
103+ ] ;
104+ }
105+
106+
85107 /// <summary>
86108 /// Initialize the Decoder cache
87109 /// </summary>
@@ -91,32 +113,15 @@ protected override async Task<Sequence> InitializeAsync(GenerateOptions options)
91113 {
92114 var modelMetadata = await Decoder . LoadAsync ( ) ;
93115 var dataType = modelMetadata . Outputs [ 0 ] . Value . ElementDataType ;
94- var kvCache = new KVCacheDecoder ( dataType , DecoderConfig . NumHeads , DecoderConfig . NumLayers , DecoderConfig . HiddenSize , DecoderConfig . NumKVHeads ) ;
116+ var kvCache = new KVCacheDecoder ( dataType , DecoderConfig . NumHeads , DecoderConfig . NumLayers , DecoderConfig . HiddenSize , DecoderConfig . NumKVHeads , options . MaxLength ) ;
95117 var sequence = new Sequence ( kvCache , Tokenizer . BOS ) ;
96118 sequence . Initialize ( TokenizerOutput . Length ) ;
97119
98- var positionIds = GetPositionIds ( modelMetadata , 0 , TokenizerOutput . Length ) ;
99- var attentionMask = new Tensor < long > ( [ 1 , TokenizerOutput . Length ] , 1 ) ;
100- using ( var parameters = new ModelParameters ( modelMetadata ) )
101- {
102- // Inputs
103- parameters . AddInput ( TokenizerOutput . InputIds ) ;
104- if ( positionIds != null )
105- parameters . AddInput ( positionIds ) ;
106- parameters . AddInput ( attentionMask ) ;
107- foreach ( var pastKeyValue in sequence . Cache )
108- parameters . AddInput ( pastKeyValue ) ;
109-
110- // Outputs
111- foreach ( var output in modelMetadata . Outputs )
112- parameters . AddOutput ( ) ;
113-
114- // Result
115- var modelResult = Decoder . RunInference ( parameters ) ;
116- modelResult [ 0 ] . Dispose ( ) ; // logits
117- var presentKeyValues = modelResult . ToArray ( ) [ 1 ..] ;
118- sequence . UpdateCache ( presentKeyValues , false ) ;
119- }
120+ var position = TokenizerOutput . Length ;
121+ var inputIds = TokenizerOutput . InputIds ;
122+ var positionIds = GetPositionIds ( modelMetadata , 0 , position ) ;
123+ var attentionMask = new Tensor < long > ( [ 1 , position ] , 1 ) ;
124+ RunDecoderInternalAsync ( modelMetadata , sequence , inputIds , positionIds , attentionMask , false ) ;
120125 return sequence ;
121126 }
122127
@@ -128,11 +133,26 @@ protected override async Task<Sequence> InitializeAsync(GenerateOptions options)
128133 /// <returns>A Task<Tensor`1> representing the asynchronous operation.</returns>
129134 protected override async Task < Tensor < float > > RunDecoderAsync ( Sequence sequence )
130135 {
131- var currentPosition = TokenizerOutput . Length + sequence . Tokens . Count ;
132136 var modelMetadata = await Decoder . LoadAsync ( ) ;
137+ var position = TokenizerOutput . Length + sequence . Tokens . Count ;
133138 var inputIds = new Tensor < long > ( [ 1 , 1 ] , sequence . Tokens [ ^ 1 ] ) ;
134- var positionIds = GetPositionIds ( modelMetadata , currentPosition ) ;
135- var attentionMask = new Tensor < long > ( [ 1 , currentPosition ] , 1 ) ;
139+ var positionIds = GetPositionIds ( modelMetadata , position ) ;
140+ var attentionMask = new Tensor < long > ( [ 1 , position ] , 1 ) ;
141+ return RunDecoderInternalAsync ( modelMetadata , sequence , inputIds , positionIds , attentionMask , true ) ;
142+ }
143+
144+
145+ /// <summary>
146+ /// Runs the decoder
147+ /// </summary>
148+ /// <param name="modelMetadata">The model metadata.</param>
149+ /// <param name="sequence">The sequence.</param>
150+ /// <param name="inputIds">The input ids.</param>
151+ /// <param name="positionIds">The position ids.</param>
152+ /// <param name="attentionMask">The attention mask.</param>
153+ /// <param name="useBranchCache">if set to <c>true</c> [use branch cache].</param>
154+ private Tensor < float > RunDecoderInternalAsync ( ModelMetadata modelMetadata , Sequence sequence , Tensor < long > inputIds , Tensor < long > positionIds , Tensor < long > attentionMask , bool useBranchCache )
155+ {
136156 using ( var parameters = new ModelParameters ( modelMetadata ) )
137157 {
138158 // Inputs
@@ -151,38 +171,17 @@ protected override async Task<Tensor<float>> RunDecoderAsync(Sequence sequence)
151171 var modelResult = Decoder . RunInference ( parameters ) ;
152172 using ( var logitsResult = modelResult [ 0 ] )
153173 {
154- var logits = logitsResult . ToTensor ( ) ;
174+ var dimension = logitsResult . GetDimensions ( ) ;
175+ var logits = logitsResult . ToTensor ( dimension [ 1 ..] ) ;
155176 var presentKeyValues = modelResult . ToArray ( ) [ 1 ..] ;
156177
157- sequence . UpdateCache ( presentKeyValues , false ) ;
158- return logits . Reshape ( [ logits . Dimensions [ 0 ] , logits . Dimensions [ 2 ] ] ) ;
178+ sequence . UpdateCache ( presentKeyValues , useBranchCache ) ;
179+ return logits ;
159180 }
160181 }
161182 }
162183
163184
164- /// <summary>
165- /// Gets the token processors.
166- /// </summary>
167- /// <param name="options">The options.</param>
168- /// <returns>ITokenProcessor[].</returns>
169- protected override ITokenProcessor [ ] GetTokenProcessors ( GenerateOptions options )
170- {
171- return
172- [
173- new EOSTokenProcessor
174- (
175- options . MinLength , // min length
176- Tokenizer . EOS ,
177- 32000 , // <|endoftext|>
178- 32001 // <|assistant|>
179- // 32007 // <|end|>
180- ) ,
181- new MaxLengthTokenProcessor ( options . MaxLength )
182- ] ;
183- }
184-
185-
186185 /// <summary>
187186 /// Creates the Phi3Pipeline
188187 /// </summary>
0 commit comments