11package co.huggingface.android_transformers.gpt2.ml
22
33import android.app.Application
4+ import android.text.Spannable
5+ import android.text.SpannableStringBuilder
46import android.util.JsonReader
5- import androidx.lifecycle.AndroidViewModel
6- import androidx.lifecycle.liveData
7- import androidx.lifecycle.viewModelScope
7+ import android.util.Log
8+ import android.widget.TextView
9+ import androidx.core.content.res.ResourcesCompat
10+ import androidx.databinding.BindingAdapter
11+ import androidx.lifecycle.*
12+ import co.huggingface.android_transformers.gpt2.R
813import co.huggingface.android_transformers.gpt2.tokenization.GPT2Tokenizer
9- import kotlinx.coroutines.Dispatchers
14+ import kotlinx.coroutines.*
1015import org.tensorflow.lite.Interpreter
1116import java.io.BufferedReader
1217import java.io.FileInputStream
@@ -23,35 +28,66 @@ private const val NUM_LITE_THREADS = 4
2328private const val MODEL_PATH = " model.tflite"
2429private const val VOCAB_PATH = " gpt2-vocab.json"
2530private const val MERGES_PATH = " gpt2-merges.txt"
31+ private const val TAG = " GPT2Client"
2632
2733private typealias Predictions = Array <Array <FloatArray >>
2834
2935enum class GPT2StrategyEnum { GREEDY , TOPK }
3036data class GPT2Strategy (val strategy : GPT2StrategyEnum , val value : Int = 0 )
3137
3238class GPT2Client (application : Application ) : AndroidViewModel(application) {
39+ private val initJob: Job
40+ private var autocompleteJob: Job ? = null
3341 private lateinit var tokenizer: GPT2Tokenizer
3442 private lateinit var tflite: Interpreter
3543
44+ private val prompts = arrayOf(
45+ " Before boarding your rocket to Mars, remember to pack these items" ,
46+ " In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English." ,
47+ " Legolas and Gimli advanced on the orcs, raising their weapons with a harrowing war cry." ,
48+ " Today, scientists confirmed the worst possible outcome: the massive asteroid will collide with Earth" ,
49+ " Hugging Face is a company that releases awesome projects in machine learning because"
50+ )
51+
52+ private val _prompt = MutableLiveData (prompts[3 ])
53+ val prompt: LiveData <String > = _prompt
54+
55+ private val _completion = MutableLiveData (" " )
56+ val completion: LiveData <String > = _completion
57+
3658 var strategy = GPT2Strategy (GPT2StrategyEnum .TOPK , 40 )
3759
38- fun init () {
39- if ( ! ::tokenizer.isInitialized) {
60+ init {
61+ initJob = viewModelScope.launch {
4062 val encoder = loadEncoder()
4163 val decoder = encoder.entries.associateBy({ it.value }, { it.key })
4264 val bpeRanks = loadBpeRanks()
4365
4466 tokenizer = GPT2Tokenizer (encoder, decoder, bpeRanks)
67+ tflite = loadModel()
4568 }
69+ }
70+
71+ override fun onCleared () {
72+ super .onCleared()
73+ tflite.close()
74+ }
4675
47- if (! ::tflite.isInitialized) {
48- tflite = loadModel()
76+ fun launchAutocomplete () {
77+ autocompleteJob = viewModelScope.launch {
78+ initJob.join()
79+ autocompleteJob?.cancelAndJoin()
80+ _completion .value = " "
81+ generate(" My name is" )
4982 }
5083 }
5184
52- fun generate (text : String , nbTokens : Int = 10) = liveData<String >(
53- viewModelScope.coroutineContext+ Dispatchers .Default ) {
85+ fun refreshPrompt () {
86+ _prompt .value = prompts.random()
87+ launchAutocomplete()
88+ }
5489
90+ private suspend fun generate (text : String , nbTokens : Int = 50) = withContext(Dispatchers .Default ) {
5591 val tokens = tokenizer.encode(text)
5692 repeat (nbTokens) {
5793 val maxTokens = tokens.takeLast(SEQUENCE_LENGTH ).toIntArray()
@@ -86,13 +122,15 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
86122
87123 tokens.add(nextToken)
88124 val decodedToken = tokenizer.decode(listOf (nextToken))
89- emit(decodedToken)
125+ _completion .postValue(_completion .value + decodedToken)
126+
127+ yield ()
90128 }
91129 }
92130
93- private fun loadModel (): Interpreter {
131+ private suspend fun loadModel (): Interpreter = withContext( Dispatchers . IO ) {
94132 val assetFileDescriptor = getApplication<Application >().assets.openFd(MODEL_PATH )
95- return assetFileDescriptor.use {
133+ assetFileDescriptor.use {
96134 val fileChannel = FileInputStream (assetFileDescriptor.fileDescriptor).channel
97135 val modelBuffer = fileChannel.map(FileChannel .MapMode .READ_ONLY , it.startOffset, it.declaredLength)
98136
@@ -102,8 +140,8 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
102140 }
103141 }
104142
105- private fun loadEncoder (): Map <String , Int > {
106- return hashMapOf<String , Int >().apply {
143+ private suspend fun loadEncoder (): Map <String , Int > = withContext( Dispatchers . IO ) {
144+ hashMapOf<String , Int >().apply {
107145 val vocabStream = getApplication<Application >().assets.open(VOCAB_PATH )
108146 vocabStream.use {
109147 val vocabReader = JsonReader (InputStreamReader (it, " UTF-8" ))
@@ -118,8 +156,8 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
118156 }
119157 }
120158
121- private fun loadBpeRanks (): Map <Pair <String , String >, Int> {
122- return hashMapOf<Pair <String , String >, Int > ().apply {
159+ private suspend fun loadBpeRanks ():Map <Pair <String , String >, Int> = withContext( Dispatchers . IO ) {
160+ hashMapOf<Pair <String , String >, Int > ().apply {
123161 val mergesStream = getApplication<Application >().assets.open(MERGES_PATH )
124162 mergesStream.use { stream ->
125163 val mergesReader = BufferedReader (InputStreamReader (stream))
@@ -136,7 +174,7 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
136174}
137175
138176private fun randomIndex (probs : List <Float >): Int {
139- val rnd = Random .nextFloat()
177+ val rnd = probs.sum() * Random .nextFloat()
140178 var acc = 0f
141179
142180 probs.forEachIndexed { i, fl ->
@@ -164,3 +202,12 @@ private fun FloatArray.argmax(): Int {
164202
165203 return bestIndex
166204}
205+
206+ @BindingAdapter(" prompt" , " completion" )
207+ fun TextView.formatCompletion (prompt : String , completion : String ) {
208+ val str = SpannableStringBuilder (prompt + completion)
209+ val bgCompletionColor = ResourcesCompat .getColor(resources, R .color.colorPrimary, context.theme)
210+ str.setSpan(android.text.style.BackgroundColorSpan (bgCompletionColor), prompt.length, str.length, Spannable .SPAN_EXCLUSIVE_EXCLUSIVE )
211+
212+ text = str
213+ }
0 commit comments