1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 package com.example.executorchllamademo; 10 11 import android.Manifest; 12 import android.app.ActivityManager; 13 import android.app.AlertDialog; 14 import android.content.ContentResolver; 15 import android.content.ContentValues; 16 import android.content.Intent; 17 import android.content.pm.PackageManager; 18 import android.net.Uri; 19 import android.os.Build; 20 import android.os.Bundle; 21 import android.os.Handler; 22 import android.os.Looper; 23 import android.os.Process; 24 import android.provider.MediaStore; 25 import android.system.ErrnoException; 26 import android.system.Os; 27 import android.util.Log; 28 import android.view.View; 29 import android.view.inputmethod.InputMethodManager; 30 import android.widget.EditText; 31 import android.widget.ImageButton; 32 import android.widget.ImageView; 33 import android.widget.LinearLayout; 34 import android.widget.ListView; 35 import android.widget.TextView; 36 import android.widget.Toast; 37 import androidx.activity.result.ActivityResultLauncher; 38 import androidx.activity.result.PickVisualMediaRequest; 39 import androidx.activity.result.contract.ActivityResultContracts; 40 import androidx.annotation.NonNull; 41 import androidx.appcompat.app.AppCompatActivity; 42 import androidx.constraintlayout.widget.ConstraintLayout; 43 import androidx.core.app.ActivityCompat; 44 import androidx.core.content.ContextCompat; 45 import com.google.gson.Gson; 46 import com.google.gson.reflect.TypeToken; 47 import java.lang.reflect.Type; 48 import java.util.ArrayList; 49 import java.util.List; 50 import java.util.concurrent.Executor; 51 import java.util.concurrent.Executors; 52 import org.pytorch.executorch.LlamaCallback; 53 import org.pytorch.executorch.LlamaModule; 54 55 public class MainActivity extends AppCompatActivity implements Runnable, LlamaCallback { 56 private EditText mEditTextMessage; 57 private ImageButton mSendButton; 58 private ImageButton mGalleryButton; 59 private ImageButton mCameraButton; 60 private ListView mMessagesView; 61 private MessageAdapter mMessageAdapter; 62 private LlamaModule mModule = null; 63 private Message mResultMessage = null; 64 private ImageButton mSettingsButton; 65 private TextView mMemoryView; 66 private ActivityResultLauncher<PickVisualMediaRequest> mPickGallery; 67 private ActivityResultLauncher<Uri> mCameraRoll; 68 private List<Uri> mSelectedImageUri; 69 private ConstraintLayout mMediaPreviewConstraintLayout; 70 private LinearLayout mAddMediaLayout; 71 private static final int MAX_NUM_OF_IMAGES = 5; 72 private static final int REQUEST_IMAGE_CAPTURE = 1; 73 private Uri cameraImageUri; 74 private DemoSharedPreferences mDemoSharedPreferences; 75 private SettingsFields mCurrentSettingsFields; 76 private Handler mMemoryUpdateHandler; 77 private Runnable memoryUpdater; 78 private int promptID = 0; 79 private long startPos = 0; 80 private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2; 81 private Executor executor; 82 83 @Override onResult(String result)84 public void onResult(String result) { 85 if (result.equals(PromptFormat.getStopToken(mCurrentSettingsFields.getModelType()))) { 86 return; 87 } 88 if (result.equals("\n\n") || result.equals("\n")) { 89 if (!mResultMessage.getText().isEmpty()) { 90 mResultMessage.appendText(result); 91 run(); 92 } 93 } else { 94 mResultMessage.appendText(result); 95 run(); 96 } 97 } 98 99 @Override onStats(float tps)100 public void onStats(float tps) { 101 runOnUiThread( 102 () -> { 103 if (mResultMessage != null) { 104 mResultMessage.setTokensPerSecond(tps); 105 mMessageAdapter.notifyDataSetChanged(); 106 } 107 }); 108 } 109 setLocalModel(String modelPath, String tokenizerPath, float temperature)110 private void setLocalModel(String modelPath, String tokenizerPath, float temperature) { 111 Message modelLoadingMessage = new Message("Loading model...", false, MessageType.SYSTEM, 0); 112 ETLogging.getInstance().log("Loading model " + modelPath + " with tokenizer " + tokenizerPath); 113 runOnUiThread( 114 () -> { 115 mSendButton.setEnabled(false); 116 mMessageAdapter.add(modelLoadingMessage); 117 mMessageAdapter.notifyDataSetChanged(); 118 }); 119 if (mModule != null) { 120 ETLogging.getInstance().log("Start deallocating existing module instance"); 121 mModule.resetNative(); 122 mModule = null; 123 ETLogging.getInstance().log("Completed deallocating existing module instance"); 124 } 125 long runStartTime = System.currentTimeMillis(); 126 mModule = 127 new LlamaModule( 128 ModelUtils.getModelCategory( 129 mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType()), 130 modelPath, 131 tokenizerPath, 132 temperature); 133 int loadResult = mModule.load(); 134 long loadDuration = System.currentTimeMillis() - runStartTime; 135 String modelLoadError = ""; 136 String modelInfo = ""; 137 if (loadResult != 0) { 138 // TODO: Map the error code to a reason to let the user know why model loading failed 139 modelInfo = "*Model could not load (Error Code: " + loadResult + ")*" + "\n"; 140 loadDuration = 0; 141 AlertDialog.Builder builder = new AlertDialog.Builder(this); 142 builder.setTitle("Load failed: " + loadResult); 143 runOnUiThread( 144 () -> { 145 AlertDialog alert = builder.create(); 146 alert.show(); 147 }); 148 } else { 149 String[] segments = modelPath.split("/"); 150 String pteName = segments[segments.length - 1]; 151 segments = tokenizerPath.split("/"); 152 String tokenizerName = segments[segments.length - 1]; 153 modelInfo = 154 "Successfully loaded model. " 155 + pteName 156 + " and tokenizer " 157 + tokenizerName 158 + " in " 159 + (float) loadDuration / 1000 160 + " sec." 161 + " You can send text or image for inference"; 162 163 if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) { 164 ETLogging.getInstance().log("Llava start prefill prompt"); 165 startPos = mModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt(), 0, 1, 0); 166 ETLogging.getInstance().log("Llava completes prefill prompt"); 167 } 168 } 169 170 Message modelLoadedMessage = new Message(modelInfo, false, MessageType.SYSTEM, 0); 171 172 String modelLoggingInfo = 173 modelLoadError 174 + "Model path: " 175 + modelPath 176 + "\nTokenizer path: " 177 + tokenizerPath 178 + "\nBackend: " 179 + mCurrentSettingsFields.getBackendType().toString() 180 + "\nModelType: " 181 + ModelUtils.getModelCategory( 182 mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType()) 183 + "\nTemperature: " 184 + temperature 185 + "\nModel loaded time: " 186 + loadDuration 187 + " ms"; 188 ETLogging.getInstance().log("Load complete. " + modelLoggingInfo); 189 190 runOnUiThread( 191 () -> { 192 mSendButton.setEnabled(true); 193 mMessageAdapter.remove(modelLoadingMessage); 194 mMessageAdapter.add(modelLoadedMessage); 195 mMessageAdapter.notifyDataSetChanged(); 196 }); 197 } 198 loadLocalModelAndParameters( String modelFilePath, String tokenizerFilePath, float temperature)199 private void loadLocalModelAndParameters( 200 String modelFilePath, String tokenizerFilePath, float temperature) { 201 Runnable runnable = 202 new Runnable() { 203 @Override 204 public void run() { 205 setLocalModel(modelFilePath, tokenizerFilePath, temperature); 206 } 207 }; 208 new Thread(runnable).start(); 209 } 210 populateExistingMessages(String existingMsgJSON)211 private void populateExistingMessages(String existingMsgJSON) { 212 Gson gson = new Gson(); 213 Type type = new TypeToken<ArrayList<Message>>() {}.getType(); 214 ArrayList<Message> savedMessages = gson.fromJson(existingMsgJSON, type); 215 for (Message msg : savedMessages) { 216 mMessageAdapter.add(msg); 217 } 218 mMessageAdapter.notifyDataSetChanged(); 219 } 220 setPromptID()221 private int setPromptID() { 222 223 return mMessageAdapter.getMaxPromptID() + 1; 224 } 225 226 @Override onCreate(Bundle savedInstanceState)227 protected void onCreate(Bundle savedInstanceState) { 228 super.onCreate(savedInstanceState); 229 setContentView(R.layout.activity_main); 230 231 if (Build.VERSION.SDK_INT >= 21) { 232 getWindow().setStatusBarColor(ContextCompat.getColor(this, R.color.status_bar)); 233 getWindow().setNavigationBarColor(ContextCompat.getColor(this, R.color.nav_bar)); 234 } 235 236 try { 237 Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true); 238 Os.setenv("LD_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true); 239 } catch (ErrnoException e) { 240 finish(); 241 } 242 243 mEditTextMessage = requireViewById(R.id.editTextMessage); 244 mSendButton = requireViewById(R.id.sendButton); 245 mSendButton.setEnabled(false); 246 mMessagesView = requireViewById(R.id.messages_view); 247 mMessageAdapter = new MessageAdapter(this, R.layout.sent_message, new ArrayList<Message>()); 248 mMessagesView.setAdapter(mMessageAdapter); 249 mDemoSharedPreferences = new DemoSharedPreferences(this.getApplicationContext()); 250 String existingMsgJSON = mDemoSharedPreferences.getSavedMessages(); 251 if (!existingMsgJSON.isEmpty()) { 252 populateExistingMessages(existingMsgJSON); 253 promptID = setPromptID(); 254 } 255 mSettingsButton = requireViewById(R.id.settings); 256 mSettingsButton.setOnClickListener( 257 view -> { 258 Intent myIntent = new Intent(MainActivity.this, SettingsActivity.class); 259 MainActivity.this.startActivity(myIntent); 260 }); 261 262 mCurrentSettingsFields = new SettingsFields(); 263 mMemoryUpdateHandler = new Handler(Looper.getMainLooper()); 264 onModelRunStopped(); 265 setupMediaButton(); 266 setupGalleryPicker(); 267 setupCameraRoll(); 268 startMemoryUpdate(); 269 setupShowLogsButton(); 270 executor = Executors.newSingleThreadExecutor(); 271 } 272 273 @Override onPause()274 protected void onPause() { 275 super.onPause(); 276 mDemoSharedPreferences.addMessages(mMessageAdapter); 277 } 278 279 @Override onResume()280 protected void onResume() { 281 super.onResume(); 282 // Check for if settings parameters have changed 283 Gson gson = new Gson(); 284 String settingsFieldsJSON = mDemoSharedPreferences.getSettings(); 285 if (!settingsFieldsJSON.isEmpty()) { 286 SettingsFields updatedSettingsFields = 287 gson.fromJson(settingsFieldsJSON, SettingsFields.class); 288 if (updatedSettingsFields == null) { 289 // Added this check, because gson.fromJson can return null 290 askUserToSelectModel(); 291 return; 292 } 293 boolean isUpdated = !mCurrentSettingsFields.equals(updatedSettingsFields); 294 boolean isLoadModel = updatedSettingsFields.getIsLoadModel(); 295 setBackendMode(updatedSettingsFields.getBackendType()); 296 if (isUpdated) { 297 if (isLoadModel) { 298 // If users change the model file, but not pressing loadModelButton, we won't load the new 299 // model 300 checkForUpdateAndReloadModel(updatedSettingsFields); 301 } else { 302 askUserToSelectModel(); 303 } 304 305 checkForClearChatHistory(updatedSettingsFields); 306 // Update current to point to the latest 307 mCurrentSettingsFields = new SettingsFields(updatedSettingsFields); 308 } 309 } else { 310 askUserToSelectModel(); 311 } 312 } 313 setBackendMode(BackendType backendType)314 private void setBackendMode(BackendType backendType) { 315 if (backendType.equals(BackendType.XNNPACK) || backendType.equals(BackendType.QUALCOMM)) { 316 setXNNPACKMode(); 317 } else if (backendType.equals(BackendType.MEDIATEK)) { 318 setMediaTekMode(); 319 } 320 } 321 setXNNPACKMode()322 private void setXNNPACKMode() { 323 requireViewById(R.id.addMediaButton).setVisibility(View.VISIBLE); 324 } 325 setMediaTekMode()326 private void setMediaTekMode() { 327 requireViewById(R.id.addMediaButton).setVisibility(View.GONE); 328 } 329 checkForClearChatHistory(SettingsFields updatedSettingsFields)330 private void checkForClearChatHistory(SettingsFields updatedSettingsFields) { 331 if (updatedSettingsFields.getIsClearChatHistory()) { 332 mMessageAdapter.clear(); 333 mMessageAdapter.notifyDataSetChanged(); 334 mDemoSharedPreferences.removeExistingMessages(); 335 // changing to false since chat history has been cleared. 336 updatedSettingsFields.saveIsClearChatHistory(false); 337 mDemoSharedPreferences.addSettings(updatedSettingsFields); 338 } 339 } 340 checkForUpdateAndReloadModel(SettingsFields updatedSettingsFields)341 private void checkForUpdateAndReloadModel(SettingsFields updatedSettingsFields) { 342 // TODO need to add 'load model' in settings and queue loading based on that 343 String modelPath = updatedSettingsFields.getModelFilePath(); 344 String tokenizerPath = updatedSettingsFields.getTokenizerFilePath(); 345 double temperature = updatedSettingsFields.getTemperature(); 346 if (!modelPath.isEmpty() && !tokenizerPath.isEmpty()) { 347 if (updatedSettingsFields.getIsLoadModel() 348 || !modelPath.equals(mCurrentSettingsFields.getModelFilePath()) 349 || !tokenizerPath.equals(mCurrentSettingsFields.getTokenizerFilePath()) 350 || temperature != mCurrentSettingsFields.getTemperature()) { 351 loadLocalModelAndParameters( 352 updatedSettingsFields.getModelFilePath(), 353 updatedSettingsFields.getTokenizerFilePath(), 354 (float) updatedSettingsFields.getTemperature()); 355 updatedSettingsFields.saveLoadModelAction(false); 356 mDemoSharedPreferences.addSettings(updatedSettingsFields); 357 } 358 } else { 359 askUserToSelectModel(); 360 } 361 } 362 askUserToSelectModel()363 private void askUserToSelectModel() { 364 String askLoadModel = 365 "To get started, select your desired model and tokenizer " + "from the top right corner"; 366 Message askLoadModelMessage = new Message(askLoadModel, false, MessageType.SYSTEM, 0); 367 ETLogging.getInstance().log(askLoadModel); 368 runOnUiThread( 369 () -> { 370 mMessageAdapter.add(askLoadModelMessage); 371 mMessageAdapter.notifyDataSetChanged(); 372 }); 373 } 374 setupShowLogsButton()375 private void setupShowLogsButton() { 376 ImageButton showLogsButton = requireViewById(R.id.showLogsButton); 377 showLogsButton.setOnClickListener( 378 view -> { 379 Intent myIntent = new Intent(MainActivity.this, LogsActivity.class); 380 MainActivity.this.startActivity(myIntent); 381 }); 382 } 383 setupMediaButton()384 private void setupMediaButton() { 385 mAddMediaLayout = requireViewById(R.id.addMediaLayout); 386 mAddMediaLayout.setVisibility(View.GONE); // We hide this initially 387 388 ImageButton addMediaButton = requireViewById(R.id.addMediaButton); 389 addMediaButton.setOnClickListener( 390 view -> { 391 mAddMediaLayout.setVisibility(View.VISIBLE); 392 }); 393 394 mGalleryButton = requireViewById(R.id.galleryButton); 395 mGalleryButton.setOnClickListener( 396 view -> { 397 // Launch the photo picker and let the user choose only images. 398 mPickGallery.launch( 399 new PickVisualMediaRequest.Builder() 400 .setMediaType(ActivityResultContracts.PickVisualMedia.ImageOnly.INSTANCE) 401 .build()); 402 }); 403 mCameraButton = requireViewById(R.id.cameraButton); 404 mCameraButton.setOnClickListener( 405 view -> { 406 Log.d("CameraRoll", "Check permission"); 407 if (ContextCompat.checkSelfPermission(MainActivity.this, Manifest.permission.CAMERA) 408 != PackageManager.PERMISSION_GRANTED) { 409 ActivityCompat.requestPermissions( 410 MainActivity.this, 411 new String[] {Manifest.permission.CAMERA}, 412 REQUEST_IMAGE_CAPTURE); 413 } else { 414 launchCamera(); 415 } 416 }); 417 } 418 setupCameraRoll()419 private void setupCameraRoll() { 420 // Registers a camera roll activity launcher. 421 mCameraRoll = 422 registerForActivityResult( 423 new ActivityResultContracts.TakePicture(), 424 result -> { 425 if (result && cameraImageUri != null) { 426 Log.d("CameraRoll", "Photo saved to uri: " + cameraImageUri); 427 mAddMediaLayout.setVisibility(View.GONE); 428 List<Uri> uris = new ArrayList<>(); 429 uris.add(cameraImageUri); 430 showMediaPreview(uris); 431 } else { 432 // Delete the temp image file based on the url since the photo is not successfully 433 // taken 434 if (cameraImageUri != null) { 435 ContentResolver contentResolver = MainActivity.this.getContentResolver(); 436 contentResolver.delete(cameraImageUri, null, null); 437 Log.d("CameraRoll", "No photo taken. Delete temp uri"); 438 } 439 } 440 }); 441 mMediaPreviewConstraintLayout = requireViewById(R.id.mediaPreviewConstraintLayout); 442 ImageButton mediaPreviewCloseButton = requireViewById(R.id.mediaPreviewCloseButton); 443 mediaPreviewCloseButton.setOnClickListener( 444 view -> { 445 mMediaPreviewConstraintLayout.setVisibility(View.GONE); 446 mSelectedImageUri = null; 447 }); 448 449 ImageButton addMoreImageButton = requireViewById(R.id.addMoreImageButton); 450 addMoreImageButton.setOnClickListener( 451 view -> { 452 Log.d("addMore", "clicked"); 453 mMediaPreviewConstraintLayout.setVisibility(View.GONE); 454 // Direct user to select type of input 455 mCameraButton.callOnClick(); 456 }); 457 } 458 updateMemoryUsage()459 private String updateMemoryUsage() { 460 ActivityManager.MemoryInfo memoryInfo = new ActivityManager.MemoryInfo(); 461 ActivityManager activityManager = (ActivityManager) getSystemService(ACTIVITY_SERVICE); 462 if (activityManager == null) { 463 return "---"; 464 } 465 activityManager.getMemoryInfo(memoryInfo); 466 long totalMem = memoryInfo.totalMem / (1024 * 1024); 467 long availableMem = memoryInfo.availMem / (1024 * 1024); 468 long usedMem = totalMem - availableMem; 469 return usedMem + "MB"; 470 } 471 startMemoryUpdate()472 private void startMemoryUpdate() { 473 mMemoryView = requireViewById(R.id.ram_usage_live); 474 memoryUpdater = 475 new Runnable() { 476 @Override 477 public void run() { 478 mMemoryView.setText(updateMemoryUsage()); 479 mMemoryUpdateHandler.postDelayed(this, 1000); 480 } 481 }; 482 mMemoryUpdateHandler.post(memoryUpdater); 483 } 484 485 @Override onRequestPermissionsResult( int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults)486 public void onRequestPermissionsResult( 487 int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) { 488 super.onRequestPermissionsResult(requestCode, permissions, grantResults); 489 if (requestCode == REQUEST_IMAGE_CAPTURE && grantResults.length != 0) { 490 if (grantResults[0] == PackageManager.PERMISSION_GRANTED) { 491 launchCamera(); 492 } else if (grantResults[0] == PackageManager.PERMISSION_DENIED) { 493 Log.d("CameraRoll", "Permission denied"); 494 } 495 } 496 } 497 launchCamera()498 private void launchCamera() { 499 ContentValues values = new ContentValues(); 500 values.put(MediaStore.Images.Media.TITLE, "New Picture"); 501 values.put(MediaStore.Images.Media.DESCRIPTION, "From Camera"); 502 values.put(MediaStore.Images.Media.RELATIVE_PATH, "DCIM/Camera/"); 503 cameraImageUri = 504 MainActivity.this 505 .getContentResolver() 506 .insert(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, values); 507 mCameraRoll.launch(cameraImageUri); 508 } 509 setupGalleryPicker()510 private void setupGalleryPicker() { 511 // Registers a photo picker activity launcher in single-select mode. 512 mPickGallery = 513 registerForActivityResult( 514 new ActivityResultContracts.PickMultipleVisualMedia(MAX_NUM_OF_IMAGES), 515 uris -> { 516 if (!uris.isEmpty()) { 517 Log.d("PhotoPicker", "Selected URIs: " + uris); 518 mAddMediaLayout.setVisibility(View.GONE); 519 for (Uri uri : uris) { 520 MainActivity.this 521 .getContentResolver() 522 .takePersistableUriPermission(uri, Intent.FLAG_GRANT_READ_URI_PERMISSION); 523 } 524 showMediaPreview(uris); 525 } else { 526 Log.d("PhotoPicker", "No media selected"); 527 } 528 }); 529 530 mMediaPreviewConstraintLayout = requireViewById(R.id.mediaPreviewConstraintLayout); 531 ImageButton mediaPreviewCloseButton = requireViewById(R.id.mediaPreviewCloseButton); 532 mediaPreviewCloseButton.setOnClickListener( 533 view -> { 534 mMediaPreviewConstraintLayout.setVisibility(View.GONE); 535 mSelectedImageUri = null; 536 }); 537 538 ImageButton addMoreImageButton = requireViewById(R.id.addMoreImageButton); 539 addMoreImageButton.setOnClickListener( 540 view -> { 541 Log.d("addMore", "clicked"); 542 mMediaPreviewConstraintLayout.setVisibility(View.GONE); 543 mGalleryButton.callOnClick(); 544 }); 545 } 546 getProcessedImagesForModel(List<Uri> uris)547 private List<ETImage> getProcessedImagesForModel(List<Uri> uris) { 548 List<ETImage> imageList = new ArrayList<>(); 549 if (uris != null) { 550 uris.forEach( 551 (uri) -> { 552 imageList.add(new ETImage(this.getContentResolver(), uri)); 553 }); 554 } 555 return imageList; 556 } 557 showMediaPreview(List<Uri> uris)558 private void showMediaPreview(List<Uri> uris) { 559 if (mSelectedImageUri == null) { 560 mSelectedImageUri = uris; 561 } else { 562 mSelectedImageUri.addAll(uris); 563 } 564 565 if (mSelectedImageUri.size() > MAX_NUM_OF_IMAGES) { 566 mSelectedImageUri = mSelectedImageUri.subList(0, MAX_NUM_OF_IMAGES); 567 Toast.makeText( 568 this, "Only max " + MAX_NUM_OF_IMAGES + " images are allowed", Toast.LENGTH_SHORT) 569 .show(); 570 } 571 Log.d("mSelectedImageUri", mSelectedImageUri.size() + " " + mSelectedImageUri); 572 573 mMediaPreviewConstraintLayout.setVisibility(View.VISIBLE); 574 575 List<ImageView> imageViews = new ArrayList<ImageView>(); 576 577 // Pre-populate all the image views that are available from the layout (currently max 5) 578 imageViews.add(requireViewById(R.id.mediaPreviewImageView1)); 579 imageViews.add(requireViewById(R.id.mediaPreviewImageView2)); 580 imageViews.add(requireViewById(R.id.mediaPreviewImageView3)); 581 imageViews.add(requireViewById(R.id.mediaPreviewImageView4)); 582 imageViews.add(requireViewById(R.id.mediaPreviewImageView5)); 583 584 // Hide all the image views (reset state) 585 for (int i = 0; i < imageViews.size(); i++) { 586 imageViews.get(i).setVisibility(View.GONE); 587 } 588 589 // Only show/render those that have proper Image URIs 590 for (int i = 0; i < mSelectedImageUri.size(); i++) { 591 imageViews.get(i).setVisibility(View.VISIBLE); 592 imageViews.get(i).setImageURI(mSelectedImageUri.get(i)); 593 } 594 595 // For LLava, we want to call prefill_image as soon as an image is selected 596 // Llava only support 1 image for now 597 if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) { 598 List<ETImage> processedImageList = getProcessedImagesForModel(mSelectedImageUri); 599 if (!processedImageList.isEmpty()) { 600 mMessageAdapter.add( 601 new Message("Llava - Starting image Prefill.", false, MessageType.SYSTEM, 0)); 602 mMessageAdapter.notifyDataSetChanged(); 603 Runnable runnable = 604 () -> { 605 Process.setThreadPriority(Process.THREAD_PRIORITY_MORE_FAVORABLE); 606 ETLogging.getInstance().log("Starting runnable prefill image"); 607 ETImage img = processedImageList.get(0); 608 ETLogging.getInstance().log("Llava start prefill image"); 609 startPos = 610 mModule.prefillImages( 611 img.getInts(), 612 img.getWidth(), 613 img.getHeight(), 614 ModelUtils.VISION_MODEL_IMAGE_CHANNELS, 615 startPos); 616 }; 617 executor.execute(runnable); 618 } 619 } 620 } 621 addSelectedImagesToChatThread(List<Uri> selectedImageUri)622 private void addSelectedImagesToChatThread(List<Uri> selectedImageUri) { 623 if (selectedImageUri == null) { 624 return; 625 } 626 mMediaPreviewConstraintLayout.setVisibility(View.GONE); 627 for (int i = 0; i < selectedImageUri.size(); i++) { 628 Uri imageURI = selectedImageUri.get(i); 629 Log.d("image uri ", "test " + imageURI.getPath()); 630 mMessageAdapter.add(new Message(imageURI.toString(), true, MessageType.IMAGE, 0)); 631 } 632 mMessageAdapter.notifyDataSetChanged(); 633 } 634 getConversationHistory()635 private String getConversationHistory() { 636 String conversationHistory = ""; 637 638 ArrayList<Message> conversations = 639 mMessageAdapter.getRecentSavedTextMessages(CONVERSATION_HISTORY_MESSAGE_LOOKBACK); 640 if (conversations.isEmpty()) { 641 return conversationHistory; 642 } 643 644 int prevPromptID = conversations.get(0).getPromptID(); 645 String conversationFormat = 646 PromptFormat.getConversationFormat(mCurrentSettingsFields.getModelType()); 647 String format = conversationFormat; 648 for (int i = 0; i < conversations.size(); i++) { 649 Message conversation = conversations.get(i); 650 int currentPromptID = conversation.getPromptID(); 651 if (currentPromptID != prevPromptID) { 652 conversationHistory = conversationHistory + format; 653 format = conversationFormat; 654 prevPromptID = currentPromptID; 655 } 656 if (conversation.getIsSent()) { 657 format = format.replace(PromptFormat.USER_PLACEHOLDER, conversation.getText()); 658 } else { 659 format = format.replace(PromptFormat.ASSISTANT_PLACEHOLDER, conversation.getText()); 660 } 661 } 662 conversationHistory = conversationHistory + format; 663 664 return conversationHistory; 665 } 666 getTotalFormattedPrompt(String conversationHistory, String rawPrompt)667 private String getTotalFormattedPrompt(String conversationHistory, String rawPrompt) { 668 if (conversationHistory.isEmpty()) { 669 return mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt); 670 } 671 672 return mCurrentSettingsFields.getFormattedSystemPrompt() 673 + conversationHistory 674 + mCurrentSettingsFields.getFormattedUserPrompt(rawPrompt); 675 } 676 onModelRunStarted()677 private void onModelRunStarted() { 678 mSendButton.setClickable(false); 679 mSendButton.setImageResource(R.drawable.baseline_stop_24); 680 mSendButton.setOnClickListener( 681 view -> { 682 mModule.stop(); 683 }); 684 } 685 onModelRunStopped()686 private void onModelRunStopped() { 687 mSendButton.setClickable(true); 688 mSendButton.setImageResource(R.drawable.baseline_send_24); 689 mSendButton.setOnClickListener( 690 view -> { 691 try { 692 InputMethodManager imm = (InputMethodManager) getSystemService(INPUT_METHOD_SERVICE); 693 imm.hideSoftInputFromWindow(getCurrentFocus().getWindowToken(), 0); 694 } catch (Exception e) { 695 ETLogging.getInstance().log("Keyboard dismissal error: " + e.getMessage()); 696 } 697 addSelectedImagesToChatThread(mSelectedImageUri); 698 String finalPrompt; 699 String rawPrompt = mEditTextMessage.getText().toString(); 700 if (ModelUtils.getModelCategory( 701 mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType()) 702 == ModelUtils.VISION_MODEL) { 703 finalPrompt = mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt); 704 } else { 705 finalPrompt = getTotalFormattedPrompt(getConversationHistory(), rawPrompt); 706 } 707 // We store raw prompt into message adapter, because we don't want to show the extra 708 // tokens from system prompt 709 mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, promptID)); 710 mMessageAdapter.notifyDataSetChanged(); 711 mEditTextMessage.setText(""); 712 mResultMessage = new Message("", false, MessageType.TEXT, promptID); 713 mMessageAdapter.add(mResultMessage); 714 // Scroll to bottom of the list 715 mMessagesView.smoothScrollToPosition(mMessageAdapter.getCount() - 1); 716 // After images are added to prompt and chat thread, we clear the imageURI list 717 // Note: This has to be done after imageURIs are no longer needed by LlamaModule 718 mSelectedImageUri = null; 719 promptID++; 720 Runnable runnable = 721 new Runnable() { 722 @Override 723 public void run() { 724 Process.setThreadPriority(Process.THREAD_PRIORITY_MORE_FAVORABLE); 725 ETLogging.getInstance().log("starting runnable generate()"); 726 runOnUiThread( 727 new Runnable() { 728 @Override 729 public void run() { 730 onModelRunStarted(); 731 } 732 }); 733 long generateStartTime = System.currentTimeMillis(); 734 if (ModelUtils.getModelCategory( 735 mCurrentSettingsFields.getModelType(), 736 mCurrentSettingsFields.getBackendType()) 737 == ModelUtils.VISION_MODEL) { 738 mModule.generateFromPos( 739 finalPrompt, 740 ModelUtils.VISION_MODEL_SEQ_LEN, 741 startPos, 742 MainActivity.this, 743 false); 744 } else if (mCurrentSettingsFields.getModelType() == ModelType.LLAMA_GUARD_3) { 745 String llamaGuardPromptForClassification = 746 PromptFormat.getFormattedLlamaGuardPrompt(rawPrompt); 747 ETLogging.getInstance() 748 .log("Running inference.. prompt=" + llamaGuardPromptForClassification); 749 mModule.generate( 750 llamaGuardPromptForClassification, 751 llamaGuardPromptForClassification.length() + 64, 752 MainActivity.this, 753 false); 754 } else { 755 ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt); 756 mModule.generate( 757 finalPrompt, 758 (int) (finalPrompt.length() * 0.75) + 64, 759 MainActivity.this, 760 false); 761 } 762 763 long generateDuration = System.currentTimeMillis() - generateStartTime; 764 mResultMessage.setTotalGenerationTime(generateDuration); 765 runOnUiThread( 766 new Runnable() { 767 @Override 768 public void run() { 769 onModelRunStopped(); 770 } 771 }); 772 ETLogging.getInstance().log("Inference completed"); 773 } 774 }; 775 executor.execute(runnable); 776 }); 777 mMessageAdapter.notifyDataSetChanged(); 778 } 779 780 @Override run()781 public void run() { 782 runOnUiThread( 783 new Runnable() { 784 @Override 785 public void run() { 786 mMessageAdapter.notifyDataSetChanged(); 787 } 788 }); 789 } 790 791 @Override onBackPressed()792 public void onBackPressed() { 793 super.onBackPressed(); 794 if (mAddMediaLayout != null && mAddMediaLayout.getVisibility() == View.VISIBLE) { 795 mAddMediaLayout.setVisibility(View.GONE); 796 } else { 797 // Default behavior of back button 798 finish(); 799 } 800 } 801 802 @Override onDestroy()803 protected void onDestroy() { 804 super.onDestroy(); 805 mMemoryUpdateHandler.removeCallbacks(memoryUpdater); 806 // This is to cover the case where the app is shutdown when user is on MainActivity but 807 // never clicked on the logsActivity 808 ETLogging.getInstance().saveLogs(); 809 } 810 } 811