@@ -49,7 +49,7 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
4949 " Hugging Face is a company that releases awesome projects in machine learning because"
5050 )
5151
52- private val _prompt = MutableLiveData (prompts[ 3 ] )
52+ private val _prompt = MutableLiveData (prompts.random() )
5353 val prompt: LiveData <String > = _prompt
5454
5555 private val _completion = MutableLiveData (" " )
@@ -78,7 +78,7 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
7878 initJob.join()
7979 autocompleteJob?.cancelAndJoin()
8080 _completion .value = " "
81- generate(" My name is " )
81+ generate(_prompt .value !! )
8282 }
8383 }
8484
@@ -102,11 +102,10 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
102102
103103 val nextToken: Int = when (strategy.strategy) {
104104 GPT2StrategyEnum .TOPK -> {
105- val finalTopK = min(strategy.value, outputLogits.size)
106105 val filteredLogits = outputLogits
107106 .mapIndexed { index, fl -> (index to fl) }
108- .sortedBy { it.second }
109- .takeWhile { it.second < finalTopK }
107+ .sortedByDescending { it.second }
108+ .take(strategy.value)
110109
111110 // Softmax computation on filtered logits
112111 val maxLogitValue = outputLogits.max()!!
@@ -204,10 +203,15 @@ private fun FloatArray.argmax(): Int {
204203}
205204
206205@BindingAdapter(" prompt" , " completion" )
207- fun TextView.formatCompletion (prompt : String , completion : String ) {
208- val str = SpannableStringBuilder (prompt + completion)
209- val bgCompletionColor = ResourcesCompat .getColor(resources, R .color.colorPrimary, context.theme)
210- str.setSpan(android.text.style.BackgroundColorSpan (bgCompletionColor), prompt.length, str.length, Spannable .SPAN_EXCLUSIVE_EXCLUSIVE )
211-
212- text = str
206+ fun TextView.formatCompletion (prompt : String , completion : String ): Unit {
207+ text = when {
208+ completion.isEmpty() -> prompt
209+ else -> {
210+ val str = SpannableStringBuilder (prompt + completion)
211+ val bgCompletionColor = ResourcesCompat .getColor(resources, R .color.colorPrimary, context.theme)
212+ str.setSpan(android.text.style.BackgroundColorSpan (bgCompletionColor), prompt.length, str.length, Spannable .SPAN_EXCLUSIVE_EXCLUSIVE )
213+
214+ str
215+ }
216+ }
213217}
0 commit comments