• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 package com.example.executorchllamademo;
10 
11 import android.app.Activity;
12 import android.app.ActivityManager;
13 import android.content.Intent;
14 import android.os.Build;
15 import android.os.Bundle;
16 import android.util.Log;
17 import android.widget.TextView;
18 import androidx.annotation.NonNull;
19 import com.google.gson.Gson;
20 import java.io.File;
21 import java.io.FileWriter;
22 import java.io.IOException;
23 import java.util.ArrayList;
24 import java.util.Arrays;
25 import java.util.List;
26 import java.util.regex.Matcher;
27 import java.util.regex.Pattern;
28 
29 public class LlmBenchmarkRunner extends Activity implements ModelRunnerCallback {
30   ModelRunner mModelRunner;
31 
32   String mPrompt;
33   TextView mTextView;
34   StatsDump mStatsDump;
35 
36   @Override
onCreate(Bundle savedInstanceState)37   protected void onCreate(Bundle savedInstanceState) {
38     super.onCreate(savedInstanceState);
39     setContentView(R.layout.activity_benchmarking);
40     mTextView = findViewById(R.id.log_view);
41 
42     Intent intent = getIntent();
43 
44     File modelDir = new File(intent.getStringExtra("model_dir"));
45     File model =
46         Arrays.stream(modelDir.listFiles())
47             .filter(file -> file.getName().endsWith(".pte"))
48             .findFirst()
49             .get();
50     String tokenizerPath = intent.getStringExtra("tokenizer_path");
51 
52     float temperature = intent.getFloatExtra("temperature", 0.8f);
53     mPrompt = intent.getStringExtra("prompt");
54     if (mPrompt == null) {
55       mPrompt = "The ultimate answer";
56     }
57 
58     mStatsDump = new StatsDump();
59     mStatsDump.modelName = model.getName().replace(".pte", "");
60     mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this);
61     mStatsDump.loadStart = System.nanoTime();
62   }
63 
64   @Override
onModelLoaded(int status)65   public void onModelLoaded(int status) {
66     mStatsDump.loadEnd = System.nanoTime();
67     mStatsDump.loadStatus = status;
68     if (status != 0) {
69       Log.e("LlmBenchmarkRunner", "Loaded failed: " + status);
70       onGenerationStopped();
71       return;
72     }
73     mStatsDump.generateStart = System.nanoTime();
74     mModelRunner.generate(mPrompt);
75   }
76 
77   @Override
onTokenGenerated(String token)78   public void onTokenGenerated(String token) {
79     runOnUiThread(
80         () -> {
81           mTextView.append(token);
82         });
83   }
84 
85   @Override
onStats(String stats)86   public void onStats(String stats) {
87     mStatsDump.tokens = stats;
88   }
89 
90   @Override
onGenerationStopped()91   public void onGenerationStopped() {
92     mStatsDump.generateEnd = System.nanoTime();
93     runOnUiThread(
94         () -> {
95           mTextView.append(mStatsDump.toString());
96         });
97 
98     final BenchmarkMetric.BenchmarkModel benchmarkModel =
99         BenchmarkMetric.extractBackendAndQuantization(mStatsDump.modelName);
100     final List<BenchmarkMetric> results = new ArrayList<>();
101     // The list of metrics we have atm includes:
102     // Load status
103     results.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsDump.loadStatus, 0));
104     // Model load time
105     results.add(
106         new BenchmarkMetric(
107             benchmarkModel,
108             "model_load_time(ms)",
109             (mStatsDump.loadEnd - mStatsDump.loadStart) * 1e-6,
110             0.0f));
111     // LLM generate time
112     results.add(
113         new BenchmarkMetric(
114             benchmarkModel,
115             "generate_time(ms)",
116             (mStatsDump.generateEnd - mStatsDump.generateStart) * 1e-6,
117             0.0f));
118     // Token per second
119     results.add(
120         new BenchmarkMetric(benchmarkModel, "token_per_sec", extractTPS(mStatsDump.tokens), 0.0f));
121 
122     try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
123       Gson gson = new Gson();
124       writer.write(gson.toJson(results));
125     } catch (IOException e) {
126       e.printStackTrace();
127     }
128   }
129 
extractTPS(final String tokens)130   private double extractTPS(final String tokens) {
131     final Matcher m = Pattern.compile("\\d+\\.?\\d*").matcher(tokens);
132     if (m.find()) {
133       return Double.parseDouble(m.group());
134     } else {
135       return 0.0f;
136     }
137   }
138 }
139 
140 class BenchmarkMetric {
141   public static class BenchmarkModel {
142     // The model name, i.e. stories110M
143     String name;
144     String backend;
145     String quantization;
146 
BenchmarkModel(final String name, final String backend, final String quantization)147     public BenchmarkModel(final String name, final String backend, final String quantization) {
148       this.name = name;
149       this.backend = backend;
150       this.quantization = quantization;
151     }
152   }
153 
154   BenchmarkModel benchmarkModel;
155 
156   // The metric name, i.e. TPS
157   String metric;
158 
159   // The actual value and the option target value
160   double actualValue;
161   double targetValue;
162 
163   public static class DeviceInfo {
164     // Let's see which information we want to include here
165     final String device = Build.BRAND;
166     // The phone model and Android release version
167     final String arch = Build.MODEL;
168     final String os = "Android " + Build.VERSION.RELEASE;
169     final long totalMem = new ActivityManager.MemoryInfo().totalMem;
170     final long availMem = new ActivityManager.MemoryInfo().availMem;
171   }
172 
173   DeviceInfo deviceInfo = new DeviceInfo();
174 
BenchmarkMetric( final BenchmarkModel benchmarkModel, final String metric, final double actualValue, final double targetValue)175   public BenchmarkMetric(
176       final BenchmarkModel benchmarkModel,
177       final String metric,
178       final double actualValue,
179       final double targetValue) {
180     this.benchmarkModel = benchmarkModel;
181     this.metric = metric;
182     this.actualValue = actualValue;
183     this.targetValue = targetValue;
184   }
185 
186   // TODO (huydhn): Figure out a way to extract the backend and quantization information from
187   // the .pte model itself instead of parsing its name
extractBackendAndQuantization(final String model)188   public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) {
189     final Matcher m =
190         Pattern.compile("(?<name>\\w+)_(?<backend>\\w+)_(?<quantization>\\w+)").matcher(model);
191     if (m.matches()) {
192       return new BenchmarkMetric.BenchmarkModel(
193           m.group("name"), m.group("backend"), m.group("quantization"));
194     } else {
195       return new BenchmarkMetric.BenchmarkModel(model, "", "");
196     }
197   }
198 }
199 
200 class StatsDump {
201   int loadStatus;
202   long loadStart;
203   long loadEnd;
204   long generateStart;
205   long generateEnd;
206   String tokens;
207   String modelName;
208 
209   @NonNull
210   @Override
toString()211   public String toString() {
212     return "loadStart: "
213         + loadStart
214         + "\nloadEnd: "
215         + loadEnd
216         + "\ngenerateStart: "
217         + generateStart
218         + "\ngenerateEnd: "
219         + generateEnd
220         + "\n"
221         + tokens;
222   }
223 }
224