@@ -89,8 +89,6 @@ import {
8989 MinNewTokensLengthLogitsProcessor ,
9090
9191 TemperatureLogitsWarper ,
92- TopKLogitsWarper ,
93- TopPLogitsWarper ,
9492 ClassifierFreeGuidanceLogitsProcessor ,
9593} from './generation/logits_process.js' ;
9694
@@ -1310,32 +1308,6 @@ export class PreTrainedModel extends Callable {
13101308 return this . configs ?. generation_config ?? null ;
13111309 }
13121310
1313- /**
1314- * This function returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`]
1315- * instances used for multinomial sampling.
1316- * @param {GenerationConfig } generation_config The generation config.
1317- * @returns {LogitsProcessorList } generation_config
1318- */
1319- _get_logits_warper ( generation_config ) {
1320-
1321- // instantiate warpers list
1322- const warpers = new LogitsProcessorList ( ) ;
1323-
1324- if ( generation_config . temperature !== null && generation_config . temperature !== 1.0 ) {
1325- warpers . push ( new TemperatureLogitsWarper ( generation_config . temperature ) ) ;
1326- }
1327- if ( generation_config . top_k !== null && generation_config . top_k !== 0 ) {
1328- // TODO: add min_tokens_to_keep
1329- warpers . push ( new TopKLogitsWarper ( generation_config . top_k ) ) ;
1330- }
1331- if ( generation_config . top_p !== null && generation_config . top_p < 1.0 ) {
1332- // TODO: add min_tokens_to_keep
1333- warpers . push ( new TopPLogitsWarper ( generation_config . top_p ) ) ;
1334- }
1335-
1336- return warpers ;
1337- }
1338-
13391311 /**
13401312 * @param {GenerationConfig } generation_config
13411313 * @param {number } input_ids_seq_length The starting sequence length for the input ids.
@@ -1455,6 +1427,19 @@ export class PreTrainedModel extends Callable {
14551427 processors . push ( new ClassifierFreeGuidanceLogitsProcessor ( generation_config . guidance_scale ) ) ;
14561428 }
14571429
1430+ if ( generation_config . do_sample ) {
1431+ if ( generation_config . temperature !== null && generation_config . temperature !== 1.0 ) {
1432+ processors . push ( new TemperatureLogitsWarper ( generation_config . temperature ) ) ;
1433+ }
1434+ // TODO: Add TopPLogitsWarper and TopKLogitsWarper
1435+ // if (generation_config.top_k !== null && generation_config.top_k !== 0) {
1436+ // processors.push(new TopKLogitsWarper(generation_config.top_k));
1437+ // }
1438+ // if (generation_config.top_p !== null && generation_config.top_p < 1.0) {
1439+ // processors.push(new TopPLogitsWarper(generation_config.top_p));
1440+ // }
1441+ }
1442+
14581443 if ( logits_processor !== null ) {
14591444 processors . extend ( logits_processor )
14601445 }
0 commit comments