1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 package com.example.executorchllamademo; 10 11 import android.app.Activity; 12 import android.app.ActivityManager; 13 import android.content.Intent; 14 import android.os.Build; 15 import android.os.Bundle; 16 import android.util.Log; 17 import android.widget.TextView; 18 import androidx.annotation.NonNull; 19 import com.google.gson.Gson; 20 import java.io.File; 21 import java.io.FileWriter; 22 import java.io.IOException; 23 import java.util.ArrayList; 24 import java.util.Arrays; 25 import java.util.List; 26 import java.util.regex.Matcher; 27 import java.util.regex.Pattern; 28 29 public class LlmBenchmarkRunner extends Activity implements ModelRunnerCallback { 30 ModelRunner mModelRunner; 31 32 String mPrompt; 33 TextView mTextView; 34 StatsDump mStatsDump; 35 36 @Override onCreate(Bundle savedInstanceState)37 protected void onCreate(Bundle savedInstanceState) { 38 super.onCreate(savedInstanceState); 39 setContentView(R.layout.activity_benchmarking); 40 mTextView = findViewById(R.id.log_view); 41 42 Intent intent = getIntent(); 43 44 File modelDir = new File(intent.getStringExtra("model_dir")); 45 File model = 46 Arrays.stream(modelDir.listFiles()) 47 .filter(file -> file.getName().endsWith(".pte")) 48 .findFirst() 49 .get(); 50 String tokenizerPath = intent.getStringExtra("tokenizer_path"); 51 52 float temperature = intent.getFloatExtra("temperature", 0.8f); 53 mPrompt = intent.getStringExtra("prompt"); 54 if (mPrompt == null) { 55 mPrompt = "The ultimate answer"; 56 } 57 58 mStatsDump = new StatsDump(); 59 mStatsDump.modelName = model.getName().replace(".pte", ""); 60 mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this); 61 mStatsDump.loadStart = System.nanoTime(); 62 } 63 64 @Override onModelLoaded(int status)65 public void onModelLoaded(int status) { 66 mStatsDump.loadEnd = System.nanoTime(); 67 mStatsDump.loadStatus = status; 68 if (status != 0) { 69 Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); 70 onGenerationStopped(); 71 return; 72 } 73 mStatsDump.generateStart = System.nanoTime(); 74 mModelRunner.generate(mPrompt); 75 } 76 77 @Override onTokenGenerated(String token)78 public void onTokenGenerated(String token) { 79 runOnUiThread( 80 () -> { 81 mTextView.append(token); 82 }); 83 } 84 85 @Override onStats(String stats)86 public void onStats(String stats) { 87 mStatsDump.tokens = stats; 88 } 89 90 @Override onGenerationStopped()91 public void onGenerationStopped() { 92 mStatsDump.generateEnd = System.nanoTime(); 93 runOnUiThread( 94 () -> { 95 mTextView.append(mStatsDump.toString()); 96 }); 97 98 final BenchmarkMetric.BenchmarkModel benchmarkModel = 99 BenchmarkMetric.extractBackendAndQuantization(mStatsDump.modelName); 100 final List<BenchmarkMetric> results = new ArrayList<>(); 101 // The list of metrics we have atm includes: 102 // Load status 103 results.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsDump.loadStatus, 0)); 104 // Model load time 105 results.add( 106 new BenchmarkMetric( 107 benchmarkModel, 108 "model_load_time(ms)", 109 (mStatsDump.loadEnd - mStatsDump.loadStart) * 1e-6, 110 0.0f)); 111 // LLM generate time 112 results.add( 113 new BenchmarkMetric( 114 benchmarkModel, 115 "generate_time(ms)", 116 (mStatsDump.generateEnd - mStatsDump.generateStart) * 1e-6, 117 0.0f)); 118 // Token per second 119 results.add( 120 new BenchmarkMetric(benchmarkModel, "token_per_sec", extractTPS(mStatsDump.tokens), 0.0f)); 121 122 try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { 123 Gson gson = new Gson(); 124 writer.write(gson.toJson(results)); 125 } catch (IOException e) { 126 e.printStackTrace(); 127 } 128 } 129 extractTPS(final String tokens)130 private double extractTPS(final String tokens) { 131 final Matcher m = Pattern.compile("\\d+\\.?\\d*").matcher(tokens); 132 if (m.find()) { 133 return Double.parseDouble(m.group()); 134 } else { 135 return 0.0f; 136 } 137 } 138 } 139 140 class BenchmarkMetric { 141 public static class BenchmarkModel { 142 // The model name, i.e. stories110M 143 String name; 144 String backend; 145 String quantization; 146 BenchmarkModel(final String name, final String backend, final String quantization)147 public BenchmarkModel(final String name, final String backend, final String quantization) { 148 this.name = name; 149 this.backend = backend; 150 this.quantization = quantization; 151 } 152 } 153 154 BenchmarkModel benchmarkModel; 155 156 // The metric name, i.e. TPS 157 String metric; 158 159 // The actual value and the option target value 160 double actualValue; 161 double targetValue; 162 163 public static class DeviceInfo { 164 // Let's see which information we want to include here 165 final String device = Build.BRAND; 166 // The phone model and Android release version 167 final String arch = Build.MODEL; 168 final String os = "Android " + Build.VERSION.RELEASE; 169 final long totalMem = new ActivityManager.MemoryInfo().totalMem; 170 final long availMem = new ActivityManager.MemoryInfo().availMem; 171 } 172 173 DeviceInfo deviceInfo = new DeviceInfo(); 174 BenchmarkMetric( final BenchmarkModel benchmarkModel, final String metric, final double actualValue, final double targetValue)175 public BenchmarkMetric( 176 final BenchmarkModel benchmarkModel, 177 final String metric, 178 final double actualValue, 179 final double targetValue) { 180 this.benchmarkModel = benchmarkModel; 181 this.metric = metric; 182 this.actualValue = actualValue; 183 this.targetValue = targetValue; 184 } 185 186 // TODO (huydhn): Figure out a way to extract the backend and quantization information from 187 // the .pte model itself instead of parsing its name extractBackendAndQuantization(final String model)188 public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) { 189 final Matcher m = 190 Pattern.compile("(?<name>\\w+)_(?<backend>\\w+)_(?<quantization>\\w+)").matcher(model); 191 if (m.matches()) { 192 return new BenchmarkMetric.BenchmarkModel( 193 m.group("name"), m.group("backend"), m.group("quantization")); 194 } else { 195 return new BenchmarkMetric.BenchmarkModel(model, "", ""); 196 } 197 } 198 } 199 200 class StatsDump { 201 int loadStatus; 202 long loadStart; 203 long loadEnd; 204 long generateStart; 205 long generateEnd; 206 String tokens; 207 String modelName; 208 209 @NonNull 210 @Override toString()211 public String toString() { 212 return "loadStart: " 213 + loadStart 214 + "\nloadEnd: " 215 + loadEnd 216 + "\ngenerateStart: " 217 + generateStart 218 + "\ngenerateEnd: " 219 + generateEnd 220 + "\n" 221 + tokens; 222 } 223 } 224