This repository was archived by the owner on Jan 10, 2025. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +7
-6
lines changed
gpt2/src/main/java/co/huggingface/android_transformers/gpt2/ml Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -100,18 +100,19 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
100100
101101 val nextToken: Int = when (strategy.strategy) {
102102 GPT2StrategyEnum .TOPK -> {
103- val filteredLogits = outputLogits
103+ val filteredLogitsWithIndexes = outputLogits
104104 .mapIndexed { index, fl -> (index to fl) }
105105 .sortedByDescending { it.second }
106106 .take(strategy.value)
107107
108108 // Softmax computation on filtered logits
109- val maxLogitValue = outputLogits.max()!!
110- val logitsExp = filteredLogits.map { exp(it.second - maxLogitValue) }
111- val sumExp = logitsExp.sum()
112- val probs = logitsExp.map { it.div(sumExp) }
109+ val filteredLogits = filteredLogitsWithIndexes.map { it.second }
110+ val maxLogitValue = filteredLogits.max()!!
111+ val logitsExp = filteredLogits.map { exp(it - maxLogitValue) }
112+ val sumExp = logitsExp.sum()
113+ val probs = logitsExp.map { it.div(sumExp) }
113114
114- val logitsIndexes = filteredLogits .map { it.first }
115+ val logitsIndexes = filteredLogitsWithIndexes .map { it.first }
115116 sample(logitsIndexes, probs)
116117 }
117118 else -> outputLogits.argmax()
You can’t perform that action at this time.
0 commit comments