• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 package com.example.executorchllamademo;
10 
11 import android.net.Uri;
12 import android.view.LayoutInflater;
13 import android.view.View;
14 import android.view.ViewGroup;
15 import android.widget.ArrayAdapter;
16 import android.widget.ImageView;
17 import android.widget.TextView;
18 import java.util.ArrayList;
19 import java.util.Collections;
20 
21 public class MessageAdapter extends ArrayAdapter<Message> {
22 
23   private final ArrayList<Message> savedMessages;
24 
MessageAdapter( android.content.Context context, int resource, ArrayList<Message> savedMessages)25   public MessageAdapter(
26       android.content.Context context, int resource, ArrayList<Message> savedMessages) {
27     super(context, resource);
28     this.savedMessages = savedMessages;
29   }
30 
31   @Override
getView(int position, View convertView, ViewGroup parent)32   public View getView(int position, View convertView, ViewGroup parent) {
33     Message currentMessage = getItem(position);
34     int layoutIdForListItem;
35 
36     if (currentMessage.getMessageType() == MessageType.SYSTEM) {
37       layoutIdForListItem = R.layout.system_message;
38     } else {
39       layoutIdForListItem =
40           currentMessage.getIsSent() ? R.layout.sent_message : R.layout.received_message;
41     }
42     View listItemView =
43         LayoutInflater.from(getContext()).inflate(layoutIdForListItem, parent, false);
44     if (currentMessage.getMessageType() == MessageType.IMAGE) {
45       ImageView messageImageView = listItemView.requireViewById(R.id.message_image);
46       messageImageView.setImageURI(Uri.parse(currentMessage.getImagePath()));
47       TextView messageTextView = listItemView.requireViewById(R.id.message_text);
48       messageTextView.setVisibility(View.GONE);
49     } else {
50       TextView messageTextView = listItemView.requireViewById(R.id.message_text);
51       messageTextView.setText(currentMessage.getText());
52     }
53 
54     String metrics = "";
55     TextView tokensView;
56     if (currentMessage.getTokensPerSecond() > 0) {
57       metrics = String.format("%.2f", currentMessage.getTokensPerSecond()) + "t/s  ";
58     }
59 
60     if (currentMessage.getTotalGenerationTime() > 0) {
61       metrics = metrics + (float) currentMessage.getTotalGenerationTime() / 1000 + "s  ";
62     }
63 
64     if (currentMessage.getTokensPerSecond() > 0 || currentMessage.getTotalGenerationTime() > 0) {
65       tokensView = listItemView.requireViewById(R.id.generation_metrics);
66       tokensView.setText(metrics);
67       TextView separatorView = listItemView.requireViewById(R.id.bar);
68       separatorView.setVisibility(View.VISIBLE);
69     }
70 
71     if (currentMessage.getTimestamp() > 0) {
72       TextView timestampView = listItemView.requireViewById(R.id.timestamp);
73       timestampView.setText(currentMessage.getFormattedTimestamp());
74     }
75 
76     return listItemView;
77   }
78 
79   @Override
add(Message msg)80   public void add(Message msg) {
81     super.add(msg);
82     savedMessages.add(msg);
83   }
84 
85   @Override
clear()86   public void clear() {
87     super.clear();
88     savedMessages.clear();
89   }
90 
getSavedMessages()91   public ArrayList<Message> getSavedMessages() {
92     return savedMessages;
93   }
94 
getRecentSavedTextMessages(int numOfLatestPromptMessages)95   public ArrayList<Message> getRecentSavedTextMessages(int numOfLatestPromptMessages) {
96     ArrayList<Message> recentMessages = new ArrayList<Message>();
97     int lastIndex = savedMessages.size() - 1;
98     // In most cases lastIndex >=0 .
99     // A situation where the user clears chat history and enters prompt. Causes lastIndex=-1 .
100     if (lastIndex >= 0) {
101       Message messageToAdd = savedMessages.get(lastIndex);
102       int oldPromptID = messageToAdd.getPromptID();
103 
104       for (int i = 0; i < savedMessages.size(); i++) {
105         messageToAdd = savedMessages.get(lastIndex - i);
106         if (messageToAdd.getMessageType() != MessageType.SYSTEM) {
107           if (messageToAdd.getPromptID() != oldPromptID) {
108             numOfLatestPromptMessages--;
109             oldPromptID = messageToAdd.getPromptID();
110           }
111           if (numOfLatestPromptMessages > 0) {
112             if (messageToAdd.getMessageType() == MessageType.TEXT) {
113               recentMessages.add(messageToAdd);
114             }
115           } else {
116             break;
117           }
118         }
119       }
120       // To place the order in [input1, output1, input2, output2...]
121       Collections.reverse(recentMessages);
122     }
123 
124     return recentMessages;
125   }
126 
getMaxPromptID()127   public int getMaxPromptID() {
128     int maxPromptID = -1;
129     for (Message msg : savedMessages) {
130 
131       maxPromptID = Math.max(msg.getPromptID(), maxPromptID);
132     }
133     return maxPromptID;
134   }
135 }
136