File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -264,6 +264,9 @@ function getNormalizedConfig(config) {
264264 */
265265export function getCacheShapes ( config , options ) {
266266 if ( config . model_type === 'lfm2' ) {
267+ const pkv_prefix = options ?. prefix ?? 'past_key_values' ;
268+ const conv_prefix = pkv_prefix === 'present' ? 'present' : 'past' ;
269+
267270 // Custom caching mechanism for LFM2
268271 /** @type {Record<string, number[]> } */
269272 const cache_values = { } ;
@@ -274,10 +277,10 @@ export function getCacheShapes(config, options) {
274277 for ( let i = 0 ; i < layer_types . length ; ++ i ) {
275278 if ( layer_types [ i ] === 'full_attention' ) {
276279 for ( const kv of [ 'key' , 'value' ] ) {
277- cache_values [ `past_key_values .${ i } .${ kv } ` ] = [ batch_size , num_key_value_heads , 0 , head_dim ] ;
280+ cache_values [ `${ pkv_prefix } .${ i } .${ kv } ` ] = [ batch_size , num_key_value_heads , 0 , head_dim ] ;
278281 }
279282 } else if ( layer_types [ i ] === 'conv' ) {
280- cache_values [ `past_conv .${ i } ` ] = [ batch_size , hidden_size , conv_L_cache ] ;
283+ cache_values [ `${ conv_prefix } _conv .${ i } ` ] = [ batch_size , hidden_size , conv_L_cache ] ;
281284 } else {
282285 throw new Error ( `Unsupported layer type: ${ layer_types [ i ] } ` ) ;
283286 }
You can’t perform that action at this time.
0 commit comments