Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit d4116d7

Browse files
committed
last fixes gpt2
1 parent 5c2ab83 commit d4116d7

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

gpt2/src/main/java/co/huggingface/android_transformers/gpt2/ml/GPT2Client.kt

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
4949
"Hugging Face is a company that releases awesome projects in machine learning because"
5050
)
5151

52-
private val _prompt = MutableLiveData(prompts[3])
52+
private val _prompt = MutableLiveData(prompts.random())
5353
val prompt: LiveData<String> = _prompt
5454

5555
private val _completion = MutableLiveData("")
@@ -78,7 +78,7 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
7878
initJob.join()
7979
autocompleteJob?.cancelAndJoin()
8080
_completion.value = ""
81-
generate("My name is")
81+
generate(_prompt.value!!)
8282
}
8383
}
8484

@@ -102,11 +102,10 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
102102

103103
val nextToken: Int = when (strategy.strategy) {
104104
GPT2StrategyEnum.TOPK -> {
105-
val finalTopK = min(strategy.value, outputLogits.size)
106105
val filteredLogits = outputLogits
107106
.mapIndexed { index, fl -> (index to fl) }
108-
.sortedBy { it.second }
109-
.takeWhile { it.second < finalTopK }
107+
.sortedByDescending { it.second }
108+
.take(strategy.value)
110109

111110
// Softmax computation on filtered logits
112111
val maxLogitValue = outputLogits.max()!!
@@ -204,10 +203,15 @@ private fun FloatArray.argmax(): Int {
204203
}
205204

206205
@BindingAdapter("prompt", "completion")
207-
fun TextView.formatCompletion(prompt: String, completion: String) {
208-
val str = SpannableStringBuilder(prompt + completion)
209-
val bgCompletionColor = ResourcesCompat.getColor(resources, R.color.colorPrimary, context.theme)
210-
str.setSpan(android.text.style.BackgroundColorSpan(bgCompletionColor), prompt.length, str.length, Spannable.SPAN_EXCLUSIVE_EXCLUSIVE)
211-
212-
text = str
206+
fun TextView.formatCompletion(prompt: String, completion: String): Unit {
207+
text = when {
208+
completion.isEmpty() -> prompt
209+
else -> {
210+
val str = SpannableStringBuilder(prompt + completion)
211+
val bgCompletionColor = ResourcesCompat.getColor(resources, R.color.colorPrimary, context.theme)
212+
str.setSpan(android.text.style.BackgroundColorSpan(bgCompletionColor), prompt.length, str.length, Spannable.SPAN_EXCLUSIVE_EXCLUSIVE)
213+
214+
str
215+
}
216+
}
213217
}

gpt2/src/main/java/co/huggingface/android_transformers/gpt2/tokenization/GPT2Tokenizer.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,13 @@ class GPT2Tokenizer(
3535
var pairs = getPairs(word)
3636

3737
while (true) {
38+
if (!pairs.any { bpeRanks.containsKey(it) }) break
3839
val (first, second) = pairs.minBy { bpeRanks.getOrDefault(it, Int.MAX_VALUE) } ?: break
3940

4041
var i = 0
4142
val newWord = mutableListOf<String>()
4243
while (i < word.size) {
43-
val j = word.subList(i, word.size).indexOf(first)
44+
val j = word.withIndex().indexOfFirst { it.index >= i && it.value == first }
4445
if (j != -1) {
4546
newWord.addAll(word.subList(i, j))
4647
i = j

0 commit comments

Comments
 (0)