Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit 5c2ab83

Browse files
committed
interface ok
1 parent d20f90c commit 5c2ab83

File tree

8 files changed

+168
-61
lines changed

8 files changed

+168
-61
lines changed

.idea/misc.xml

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

gpt2/build.gradle

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
apply plugin: 'com.android.application'
22
apply plugin: 'kotlin-android'
33
apply plugin: 'kotlin-android-extensions'
4+
apply plugin: 'kotlin-kapt'
45

56
android {
67
compileSdkVersion 29
@@ -21,6 +22,10 @@ android {
2122
noCompress "tflite"
2223
}
2324

25+
dataBinding {
26+
enabled = true
27+
}
28+
2429
kotlinOptions {
2530
jvmTarget = JavaVersion.VERSION_1_8
2631
}
@@ -43,6 +48,7 @@ dependencies {
4348
implementation 'androidx.appcompat:appcompat:1.1.0'
4449
implementation 'androidx.core:core-ktx:1.1.0'
4550
implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
51+
implementation 'com.google.android.material:material:1.1.0-beta02'
4652
testImplementation 'junit:junit:4.12'
4753
androidTestImplementation 'androidx.test.ext:junit:1.1.1'
4854
androidTestImplementation 'androidx.test.espresso:espresso-core:3.2.0'

gpt2/src/main/java/co/huggingface/android_transformers/gpt2/MainActivity.kt

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,23 @@ package co.huggingface.android_transformers.gpt2
22

33
import androidx.appcompat.app.AppCompatActivity
44
import android.os.Bundle
5-
import android.os.Handler
6-
import android.os.HandlerThread
75
import androidx.activity.viewModels
8-
import androidx.lifecycle.observe
6+
import androidx.databinding.DataBindingUtil
7+
import co.huggingface.android_transformers.gpt2.databinding.ActivityMainBinding
98

109
class MainActivity : AppCompatActivity() {
1110
private val gpt2: co.huggingface.android_transformers.gpt2.ml.GPT2Client by viewModels()
12-
private val handlerThread by lazy { HandlerThread("GPT2Client") }
13-
private val handler by lazy {
14-
handlerThread.start()
15-
Handler(handlerThread.looper)
16-
}
1711

1812
override fun onCreate(savedInstanceState: Bundle?) {
1913
super.onCreate(savedInstanceState)
20-
setContentView(R.layout.activity_main)
2114

22-
handler.post {
23-
gpt2.init()
24-
val generation = gpt2.generate("My name is")
15+
val binding: ActivityMainBinding
16+
= DataBindingUtil.setContentView(this, R.layout.activity_main)
2517

26-
runOnUiThread {
27-
generation.observe(this) {
28-
print(it)
29-
}
30-
}
31-
}
32-
}
18+
// Bind layout with ViewModel
19+
binding.vm = gpt2
3320

34-
override fun onDestroy() {
35-
super.onDestroy()
36-
handlerThread.quit()
21+
// LiveData needs the lifecycle owner
22+
binding.lifecycleOwner = this
3723
}
3824
}

gpt2/src/main/java/co/huggingface/android_transformers/gpt2/ml/GPT2Client.kt

Lines changed: 65 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
package co.huggingface.android_transformers.gpt2.ml
22

33
import android.app.Application
4+
import android.text.Spannable
5+
import android.text.SpannableStringBuilder
46
import android.util.JsonReader
5-
import androidx.lifecycle.AndroidViewModel
6-
import androidx.lifecycle.liveData
7-
import androidx.lifecycle.viewModelScope
7+
import android.util.Log
8+
import android.widget.TextView
9+
import androidx.core.content.res.ResourcesCompat
10+
import androidx.databinding.BindingAdapter
11+
import androidx.lifecycle.*
12+
import co.huggingface.android_transformers.gpt2.R
813
import co.huggingface.android_transformers.gpt2.tokenization.GPT2Tokenizer
9-
import kotlinx.coroutines.Dispatchers
14+
import kotlinx.coroutines.*
1015
import org.tensorflow.lite.Interpreter
1116
import java.io.BufferedReader
1217
import java.io.FileInputStream
@@ -23,35 +28,66 @@ private const val NUM_LITE_THREADS = 4
2328
private const val MODEL_PATH = "model.tflite"
2429
private const val VOCAB_PATH = "gpt2-vocab.json"
2530
private const val MERGES_PATH = "gpt2-merges.txt"
31+
private const val TAG = "GPT2Client"
2632

2733
private typealias Predictions = Array<Array<FloatArray>>
2834

2935
enum class GPT2StrategyEnum { GREEDY, TOPK }
3036
data class GPT2Strategy(val strategy: GPT2StrategyEnum, val value: Int = 0)
3137

3238
class GPT2Client(application: Application) : AndroidViewModel(application) {
39+
private val initJob: Job
40+
private var autocompleteJob: Job? = null
3341
private lateinit var tokenizer: GPT2Tokenizer
3442
private lateinit var tflite: Interpreter
3543

44+
private val prompts = arrayOf(
45+
"Before boarding your rocket to Mars, remember to pack these items",
46+
"In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.",
47+
"Legolas and Gimli advanced on the orcs, raising their weapons with a harrowing war cry.",
48+
"Today, scientists confirmed the worst possible outcome: the massive asteroid will collide with Earth",
49+
"Hugging Face is a company that releases awesome projects in machine learning because"
50+
)
51+
52+
private val _prompt = MutableLiveData(prompts[3])
53+
val prompt: LiveData<String> = _prompt
54+
55+
private val _completion = MutableLiveData("")
56+
val completion: LiveData<String> = _completion
57+
3658
var strategy = GPT2Strategy(GPT2StrategyEnum.TOPK, 40)
3759

38-
fun init() {
39-
if (!::tokenizer.isInitialized) {
60+
init {
61+
initJob = viewModelScope.launch {
4062
val encoder = loadEncoder()
4163
val decoder = encoder.entries.associateBy({ it.value }, { it.key })
4264
val bpeRanks = loadBpeRanks()
4365

4466
tokenizer = GPT2Tokenizer(encoder, decoder, bpeRanks)
67+
tflite = loadModel()
4568
}
69+
}
70+
71+
override fun onCleared() {
72+
super.onCleared()
73+
tflite.close()
74+
}
4675

47-
if (!::tflite.isInitialized) {
48-
tflite = loadModel()
76+
fun launchAutocomplete() {
77+
autocompleteJob = viewModelScope.launch {
78+
initJob.join()
79+
autocompleteJob?.cancelAndJoin()
80+
_completion.value = ""
81+
generate("My name is")
4982
}
5083
}
5184

52-
fun generate(text: String, nbTokens: Int = 10) = liveData<String>(
53-
viewModelScope.coroutineContext+Dispatchers.Default) {
85+
fun refreshPrompt() {
86+
_prompt.value = prompts.random()
87+
launchAutocomplete()
88+
}
5489

90+
private suspend fun generate(text: String, nbTokens: Int = 50) = withContext(Dispatchers.Default) {
5591
val tokens = tokenizer.encode(text)
5692
repeat (nbTokens) {
5793
val maxTokens = tokens.takeLast(SEQUENCE_LENGTH).toIntArray()
@@ -86,13 +122,15 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
86122

87123
tokens.add(nextToken)
88124
val decodedToken = tokenizer.decode(listOf(nextToken))
89-
emit(decodedToken)
125+
_completion.postValue(_completion.value + decodedToken)
126+
127+
yield()
90128
}
91129
}
92130

93-
private fun loadModel(): Interpreter {
131+
private suspend fun loadModel(): Interpreter = withContext(Dispatchers.IO) {
94132
val assetFileDescriptor = getApplication<Application>().assets.openFd(MODEL_PATH)
95-
return assetFileDescriptor.use {
133+
assetFileDescriptor.use {
96134
val fileChannel = FileInputStream(assetFileDescriptor.fileDescriptor).channel
97135
val modelBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, it.startOffset, it.declaredLength)
98136

@@ -102,8 +140,8 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
102140
}
103141
}
104142

105-
private fun loadEncoder(): Map<String, Int> {
106-
return hashMapOf<String, Int>().apply {
143+
private suspend fun loadEncoder(): Map<String, Int> = withContext(Dispatchers.IO) {
144+
hashMapOf<String, Int>().apply {
107145
val vocabStream = getApplication<Application>().assets.open(VOCAB_PATH)
108146
vocabStream.use {
109147
val vocabReader = JsonReader(InputStreamReader(it, "UTF-8"))
@@ -118,8 +156,8 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
118156
}
119157
}
120158

121-
private fun loadBpeRanks(): Map<Pair<String, String>, Int> {
122-
return hashMapOf<Pair<String, String>, Int>().apply {
159+
private suspend fun loadBpeRanks():Map<Pair<String, String>, Int> = withContext(Dispatchers.IO) {
160+
hashMapOf<Pair<String, String>, Int>().apply {
123161
val mergesStream = getApplication<Application>().assets.open(MERGES_PATH)
124162
mergesStream.use { stream ->
125163
val mergesReader = BufferedReader(InputStreamReader(stream))
@@ -136,7 +174,7 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
136174
}
137175

138176
private fun randomIndex(probs: List<Float>): Int {
139-
val rnd = Random.nextFloat()
177+
val rnd = probs.sum() * Random.nextFloat()
140178
var acc = 0f
141179

142180
probs.forEachIndexed { i, fl ->
@@ -164,3 +202,12 @@ private fun FloatArray.argmax(): Int {
164202

165203
return bestIndex
166204
}
205+
206+
@BindingAdapter("prompt", "completion")
207+
fun TextView.formatCompletion(prompt: String, completion: String) {
208+
val str = SpannableStringBuilder(prompt + completion)
209+
val bgCompletionColor = ResourcesCompat.getColor(resources, R.color.colorPrimary, context.theme)
210+
str.setSpan(android.text.style.BackgroundColorSpan(bgCompletionColor), prompt.length, str.length, Spannable.SPAN_EXCLUSIVE_EXCLUSIVE)
211+
212+
text = str
213+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
<vector xmlns:android="http://schemas.android.com/apk/res/android"
2+
android:width="24dp"
3+
android:height="24dp"
4+
android:viewportWidth="24.0"
5+
android:viewportHeight="24.0">
6+
<path
7+
android:fillColor="#FF000000"
8+
android:pathData="M10.59,9.17L5.41,4 4,5.41l5.17,5.17 1.42,-1.41zM14.5,4l2.04,2.04L4,18.59 5.41,20 17.96,7.46 20,9.5L20,4h-5.5zM14.83,13.41l-1.41,1.41 3.13,3.13L14.5,20L20,20v-5.5l-2.04,2.04 -3.13,-3.13z"/>
9+
</vector>
Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,61 @@
11
<?xml version="1.0" encoding="utf-8"?>
2-
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
2+
3+
<layout
4+
xmlns:android="http://schemas.android.com/apk/res/android"
35
xmlns:app="http://schemas.android.com/apk/res-auto"
4-
xmlns:tools="http://schemas.android.com/tools"
5-
android:layout_width="match_parent"
6-
android:layout_height="match_parent"
7-
tools:context=".MainActivity">
8-
9-
<TextView
10-
android:layout_width="wrap_content"
11-
android:layout_height="wrap_content"
12-
android:text="Hello World!"
13-
app:layout_constraintBottom_toBottomOf="parent"
14-
app:layout_constraintLeft_toLeftOf="parent"
15-
app:layout_constraintRight_toRightOf="parent"
16-
app:layout_constraintTop_toTopOf="parent" />
17-
18-
</androidx.constraintlayout.widget.ConstraintLayout>
6+
xmlns:tools="http://schemas.android.com/tools">
7+
8+
<data>
9+
<variable
10+
name="vm"
11+
type="co.huggingface.android_transformers.gpt2.ml.GPT2Client"/>
12+
</data>
13+
14+
<androidx.constraintlayout.widget.ConstraintLayout
15+
android:layout_width="match_parent"
16+
android:layout_height="match_parent"
17+
tools:context=".MainActivity">
18+
19+
<com.google.android.material.button.MaterialButton
20+
android:id="@+id/shuffle_button"
21+
style="@style/Widget.MaterialComponents.Button.TextButton.Icon"
22+
android:layout_width="wrap_content"
23+
android:layout_height="wrap_content"
24+
android:text="Shuffle prompt text"
25+
app:layout_constraintTop_toTopOf="parent"
26+
app:layout_constraintLeft_toLeftOf="parent"
27+
app:layout_constraintRight_toRightOf="parent"
28+
android:layout_marginTop="30dp"
29+
app:icon="@drawable/ic_shuffle_black_24dp"
30+
android:onClick="@{() -> vm.refreshPrompt()}"/>
31+
32+
<com.google.android.material.button.MaterialButton
33+
android:id="@+id/autocomplete_button"
34+
android:layout_width="wrap_content"
35+
android:layout_height="wrap_content"
36+
android:textSize="16sp"
37+
android:text="Trigger autocomplete"
38+
app:layout_constraintTop_toBottomOf="@id/shuffle_button"
39+
app:layout_constraintLeft_toLeftOf="parent"
40+
app:layout_constraintRight_toRightOf="parent"
41+
android:onClick="@{() -> vm.launchAutocomplete()}"/>
42+
43+
<TextView
44+
android:id="@+id/prompt"
45+
android:layout_width="wrap_content"
46+
android:layout_height="wrap_content"
47+
android:layout_marginTop="50dp"
48+
android:padding="30dp"
49+
android:textColor="@color/colorOnPrimary"
50+
android:textSize="16sp"
51+
app:completion="@{vm.completion}"
52+
app:layout_constraintLeft_toLeftOf="parent"
53+
app:layout_constraintRight_toRightOf="parent"
54+
app:layout_constraintTop_toBottomOf="@id/autocomplete_button"
55+
app:lineHeight="22sp"
56+
app:prompt="@{vm.prompt}"
57+
tools:text="@tools:sample/lorem/random" />
58+
59+
</androidx.constraintlayout.widget.ConstraintLayout>
60+
61+
</layout>
Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
<?xml version="1.0" encoding="utf-8"?>
22
<resources>
3-
<color name="colorPrimary">#008577</color>
4-
<color name="colorPrimaryDark">#00574B</color>
5-
<color name="colorAccent">#D81B60</color>
3+
<color name="colorPrimary">#fbc02d</color>
4+
<color name="colorPrimaryVariant">#fff263</color>
5+
<color name="colorPrimaryDark">#c49000</color>
6+
<color name="colorSecondary">#3f51b5</color>
7+
<color name="colorSecondaryVariant">#757de8</color>
8+
<color name="colorSecondaryDark">#002984</color>
9+
<color name="colorOnPrimary">#000000</color>
10+
<color name="colorOnSecondary">#ffffff</color>
611
</resources>
Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
<resources>
22

33
<!-- Base application theme. -->
4-
<style name="AppTheme" parent="Theme.AppCompat.Light.DarkActionBar">
4+
<style name="AppTheme" parent="Theme.MaterialComponents.DayNight.NoActionBar">
55
<!-- Customize your theme here. -->
66
<item name="colorPrimary">@color/colorPrimary</item>
77
<item name="colorPrimaryDark">@color/colorPrimaryDark</item>
8-
<item name="colorAccent">@color/colorAccent</item>
8+
<item name="colorPrimaryVariant">@color/colorPrimaryVariant</item>
9+
10+
<item name="colorSecondary">@color/colorSecondary</item>
11+
<item name="colorSecondaryVariant">@color/colorSecondaryVariant</item>
12+
<item name="colorOnPrimary">@color/colorOnPrimary</item>
13+
<item name="colorOnSecondary">@color/colorOnSecondary</item>
14+
<!-- <item name="colorAccent">@color/colorAccent</item>-->
915
</style>
1016

1117
</resources>

0 commit comments

Comments
 (0)