Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit 77161f0

Browse files
committed
tweak softmax
1 parent 194f096 commit 77161f0

File tree

1 file changed

+7
-6
lines changed
  • gpt2/src/main/java/co/huggingface/android_transformers/gpt2/ml

1 file changed

+7
-6
lines changed

gpt2/src/main/java/co/huggingface/android_transformers/gpt2/ml/GPT2Client.kt

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,19 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
100100

101101
val nextToken: Int = when (strategy.strategy) {
102102
GPT2StrategyEnum.TOPK -> {
103-
val filteredLogits = outputLogits
103+
val filteredLogitsWithIndexes = outputLogits
104104
.mapIndexed { index, fl -> (index to fl) }
105105
.sortedByDescending { it.second }
106106
.take(strategy.value)
107107

108108
// Softmax computation on filtered logits
109-
val maxLogitValue = outputLogits.max()!!
110-
val logitsExp = filteredLogits.map { exp(it.second - maxLogitValue) }
111-
val sumExp = logitsExp.sum()
112-
val probs = logitsExp.map { it.div(sumExp) }
109+
val filteredLogits = filteredLogitsWithIndexes.map { it.second }
110+
val maxLogitValue = filteredLogits.max()!!
111+
val logitsExp = filteredLogits.map { exp(it - maxLogitValue) }
112+
val sumExp = logitsExp.sum()
113+
val probs = logitsExp.map { it.div(sumExp) }
113114

114-
val logitsIndexes = filteredLogits.map { it.first }
115+
val logitsIndexes = filteredLogitsWithIndexes.map { it.first }
115116
sample(logitsIndexes, probs)
116117
}
117118
else -> outputLogits.argmax()

0 commit comments

Comments
 (0)