import android.widget.EditText
import android.widget.TextView
import android.widget.Toast
+import androidx.activity.addCallback
import androidx.activity.enableEdgeToEdge
import androidx.activity.result.contract.ActivityResultContracts
import androidx.appcompat.app.AppCompatActivity
import com.arm.aichat.gguf.GgufMetadataReader
import com.google.android.material.floatingactionbutton.FloatingActionButton
import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
// Arm AI Chat inference engine
private lateinit var engine: InferenceEngine
+ private var generationJob: Job? = null
// Conversation states
private var isModelReady = false
super.onCreate(savedInstanceState)
enableEdgeToEdge()
setContentView(R.layout.activity_main)
+ // View model boilerplate and state management is out of this basic sample's scope
+ onBackPressedDispatcher.addCallback { Log.w(TAG, "Ignore back press for simplicity") }
// Find views
ggufTv = findViewById(R.id.gguf)
messagesRv = findViewById(R.id.messages)
- messagesRv.layoutManager = LinearLayoutManager(this)
+ messagesRv.layoutManager = LinearLayoutManager(this).apply { stackFromEnd = true }
messagesRv.adapter = messageAdapter
userInputEt = findViewById(R.id.user_input)
userActionFab = findViewById(R.id.fab)
* Validate and send the user message into [InferenceEngine]
*/
private fun handleUserInput() {
- userInputEt.text.toString().also { userSsg ->
- if (userSsg.isEmpty()) {
+ userInputEt.text.toString().also { userMsg ->
+ if (userMsg.isEmpty()) {
Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show()
} else {
userInputEt.text = null
+ userInputEt.isEnabled = false
userActionFab.isEnabled = false
// Update message states
- messages.add(Message(UUID.randomUUID().toString(), userSsg, true))
+ messages.add(Message(UUID.randomUUID().toString(), userMsg, true))
lastAssistantMsg.clear()
messages.add(Message(UUID.randomUUID().toString(), lastAssistantMsg.toString(), false))
- lifecycleScope.launch(Dispatchers.Default) {
- engine.sendUserPrompt(userSsg)
+ generationJob = lifecycleScope.launch(Dispatchers.Default) {
+ engine.sendUserPrompt(userMsg)
.onCompletion {
withContext(Dispatchers.Main) {
+ userInputEt.isEnabled = true
userActionFab.isEnabled = true
}
}.collect { token ->
- val messageCount = messages.size
- check(messageCount > 0 && !messages[messageCount - 1].isUser)
+ withContext(Dispatchers.Main) {
+ val messageCount = messages.size
+ check(messageCount > 0 && !messages[messageCount - 1].isUser)
- messages.removeAt(messageCount - 1).copy(
- content = lastAssistantMsg.append(token).toString()
- ).let { messages.add(it) }
+ messages.removeAt(messageCount - 1).copy(
+ content = lastAssistantMsg.append(token).toString()
+ ).let { messages.add(it) }
- withContext(Dispatchers.Main) {
messageAdapter.notifyItemChanged(messages.size - 1)
}
}
/**
* Run a benchmark with the model file
*/
+ @Deprecated("This benchmark doesn't accurately indicate GUI performance expected by app developers")
private suspend fun runBenchmark(modelName: String, modelFile: File) =
withContext(Dispatchers.Default) {
Log.i(TAG, "Starts benchmarking $modelName")
if (!it.exists()) { it.mkdir() }
}
+ override fun onStop() {
+ generationJob?.cancel()
+ super.onStop()
+ }
+
+ override fun onDestroy() {
+ engine.destroy()
+ super.onDestroy()
+ }
+
companion object {
private val TAG = MainActivity::class.java.simpleName
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
+import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.launch
+import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import java.io.File
import java.io.IOException
private val _state =
MutableStateFlow<InferenceEngine.State>(InferenceEngine.State.Uninitialized)
- override val state: StateFlow<InferenceEngine.State> = _state
+ override val state: StateFlow<InferenceEngine.State> = _state.asStateFlow()
private var _readyForSystemPrompt = false
+ @Volatile
+ private var _cancelGeneration = false
/**
* Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
}
Log.i(TAG, "Model loaded!")
_readyForSystemPrompt = true
+
+ _cancelGeneration = false
_state.value = InferenceEngine.State.ModelReady
} catch (e: Exception) {
Log.e(TAG, (e.message ?: "Error loading model") + "\n" + pathToModel, e)
Log.i(TAG, "User prompt processed. Generating assistant prompt...")
_state.value = InferenceEngine.State.Generating
- while (true) {
+ while (!_cancelGeneration) {
generateNextToken()?.let { utf8token ->
if (utf8token.isNotEmpty()) emit(utf8token)
} ?: break
}
- Log.i(TAG, "Assistant generation complete. Awaiting user prompt...")
+ if (_cancelGeneration) {
+ Log.i(TAG, "Assistant generation aborted per requested.")
+ } else {
+ Log.i(TAG, "Assistant generation complete. Awaiting user prompt...")
+ }
_state.value = InferenceEngine.State.ModelReady
} catch (e: CancellationException) {
- Log.i(TAG, "Generation cancelled by user.")
+ Log.i(TAG, "Assistant generation's flow collection cancelled.")
_state.value = InferenceEngine.State.ModelReady
throw e
} catch (e: Exception) {
/**
* Unloads the model and frees resources, or reset error states
*/
- override suspend fun cleanUp() =
- withContext(llamaDispatcher) {
+ override fun cleanUp() {
+ _cancelGeneration = true
+ runBlocking(llamaDispatcher) {
when (val state = _state.value) {
is InferenceEngine.State.ModelReady -> {
Log.i(TAG, "Unloading model and free resources...")
else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
}
}
+ }
/**
* Cancel all ongoing coroutines and free GGML backends
*/
override fun destroy() {
- _readyForSystemPrompt = false
- llamaScope.cancel()
- when(_state.value) {
- is InferenceEngine.State.Uninitialized -> {}
- is InferenceEngine.State.Initialized -> shutdown()
- else -> { unload(); shutdown() }
+ _cancelGeneration = true
+ runBlocking(llamaDispatcher) {
+ _readyForSystemPrompt = false
+ when(_state.value) {
+ is InferenceEngine.State.Uninitialized -> {}
+ is InferenceEngine.State.Initialized -> shutdown()
+ else -> { unload(); shutdown() }
+ }
}
+ llamaScope.cancel()
}
}