1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 package com.example.executorchllamademo; 10 11 import android.net.Uri; 12 import android.view.LayoutInflater; 13 import android.view.View; 14 import android.view.ViewGroup; 15 import android.widget.ArrayAdapter; 16 import android.widget.ImageView; 17 import android.widget.TextView; 18 import java.util.ArrayList; 19 import java.util.Collections; 20 21 public class MessageAdapter extends ArrayAdapter<Message> { 22 23 private final ArrayList<Message> savedMessages; 24 MessageAdapter( android.content.Context context, int resource, ArrayList<Message> savedMessages)25 public MessageAdapter( 26 android.content.Context context, int resource, ArrayList<Message> savedMessages) { 27 super(context, resource); 28 this.savedMessages = savedMessages; 29 } 30 31 @Override getView(int position, View convertView, ViewGroup parent)32 public View getView(int position, View convertView, ViewGroup parent) { 33 Message currentMessage = getItem(position); 34 int layoutIdForListItem; 35 36 if (currentMessage.getMessageType() == MessageType.SYSTEM) { 37 layoutIdForListItem = R.layout.system_message; 38 } else { 39 layoutIdForListItem = 40 currentMessage.getIsSent() ? R.layout.sent_message : R.layout.received_message; 41 } 42 View listItemView = 43 LayoutInflater.from(getContext()).inflate(layoutIdForListItem, parent, false); 44 if (currentMessage.getMessageType() == MessageType.IMAGE) { 45 ImageView messageImageView = listItemView.requireViewById(R.id.message_image); 46 messageImageView.setImageURI(Uri.parse(currentMessage.getImagePath())); 47 TextView messageTextView = listItemView.requireViewById(R.id.message_text); 48 messageTextView.setVisibility(View.GONE); 49 } else { 50 TextView messageTextView = listItemView.requireViewById(R.id.message_text); 51 messageTextView.setText(currentMessage.getText()); 52 } 53 54 String metrics = ""; 55 TextView tokensView; 56 if (currentMessage.getTokensPerSecond() > 0) { 57 metrics = String.format("%.2f", currentMessage.getTokensPerSecond()) + "t/s "; 58 } 59 60 if (currentMessage.getTotalGenerationTime() > 0) { 61 metrics = metrics + (float) currentMessage.getTotalGenerationTime() / 1000 + "s "; 62 } 63 64 if (currentMessage.getTokensPerSecond() > 0 || currentMessage.getTotalGenerationTime() > 0) { 65 tokensView = listItemView.requireViewById(R.id.generation_metrics); 66 tokensView.setText(metrics); 67 TextView separatorView = listItemView.requireViewById(R.id.bar); 68 separatorView.setVisibility(View.VISIBLE); 69 } 70 71 if (currentMessage.getTimestamp() > 0) { 72 TextView timestampView = listItemView.requireViewById(R.id.timestamp); 73 timestampView.setText(currentMessage.getFormattedTimestamp()); 74 } 75 76 return listItemView; 77 } 78 79 @Override add(Message msg)80 public void add(Message msg) { 81 super.add(msg); 82 savedMessages.add(msg); 83 } 84 85 @Override clear()86 public void clear() { 87 super.clear(); 88 savedMessages.clear(); 89 } 90 getSavedMessages()91 public ArrayList<Message> getSavedMessages() { 92 return savedMessages; 93 } 94 getRecentSavedTextMessages(int numOfLatestPromptMessages)95 public ArrayList<Message> getRecentSavedTextMessages(int numOfLatestPromptMessages) { 96 ArrayList<Message> recentMessages = new ArrayList<Message>(); 97 int lastIndex = savedMessages.size() - 1; 98 // In most cases lastIndex >=0 . 99 // A situation where the user clears chat history and enters prompt. Causes lastIndex=-1 . 100 if (lastIndex >= 0) { 101 Message messageToAdd = savedMessages.get(lastIndex); 102 int oldPromptID = messageToAdd.getPromptID(); 103 104 for (int i = 0; i < savedMessages.size(); i++) { 105 messageToAdd = savedMessages.get(lastIndex - i); 106 if (messageToAdd.getMessageType() != MessageType.SYSTEM) { 107 if (messageToAdd.getPromptID() != oldPromptID) { 108 numOfLatestPromptMessages--; 109 oldPromptID = messageToAdd.getPromptID(); 110 } 111 if (numOfLatestPromptMessages > 0) { 112 if (messageToAdd.getMessageType() == MessageType.TEXT) { 113 recentMessages.add(messageToAdd); 114 } 115 } else { 116 break; 117 } 118 } 119 } 120 // To place the order in [input1, output1, input2, output2...] 121 Collections.reverse(recentMessages); 122 } 123 124 return recentMessages; 125 } 126 getMaxPromptID()127 public int getMaxPromptID() { 128 int maxPromptID = -1; 129 for (Message msg : savedMessages) { 130 131 maxPromptID = Math.max(msg.getPromptID(), maxPromptID); 132 } 133 return maxPromptID; 134 } 135 } 136