Skip to content

Commit 80197b2

Browse files
authored
Correctly assign logits warpers in _get_logits_processor (#1422)
1 parent 699dcb5 commit 80197b2

File tree

1 file changed

+13
-28
lines changed

1 file changed

+13
-28
lines changed

src/models.js

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ import {
8989
MinNewTokensLengthLogitsProcessor,
9090

9191
TemperatureLogitsWarper,
92-
TopKLogitsWarper,
93-
TopPLogitsWarper,
9492
ClassifierFreeGuidanceLogitsProcessor,
9593
} from './generation/logits_process.js';
9694

@@ -1310,32 +1308,6 @@ export class PreTrainedModel extends Callable {
13101308
return this.configs?.generation_config ?? null;
13111309
}
13121310

1313-
/**
1314-
* This function returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`]
1315-
* instances used for multinomial sampling.
1316-
* @param {GenerationConfig} generation_config The generation config.
1317-
* @returns {LogitsProcessorList} generation_config
1318-
*/
1319-
_get_logits_warper(generation_config) {
1320-
1321-
// instantiate warpers list
1322-
const warpers = new LogitsProcessorList();
1323-
1324-
if (generation_config.temperature !== null && generation_config.temperature !== 1.0) {
1325-
warpers.push(new TemperatureLogitsWarper(generation_config.temperature));
1326-
}
1327-
if (generation_config.top_k !== null && generation_config.top_k !== 0) {
1328-
// TODO: add min_tokens_to_keep
1329-
warpers.push(new TopKLogitsWarper(generation_config.top_k));
1330-
}
1331-
if (generation_config.top_p !== null && generation_config.top_p < 1.0) {
1332-
// TODO: add min_tokens_to_keep
1333-
warpers.push(new TopPLogitsWarper(generation_config.top_p));
1334-
}
1335-
1336-
return warpers;
1337-
}
1338-
13391311
/**
13401312
* @param {GenerationConfig} generation_config
13411313
* @param {number} input_ids_seq_length The starting sequence length for the input ids.
@@ -1455,6 +1427,19 @@ export class PreTrainedModel extends Callable {
14551427
processors.push(new ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale));
14561428
}
14571429

1430+
if (generation_config.do_sample) {
1431+
if (generation_config.temperature !== null && generation_config.temperature !== 1.0) {
1432+
processors.push(new TemperatureLogitsWarper(generation_config.temperature));
1433+
}
1434+
// TODO: Add TopPLogitsWarper and TopKLogitsWarper
1435+
// if (generation_config.top_k !== null && generation_config.top_k !== 0) {
1436+
// processors.push(new TopKLogitsWarper(generation_config.top_k));
1437+
// }
1438+
// if (generation_config.top_p !== null && generation_config.top_p < 1.0) {
1439+
// processors.push(new TopPLogitsWarper(generation_config.top_p));
1440+
// }
1441+
}
1442+
14581443
if (logits_processor !== null) {
14591444
processors.extend(logits_processor)
14601445
}

0 commit comments

Comments
 (0)