• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 package com.example.executorchllamademo;
10 
11 import android.Manifest;
12 import android.app.ActivityManager;
13 import android.app.AlertDialog;
14 import android.content.ContentResolver;
15 import android.content.ContentValues;
16 import android.content.Intent;
17 import android.content.pm.PackageManager;
18 import android.net.Uri;
19 import android.os.Build;
20 import android.os.Bundle;
21 import android.os.Handler;
22 import android.os.Looper;
23 import android.os.Process;
24 import android.provider.MediaStore;
25 import android.system.ErrnoException;
26 import android.system.Os;
27 import android.util.Log;
28 import android.view.View;
29 import android.view.inputmethod.InputMethodManager;
30 import android.widget.EditText;
31 import android.widget.ImageButton;
32 import android.widget.ImageView;
33 import android.widget.LinearLayout;
34 import android.widget.ListView;
35 import android.widget.TextView;
36 import android.widget.Toast;
37 import androidx.activity.result.ActivityResultLauncher;
38 import androidx.activity.result.PickVisualMediaRequest;
39 import androidx.activity.result.contract.ActivityResultContracts;
40 import androidx.annotation.NonNull;
41 import androidx.appcompat.app.AppCompatActivity;
42 import androidx.constraintlayout.widget.ConstraintLayout;
43 import androidx.core.app.ActivityCompat;
44 import androidx.core.content.ContextCompat;
45 import com.google.gson.Gson;
46 import com.google.gson.reflect.TypeToken;
47 import java.lang.reflect.Type;
48 import java.util.ArrayList;
49 import java.util.List;
50 import java.util.concurrent.Executor;
51 import java.util.concurrent.Executors;
52 import org.pytorch.executorch.LlamaCallback;
53 import org.pytorch.executorch.LlamaModule;
54 
55 public class MainActivity extends AppCompatActivity implements Runnable, LlamaCallback {
56   private EditText mEditTextMessage;
57   private ImageButton mSendButton;
58   private ImageButton mGalleryButton;
59   private ImageButton mCameraButton;
60   private ListView mMessagesView;
61   private MessageAdapter mMessageAdapter;
62   private LlamaModule mModule = null;
63   private Message mResultMessage = null;
64   private ImageButton mSettingsButton;
65   private TextView mMemoryView;
66   private ActivityResultLauncher<PickVisualMediaRequest> mPickGallery;
67   private ActivityResultLauncher<Uri> mCameraRoll;
68   private List<Uri> mSelectedImageUri;
69   private ConstraintLayout mMediaPreviewConstraintLayout;
70   private LinearLayout mAddMediaLayout;
71   private static final int MAX_NUM_OF_IMAGES = 5;
72   private static final int REQUEST_IMAGE_CAPTURE = 1;
73   private Uri cameraImageUri;
74   private DemoSharedPreferences mDemoSharedPreferences;
75   private SettingsFields mCurrentSettingsFields;
76   private Handler mMemoryUpdateHandler;
77   private Runnable memoryUpdater;
78   private int promptID = 0;
79   private long startPos = 0;
80   private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2;
81   private Executor executor;
82 
83   @Override
onResult(String result)84   public void onResult(String result) {
85     if (result.equals(PromptFormat.getStopToken(mCurrentSettingsFields.getModelType()))) {
86       return;
87     }
88     if (result.equals("\n\n") || result.equals("\n")) {
89       if (!mResultMessage.getText().isEmpty()) {
90         mResultMessage.appendText(result);
91         run();
92       }
93     } else {
94       mResultMessage.appendText(result);
95       run();
96     }
97   }
98 
99   @Override
onStats(float tps)100   public void onStats(float tps) {
101     runOnUiThread(
102         () -> {
103           if (mResultMessage != null) {
104             mResultMessage.setTokensPerSecond(tps);
105             mMessageAdapter.notifyDataSetChanged();
106           }
107         });
108   }
109 
setLocalModel(String modelPath, String tokenizerPath, float temperature)110   private void setLocalModel(String modelPath, String tokenizerPath, float temperature) {
111     Message modelLoadingMessage = new Message("Loading model...", false, MessageType.SYSTEM, 0);
112     ETLogging.getInstance().log("Loading model " + modelPath + " with tokenizer " + tokenizerPath);
113     runOnUiThread(
114         () -> {
115           mSendButton.setEnabled(false);
116           mMessageAdapter.add(modelLoadingMessage);
117           mMessageAdapter.notifyDataSetChanged();
118         });
119     if (mModule != null) {
120       ETLogging.getInstance().log("Start deallocating existing module instance");
121       mModule.resetNative();
122       mModule = null;
123       ETLogging.getInstance().log("Completed deallocating existing module instance");
124     }
125     long runStartTime = System.currentTimeMillis();
126     mModule =
127         new LlamaModule(
128             ModelUtils.getModelCategory(
129                 mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType()),
130             modelPath,
131             tokenizerPath,
132             temperature);
133     int loadResult = mModule.load();
134     long loadDuration = System.currentTimeMillis() - runStartTime;
135     String modelLoadError = "";
136     String modelInfo = "";
137     if (loadResult != 0) {
138       // TODO: Map the error code to a reason to let the user know why model loading failed
139       modelInfo = "*Model could not load (Error Code: " + loadResult + ")*" + "\n";
140       loadDuration = 0;
141       AlertDialog.Builder builder = new AlertDialog.Builder(this);
142       builder.setTitle("Load failed: " + loadResult);
143       runOnUiThread(
144           () -> {
145             AlertDialog alert = builder.create();
146             alert.show();
147           });
148     } else {
149       String[] segments = modelPath.split("/");
150       String pteName = segments[segments.length - 1];
151       segments = tokenizerPath.split("/");
152       String tokenizerName = segments[segments.length - 1];
153       modelInfo =
154           "Successfully loaded model. "
155               + pteName
156               + " and tokenizer "
157               + tokenizerName
158               + " in "
159               + (float) loadDuration / 1000
160               + " sec."
161               + " You can send text or image for inference";
162 
163       if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) {
164         ETLogging.getInstance().log("Llava start prefill prompt");
165         startPos = mModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt(), 0, 1, 0);
166         ETLogging.getInstance().log("Llava completes prefill prompt");
167       }
168     }
169 
170     Message modelLoadedMessage = new Message(modelInfo, false, MessageType.SYSTEM, 0);
171 
172     String modelLoggingInfo =
173         modelLoadError
174             + "Model path: "
175             + modelPath
176             + "\nTokenizer path: "
177             + tokenizerPath
178             + "\nBackend: "
179             + mCurrentSettingsFields.getBackendType().toString()
180             + "\nModelType: "
181             + ModelUtils.getModelCategory(
182                 mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType())
183             + "\nTemperature: "
184             + temperature
185             + "\nModel loaded time: "
186             + loadDuration
187             + " ms";
188     ETLogging.getInstance().log("Load complete. " + modelLoggingInfo);
189 
190     runOnUiThread(
191         () -> {
192           mSendButton.setEnabled(true);
193           mMessageAdapter.remove(modelLoadingMessage);
194           mMessageAdapter.add(modelLoadedMessage);
195           mMessageAdapter.notifyDataSetChanged();
196         });
197   }
198 
loadLocalModelAndParameters( String modelFilePath, String tokenizerFilePath, float temperature)199   private void loadLocalModelAndParameters(
200       String modelFilePath, String tokenizerFilePath, float temperature) {
201     Runnable runnable =
202         new Runnable() {
203           @Override
204           public void run() {
205             setLocalModel(modelFilePath, tokenizerFilePath, temperature);
206           }
207         };
208     new Thread(runnable).start();
209   }
210 
populateExistingMessages(String existingMsgJSON)211   private void populateExistingMessages(String existingMsgJSON) {
212     Gson gson = new Gson();
213     Type type = new TypeToken<ArrayList<Message>>() {}.getType();
214     ArrayList<Message> savedMessages = gson.fromJson(existingMsgJSON, type);
215     for (Message msg : savedMessages) {
216       mMessageAdapter.add(msg);
217     }
218     mMessageAdapter.notifyDataSetChanged();
219   }
220 
setPromptID()221   private int setPromptID() {
222 
223     return mMessageAdapter.getMaxPromptID() + 1;
224   }
225 
226   @Override
onCreate(Bundle savedInstanceState)227   protected void onCreate(Bundle savedInstanceState) {
228     super.onCreate(savedInstanceState);
229     setContentView(R.layout.activity_main);
230 
231     if (Build.VERSION.SDK_INT >= 21) {
232       getWindow().setStatusBarColor(ContextCompat.getColor(this, R.color.status_bar));
233       getWindow().setNavigationBarColor(ContextCompat.getColor(this, R.color.nav_bar));
234     }
235 
236     try {
237       Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true);
238       Os.setenv("LD_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true);
239     } catch (ErrnoException e) {
240       finish();
241     }
242 
243     mEditTextMessage = requireViewById(R.id.editTextMessage);
244     mSendButton = requireViewById(R.id.sendButton);
245     mSendButton.setEnabled(false);
246     mMessagesView = requireViewById(R.id.messages_view);
247     mMessageAdapter = new MessageAdapter(this, R.layout.sent_message, new ArrayList<Message>());
248     mMessagesView.setAdapter(mMessageAdapter);
249     mDemoSharedPreferences = new DemoSharedPreferences(this.getApplicationContext());
250     String existingMsgJSON = mDemoSharedPreferences.getSavedMessages();
251     if (!existingMsgJSON.isEmpty()) {
252       populateExistingMessages(existingMsgJSON);
253       promptID = setPromptID();
254     }
255     mSettingsButton = requireViewById(R.id.settings);
256     mSettingsButton.setOnClickListener(
257         view -> {
258           Intent myIntent = new Intent(MainActivity.this, SettingsActivity.class);
259           MainActivity.this.startActivity(myIntent);
260         });
261 
262     mCurrentSettingsFields = new SettingsFields();
263     mMemoryUpdateHandler = new Handler(Looper.getMainLooper());
264     onModelRunStopped();
265     setupMediaButton();
266     setupGalleryPicker();
267     setupCameraRoll();
268     startMemoryUpdate();
269     setupShowLogsButton();
270     executor = Executors.newSingleThreadExecutor();
271   }
272 
273   @Override
onPause()274   protected void onPause() {
275     super.onPause();
276     mDemoSharedPreferences.addMessages(mMessageAdapter);
277   }
278 
279   @Override
onResume()280   protected void onResume() {
281     super.onResume();
282     // Check for if settings parameters have changed
283     Gson gson = new Gson();
284     String settingsFieldsJSON = mDemoSharedPreferences.getSettings();
285     if (!settingsFieldsJSON.isEmpty()) {
286       SettingsFields updatedSettingsFields =
287           gson.fromJson(settingsFieldsJSON, SettingsFields.class);
288       if (updatedSettingsFields == null) {
289         // Added this check, because gson.fromJson can return null
290         askUserToSelectModel();
291         return;
292       }
293       boolean isUpdated = !mCurrentSettingsFields.equals(updatedSettingsFields);
294       boolean isLoadModel = updatedSettingsFields.getIsLoadModel();
295       setBackendMode(updatedSettingsFields.getBackendType());
296       if (isUpdated) {
297         if (isLoadModel) {
298           // If users change the model file, but not pressing loadModelButton, we won't load the new
299           // model
300           checkForUpdateAndReloadModel(updatedSettingsFields);
301         } else {
302           askUserToSelectModel();
303         }
304 
305         checkForClearChatHistory(updatedSettingsFields);
306         // Update current to point to the latest
307         mCurrentSettingsFields = new SettingsFields(updatedSettingsFields);
308       }
309     } else {
310       askUserToSelectModel();
311     }
312   }
313 
setBackendMode(BackendType backendType)314   private void setBackendMode(BackendType backendType) {
315     if (backendType.equals(BackendType.XNNPACK) || backendType.equals(BackendType.QUALCOMM)) {
316       setXNNPACKMode();
317     } else if (backendType.equals(BackendType.MEDIATEK)) {
318       setMediaTekMode();
319     }
320   }
321 
setXNNPACKMode()322   private void setXNNPACKMode() {
323     requireViewById(R.id.addMediaButton).setVisibility(View.VISIBLE);
324   }
325 
setMediaTekMode()326   private void setMediaTekMode() {
327     requireViewById(R.id.addMediaButton).setVisibility(View.GONE);
328   }
329 
checkForClearChatHistory(SettingsFields updatedSettingsFields)330   private void checkForClearChatHistory(SettingsFields updatedSettingsFields) {
331     if (updatedSettingsFields.getIsClearChatHistory()) {
332       mMessageAdapter.clear();
333       mMessageAdapter.notifyDataSetChanged();
334       mDemoSharedPreferences.removeExistingMessages();
335       // changing to false since chat history has been cleared.
336       updatedSettingsFields.saveIsClearChatHistory(false);
337       mDemoSharedPreferences.addSettings(updatedSettingsFields);
338     }
339   }
340 
checkForUpdateAndReloadModel(SettingsFields updatedSettingsFields)341   private void checkForUpdateAndReloadModel(SettingsFields updatedSettingsFields) {
342     // TODO need to add 'load model' in settings and queue loading based on that
343     String modelPath = updatedSettingsFields.getModelFilePath();
344     String tokenizerPath = updatedSettingsFields.getTokenizerFilePath();
345     double temperature = updatedSettingsFields.getTemperature();
346     if (!modelPath.isEmpty() && !tokenizerPath.isEmpty()) {
347       if (updatedSettingsFields.getIsLoadModel()
348           || !modelPath.equals(mCurrentSettingsFields.getModelFilePath())
349           || !tokenizerPath.equals(mCurrentSettingsFields.getTokenizerFilePath())
350           || temperature != mCurrentSettingsFields.getTemperature()) {
351         loadLocalModelAndParameters(
352             updatedSettingsFields.getModelFilePath(),
353             updatedSettingsFields.getTokenizerFilePath(),
354             (float) updatedSettingsFields.getTemperature());
355         updatedSettingsFields.saveLoadModelAction(false);
356         mDemoSharedPreferences.addSettings(updatedSettingsFields);
357       }
358     } else {
359       askUserToSelectModel();
360     }
361   }
362 
askUserToSelectModel()363   private void askUserToSelectModel() {
364     String askLoadModel =
365         "To get started, select your desired model and tokenizer " + "from the top right corner";
366     Message askLoadModelMessage = new Message(askLoadModel, false, MessageType.SYSTEM, 0);
367     ETLogging.getInstance().log(askLoadModel);
368     runOnUiThread(
369         () -> {
370           mMessageAdapter.add(askLoadModelMessage);
371           mMessageAdapter.notifyDataSetChanged();
372         });
373   }
374 
setupShowLogsButton()375   private void setupShowLogsButton() {
376     ImageButton showLogsButton = requireViewById(R.id.showLogsButton);
377     showLogsButton.setOnClickListener(
378         view -> {
379           Intent myIntent = new Intent(MainActivity.this, LogsActivity.class);
380           MainActivity.this.startActivity(myIntent);
381         });
382   }
383 
setupMediaButton()384   private void setupMediaButton() {
385     mAddMediaLayout = requireViewById(R.id.addMediaLayout);
386     mAddMediaLayout.setVisibility(View.GONE); // We hide this initially
387 
388     ImageButton addMediaButton = requireViewById(R.id.addMediaButton);
389     addMediaButton.setOnClickListener(
390         view -> {
391           mAddMediaLayout.setVisibility(View.VISIBLE);
392         });
393 
394     mGalleryButton = requireViewById(R.id.galleryButton);
395     mGalleryButton.setOnClickListener(
396         view -> {
397           // Launch the photo picker and let the user choose only images.
398           mPickGallery.launch(
399               new PickVisualMediaRequest.Builder()
400                   .setMediaType(ActivityResultContracts.PickVisualMedia.ImageOnly.INSTANCE)
401                   .build());
402         });
403     mCameraButton = requireViewById(R.id.cameraButton);
404     mCameraButton.setOnClickListener(
405         view -> {
406           Log.d("CameraRoll", "Check permission");
407           if (ContextCompat.checkSelfPermission(MainActivity.this, Manifest.permission.CAMERA)
408               != PackageManager.PERMISSION_GRANTED) {
409             ActivityCompat.requestPermissions(
410                 MainActivity.this,
411                 new String[] {Manifest.permission.CAMERA},
412                 REQUEST_IMAGE_CAPTURE);
413           } else {
414             launchCamera();
415           }
416         });
417   }
418 
setupCameraRoll()419   private void setupCameraRoll() {
420     // Registers a camera roll activity launcher.
421     mCameraRoll =
422         registerForActivityResult(
423             new ActivityResultContracts.TakePicture(),
424             result -> {
425               if (result && cameraImageUri != null) {
426                 Log.d("CameraRoll", "Photo saved to uri: " + cameraImageUri);
427                 mAddMediaLayout.setVisibility(View.GONE);
428                 List<Uri> uris = new ArrayList<>();
429                 uris.add(cameraImageUri);
430                 showMediaPreview(uris);
431               } else {
432                 // Delete the temp image file based on the url since the photo is not successfully
433                 // taken
434                 if (cameraImageUri != null) {
435                   ContentResolver contentResolver = MainActivity.this.getContentResolver();
436                   contentResolver.delete(cameraImageUri, null, null);
437                   Log.d("CameraRoll", "No photo taken. Delete temp uri");
438                 }
439               }
440             });
441     mMediaPreviewConstraintLayout = requireViewById(R.id.mediaPreviewConstraintLayout);
442     ImageButton mediaPreviewCloseButton = requireViewById(R.id.mediaPreviewCloseButton);
443     mediaPreviewCloseButton.setOnClickListener(
444         view -> {
445           mMediaPreviewConstraintLayout.setVisibility(View.GONE);
446           mSelectedImageUri = null;
447         });
448 
449     ImageButton addMoreImageButton = requireViewById(R.id.addMoreImageButton);
450     addMoreImageButton.setOnClickListener(
451         view -> {
452           Log.d("addMore", "clicked");
453           mMediaPreviewConstraintLayout.setVisibility(View.GONE);
454           // Direct user to select type of input
455           mCameraButton.callOnClick();
456         });
457   }
458 
updateMemoryUsage()459   private String updateMemoryUsage() {
460     ActivityManager.MemoryInfo memoryInfo = new ActivityManager.MemoryInfo();
461     ActivityManager activityManager = (ActivityManager) getSystemService(ACTIVITY_SERVICE);
462     if (activityManager == null) {
463       return "---";
464     }
465     activityManager.getMemoryInfo(memoryInfo);
466     long totalMem = memoryInfo.totalMem / (1024 * 1024);
467     long availableMem = memoryInfo.availMem / (1024 * 1024);
468     long usedMem = totalMem - availableMem;
469     return usedMem + "MB";
470   }
471 
startMemoryUpdate()472   private void startMemoryUpdate() {
473     mMemoryView = requireViewById(R.id.ram_usage_live);
474     memoryUpdater =
475         new Runnable() {
476           @Override
477           public void run() {
478             mMemoryView.setText(updateMemoryUsage());
479             mMemoryUpdateHandler.postDelayed(this, 1000);
480           }
481         };
482     mMemoryUpdateHandler.post(memoryUpdater);
483   }
484 
485   @Override
onRequestPermissionsResult( int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults)486   public void onRequestPermissionsResult(
487       int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
488     super.onRequestPermissionsResult(requestCode, permissions, grantResults);
489     if (requestCode == REQUEST_IMAGE_CAPTURE && grantResults.length != 0) {
490       if (grantResults[0] == PackageManager.PERMISSION_GRANTED) {
491         launchCamera();
492       } else if (grantResults[0] == PackageManager.PERMISSION_DENIED) {
493         Log.d("CameraRoll", "Permission denied");
494       }
495     }
496   }
497 
launchCamera()498   private void launchCamera() {
499     ContentValues values = new ContentValues();
500     values.put(MediaStore.Images.Media.TITLE, "New Picture");
501     values.put(MediaStore.Images.Media.DESCRIPTION, "From Camera");
502     values.put(MediaStore.Images.Media.RELATIVE_PATH, "DCIM/Camera/");
503     cameraImageUri =
504         MainActivity.this
505             .getContentResolver()
506             .insert(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, values);
507     mCameraRoll.launch(cameraImageUri);
508   }
509 
setupGalleryPicker()510   private void setupGalleryPicker() {
511     // Registers a photo picker activity launcher in single-select mode.
512     mPickGallery =
513         registerForActivityResult(
514             new ActivityResultContracts.PickMultipleVisualMedia(MAX_NUM_OF_IMAGES),
515             uris -> {
516               if (!uris.isEmpty()) {
517                 Log.d("PhotoPicker", "Selected URIs: " + uris);
518                 mAddMediaLayout.setVisibility(View.GONE);
519                 for (Uri uri : uris) {
520                   MainActivity.this
521                       .getContentResolver()
522                       .takePersistableUriPermission(uri, Intent.FLAG_GRANT_READ_URI_PERMISSION);
523                 }
524                 showMediaPreview(uris);
525               } else {
526                 Log.d("PhotoPicker", "No media selected");
527               }
528             });
529 
530     mMediaPreviewConstraintLayout = requireViewById(R.id.mediaPreviewConstraintLayout);
531     ImageButton mediaPreviewCloseButton = requireViewById(R.id.mediaPreviewCloseButton);
532     mediaPreviewCloseButton.setOnClickListener(
533         view -> {
534           mMediaPreviewConstraintLayout.setVisibility(View.GONE);
535           mSelectedImageUri = null;
536         });
537 
538     ImageButton addMoreImageButton = requireViewById(R.id.addMoreImageButton);
539     addMoreImageButton.setOnClickListener(
540         view -> {
541           Log.d("addMore", "clicked");
542           mMediaPreviewConstraintLayout.setVisibility(View.GONE);
543           mGalleryButton.callOnClick();
544         });
545   }
546 
getProcessedImagesForModel(List<Uri> uris)547   private List<ETImage> getProcessedImagesForModel(List<Uri> uris) {
548     List<ETImage> imageList = new ArrayList<>();
549     if (uris != null) {
550       uris.forEach(
551           (uri) -> {
552             imageList.add(new ETImage(this.getContentResolver(), uri));
553           });
554     }
555     return imageList;
556   }
557 
showMediaPreview(List<Uri> uris)558   private void showMediaPreview(List<Uri> uris) {
559     if (mSelectedImageUri == null) {
560       mSelectedImageUri = uris;
561     } else {
562       mSelectedImageUri.addAll(uris);
563     }
564 
565     if (mSelectedImageUri.size() > MAX_NUM_OF_IMAGES) {
566       mSelectedImageUri = mSelectedImageUri.subList(0, MAX_NUM_OF_IMAGES);
567       Toast.makeText(
568               this, "Only max " + MAX_NUM_OF_IMAGES + " images are allowed", Toast.LENGTH_SHORT)
569           .show();
570     }
571     Log.d("mSelectedImageUri", mSelectedImageUri.size() + " " + mSelectedImageUri);
572 
573     mMediaPreviewConstraintLayout.setVisibility(View.VISIBLE);
574 
575     List<ImageView> imageViews = new ArrayList<ImageView>();
576 
577     // Pre-populate all the image views that are available from the layout (currently max 5)
578     imageViews.add(requireViewById(R.id.mediaPreviewImageView1));
579     imageViews.add(requireViewById(R.id.mediaPreviewImageView2));
580     imageViews.add(requireViewById(R.id.mediaPreviewImageView3));
581     imageViews.add(requireViewById(R.id.mediaPreviewImageView4));
582     imageViews.add(requireViewById(R.id.mediaPreviewImageView5));
583 
584     // Hide all the image views (reset state)
585     for (int i = 0; i < imageViews.size(); i++) {
586       imageViews.get(i).setVisibility(View.GONE);
587     }
588 
589     // Only show/render those that have proper Image URIs
590     for (int i = 0; i < mSelectedImageUri.size(); i++) {
591       imageViews.get(i).setVisibility(View.VISIBLE);
592       imageViews.get(i).setImageURI(mSelectedImageUri.get(i));
593     }
594 
595     // For LLava, we want to call prefill_image as soon as an image is selected
596     // Llava only support 1 image for now
597     if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) {
598       List<ETImage> processedImageList = getProcessedImagesForModel(mSelectedImageUri);
599       if (!processedImageList.isEmpty()) {
600         mMessageAdapter.add(
601             new Message("Llava - Starting image Prefill.", false, MessageType.SYSTEM, 0));
602         mMessageAdapter.notifyDataSetChanged();
603         Runnable runnable =
604             () -> {
605               Process.setThreadPriority(Process.THREAD_PRIORITY_MORE_FAVORABLE);
606               ETLogging.getInstance().log("Starting runnable prefill image");
607               ETImage img = processedImageList.get(0);
608               ETLogging.getInstance().log("Llava start prefill image");
609               startPos =
610                   mModule.prefillImages(
611                       img.getInts(),
612                       img.getWidth(),
613                       img.getHeight(),
614                       ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
615                       startPos);
616             };
617         executor.execute(runnable);
618       }
619     }
620   }
621 
addSelectedImagesToChatThread(List<Uri> selectedImageUri)622   private void addSelectedImagesToChatThread(List<Uri> selectedImageUri) {
623     if (selectedImageUri == null) {
624       return;
625     }
626     mMediaPreviewConstraintLayout.setVisibility(View.GONE);
627     for (int i = 0; i < selectedImageUri.size(); i++) {
628       Uri imageURI = selectedImageUri.get(i);
629       Log.d("image uri ", "test " + imageURI.getPath());
630       mMessageAdapter.add(new Message(imageURI.toString(), true, MessageType.IMAGE, 0));
631     }
632     mMessageAdapter.notifyDataSetChanged();
633   }
634 
getConversationHistory()635   private String getConversationHistory() {
636     String conversationHistory = "";
637 
638     ArrayList<Message> conversations =
639         mMessageAdapter.getRecentSavedTextMessages(CONVERSATION_HISTORY_MESSAGE_LOOKBACK);
640     if (conversations.isEmpty()) {
641       return conversationHistory;
642     }
643 
644     int prevPromptID = conversations.get(0).getPromptID();
645     String conversationFormat =
646         PromptFormat.getConversationFormat(mCurrentSettingsFields.getModelType());
647     String format = conversationFormat;
648     for (int i = 0; i < conversations.size(); i++) {
649       Message conversation = conversations.get(i);
650       int currentPromptID = conversation.getPromptID();
651       if (currentPromptID != prevPromptID) {
652         conversationHistory = conversationHistory + format;
653         format = conversationFormat;
654         prevPromptID = currentPromptID;
655       }
656       if (conversation.getIsSent()) {
657         format = format.replace(PromptFormat.USER_PLACEHOLDER, conversation.getText());
658       } else {
659         format = format.replace(PromptFormat.ASSISTANT_PLACEHOLDER, conversation.getText());
660       }
661     }
662     conversationHistory = conversationHistory + format;
663 
664     return conversationHistory;
665   }
666 
getTotalFormattedPrompt(String conversationHistory, String rawPrompt)667   private String getTotalFormattedPrompt(String conversationHistory, String rawPrompt) {
668     if (conversationHistory.isEmpty()) {
669       return mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt);
670     }
671 
672     return mCurrentSettingsFields.getFormattedSystemPrompt()
673         + conversationHistory
674         + mCurrentSettingsFields.getFormattedUserPrompt(rawPrompt);
675   }
676 
onModelRunStarted()677   private void onModelRunStarted() {
678     mSendButton.setClickable(false);
679     mSendButton.setImageResource(R.drawable.baseline_stop_24);
680     mSendButton.setOnClickListener(
681         view -> {
682           mModule.stop();
683         });
684   }
685 
onModelRunStopped()686   private void onModelRunStopped() {
687     mSendButton.setClickable(true);
688     mSendButton.setImageResource(R.drawable.baseline_send_24);
689     mSendButton.setOnClickListener(
690         view -> {
691           try {
692             InputMethodManager imm = (InputMethodManager) getSystemService(INPUT_METHOD_SERVICE);
693             imm.hideSoftInputFromWindow(getCurrentFocus().getWindowToken(), 0);
694           } catch (Exception e) {
695             ETLogging.getInstance().log("Keyboard dismissal error: " + e.getMessage());
696           }
697           addSelectedImagesToChatThread(mSelectedImageUri);
698           String finalPrompt;
699           String rawPrompt = mEditTextMessage.getText().toString();
700           if (ModelUtils.getModelCategory(
701                   mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType())
702               == ModelUtils.VISION_MODEL) {
703             finalPrompt = mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt);
704           } else {
705             finalPrompt = getTotalFormattedPrompt(getConversationHistory(), rawPrompt);
706           }
707           // We store raw prompt into message adapter, because we don't want to show the extra
708           // tokens from system prompt
709           mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, promptID));
710           mMessageAdapter.notifyDataSetChanged();
711           mEditTextMessage.setText("");
712           mResultMessage = new Message("", false, MessageType.TEXT, promptID);
713           mMessageAdapter.add(mResultMessage);
714           // Scroll to bottom of the list
715           mMessagesView.smoothScrollToPosition(mMessageAdapter.getCount() - 1);
716           // After images are added to prompt and chat thread, we clear the imageURI list
717           // Note: This has to be done after imageURIs are no longer needed by LlamaModule
718           mSelectedImageUri = null;
719           promptID++;
720           Runnable runnable =
721               new Runnable() {
722                 @Override
723                 public void run() {
724                   Process.setThreadPriority(Process.THREAD_PRIORITY_MORE_FAVORABLE);
725                   ETLogging.getInstance().log("starting runnable generate()");
726                   runOnUiThread(
727                       new Runnable() {
728                         @Override
729                         public void run() {
730                           onModelRunStarted();
731                         }
732                       });
733                   long generateStartTime = System.currentTimeMillis();
734                   if (ModelUtils.getModelCategory(
735                           mCurrentSettingsFields.getModelType(),
736                           mCurrentSettingsFields.getBackendType())
737                       == ModelUtils.VISION_MODEL) {
738                     mModule.generateFromPos(
739                         finalPrompt,
740                         ModelUtils.VISION_MODEL_SEQ_LEN,
741                         startPos,
742                         MainActivity.this,
743                         false);
744                   } else if (mCurrentSettingsFields.getModelType() == ModelType.LLAMA_GUARD_3) {
745                     String llamaGuardPromptForClassification =
746                         PromptFormat.getFormattedLlamaGuardPrompt(rawPrompt);
747                     ETLogging.getInstance()
748                         .log("Running inference.. prompt=" + llamaGuardPromptForClassification);
749                     mModule.generate(
750                         llamaGuardPromptForClassification,
751                         llamaGuardPromptForClassification.length() + 64,
752                         MainActivity.this,
753                         false);
754                   } else {
755                     ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt);
756                     mModule.generate(
757                         finalPrompt,
758                         (int) (finalPrompt.length() * 0.75) + 64,
759                         MainActivity.this,
760                         false);
761                   }
762 
763                   long generateDuration = System.currentTimeMillis() - generateStartTime;
764                   mResultMessage.setTotalGenerationTime(generateDuration);
765                   runOnUiThread(
766                       new Runnable() {
767                         @Override
768                         public void run() {
769                           onModelRunStopped();
770                         }
771                       });
772                   ETLogging.getInstance().log("Inference completed");
773                 }
774               };
775           executor.execute(runnable);
776         });
777     mMessageAdapter.notifyDataSetChanged();
778   }
779 
780   @Override
run()781   public void run() {
782     runOnUiThread(
783         new Runnable() {
784           @Override
785           public void run() {
786             mMessageAdapter.notifyDataSetChanged();
787           }
788         });
789   }
790 
791   @Override
onBackPressed()792   public void onBackPressed() {
793     super.onBackPressed();
794     if (mAddMediaLayout != null && mAddMediaLayout.getVisibility() == View.VISIBLE) {
795       mAddMediaLayout.setVisibility(View.GONE);
796     } else {
797       // Default behavior of back button
798       finish();
799     }
800   }
801 
802   @Override
onDestroy()803   protected void onDestroy() {
804     super.onDestroy();
805     mMemoryUpdateHandler.removeCallbacks(memoryUpdater);
806     // This is to cover the case where the app is shutdown when user is on MainActivity but
807     // never clicked on the logsActivity
808     ETLogging.getInstance().saveLogs();
809   }
810 }
811