• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Qualcomm Innovation Center, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // A llama 3.2 runner that includes preprocessing and post processing
10 // logic. The module takes in a string as input and emits a string as output.
11 
12 #include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
13 #include <executorch/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h>
14 #include <executorch/extension/evalue_util/print_evalue.h>
15 #include <executorch/extension/llm/runner/util.h>
16 #include <executorch/runtime/core/exec_aten/exec_aten.h>
17 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
18 #include <executorch/runtime/platform/log.h>
19 #include <chrono>
20 #include <ctime>
21 #include <iostream>
22 #include <memory>
23 #include <sstream>
24 
25 using executorch::aten::Tensor;
26 using executorch::extension::Module;
27 using executorch::extension::llm::Sampler;
28 using executorch::extension::llm::time_in_ms;
29 using executorch::runtime::Error;
30 using executorch::runtime::EValue;
31 using executorch::runtime::MethodMeta;
32 using executorch::runtime::Result;
33 
34 namespace example {
35 
36 namespace {
37 static constexpr auto kTopp = 0.9f;
38 void printReport(const Runner::Stats& stats);
39 std::string statsToJsonString(const Runner::Stats& stats);
40 } // namespace
41 
Runner(const std::vector<std::string> & models_path,const std::string & tokenizer_path,const float temperature)42 Runner::Runner(
43     const std::vector<std::string>& models_path,
44     const std::string& tokenizer_path,
45     const float temperature)
46     : n_bos_(1),
47       n_eos_(1),
48       vocab_size_(QNN_LLAMA3_2_LOGITS),
49       max_seq_len_(QNN_LLAMA3_2_SEQLEN),
50       tokenizer_path_(tokenizer_path),
51       temperature_(temperature),
52       stats_({}) {
53   for (size_t i = 0; i < models_path.size(); ++i) {
54     modules_.push_back(std::make_shared<Module>(
55         models_path[i], Module::LoadMode::MmapUseMlockIgnoreErrors));
56     ET_LOG(Info, "creating module: model_path=%s", models_path[i].c_str());
57   }
58   ET_LOG(Info, "creating runner: tokenizer_path=%s", tokenizer_path_.c_str());
59 
60   tokenizer_ = example::get_tiktoken_for_llama();
61   Error err = tokenizer_->load(tokenizer_path_);
62   ET_CHECK_MSG(
63       err == Error::Ok, "failed to load tokenizer %s", tokenizer_path_.c_str());
64   eos_id_.insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]);
65   bos_id_ = tokenizer_->bos_tok();
66   eos_id_.insert(tokenizer_->eos_tok());
67   io_mem_ = std::make_unique<KVCachedMemory>(modules_);
68   ET_LOG(Info, "creating io_memory");
69 }
70 
is_loaded() const71 bool Runner::is_loaded() const {
72   bool loaded = true;
73   for (const std::shared_ptr<Module>& module : modules_) {
74     loaded &= module->is_loaded();
75   }
76   return loaded && tokenizer_ && sampler_;
77 }
78 
load()79 Error Runner::load() {
80   if (is_loaded()) {
81     return Error::Ok;
82   }
83   for (std::shared_ptr<Module>& module : modules_) {
84     ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("forward"));
85   }
86 
87   // create sampler
88   sampler_ = std::make_unique<Sampler>(
89       vocab_size_,
90       temperature_,
91       kTopp,
92       static_cast<unsigned long long>(std::time(nullptr)));
93 
94   // prepare io
95   auto methods_meta = get_methods_meta();
96   io_mem_->prepare_io(methods_meta);
97   return Error::Ok;
98 }
99 
100 template <typename T>
logitsToToken(const Tensor & logits_tensor)101 int32_t Runner::logitsToToken(const Tensor& logits_tensor) {
102   T* logits = logits_tensor.mutable_data_ptr<T>();
103 
104   // Since the logits are for all tokens, get the last token probabilities
105   T* logits_last = logits;
106   return sampler_->sample(logits_last);
107 }
108 
run_model_step(std::vector<std::vector<EValue>> & inputs)109 void Runner::run_model_step(std::vector<std::vector<EValue>>& inputs) {
110   for (size_t i = 0, num_modules = modules_.size(); i < num_modules; ++i) {
111     Result<std::vector<EValue>> outputs_res = modules_[i]->forward(inputs[i]);
112     ET_CHECK_MSG(
113         outputs_res.error() == Error::Ok, "shard %zu inference failed", i);
114   }
115 }
116 
generate(const std::string & prompt,const std::string & system_prompt,int32_t seq_len,std::function<void (const std::string &)> token_callback,std::function<void (const Stats &)> stats_callback)117 Error Runner::generate(
118     const std::string& prompt,
119     const std::string& system_prompt,
120     int32_t seq_len,
121     std::function<void(const std::string&)> token_callback,
122     std::function<void(const Stats&)> stats_callback) {
123   ET_CHECK_MSG(!prompt.empty(), "prompt cannot be null");
124 
125   std::vector<std::vector<Tensor>> input_tensors, output_tensors;
126   std::vector<std::vector<EValue>> inputs;
127   if (!is_loaded()) {
128     stats_.model_load_start_ms = time_in_ms();
129     ET_CHECK_OK_OR_RETURN_ERROR(load());
130     for (int i = 0; i < modules_.size(); ++i) {
131       input_tensors.emplace_back(io_mem_->get_input_tensors(i));
132       output_tensors.emplace_back(io_mem_->get_output_tensors(i));
133       for (size_t j = 0; j < output_tensors[i].size(); ++j) {
134         ET_CHECK_MSG(
135             modules_[i]->set_output(output_tensors[i][j], j) == Error::Ok,
136             "failed to set output tensor for module %d's %zu'th output",
137             i,
138             j);
139       }
140       inputs.emplace_back(
141           std::vector<EValue>(begin(input_tensors[i]), end(input_tensors[i])));
142     }
143     stats_.model_load_end_ms = time_in_ms();
144   }
145 
146   stats_.inference_start_ms = time_in_ms();
147   seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_;
148 
149   std::string post_process_prompt;
150 
151   if (!system_prompt.empty()) {
152     post_process_prompt.append(
153         "<|start_header_id|>system<|end_header_id|>\n\n");
154     post_process_prompt.append(system_prompt);
155     post_process_prompt.append("<|eot_id|>\n");
156   }
157   post_process_prompt.append("<|start_header_id|>user<|end_header_id|>\n\n");
158   post_process_prompt.append(prompt);
159   post_process_prompt.append(
160       "<|eot_id|><|start_header_id|>assistant<|end_header_id|>");
161   // tokenizer_->encode will add <|begin_of_text|> token for us.
162   // For now, do token call back so the output format looks the same as
163   // llama3 model card.
164   token_callback("<|begin_of_text|>");
165 
166   Result<std::vector<uint64_t>> encode_res =
167       tokenizer_->encode(post_process_prompt, n_bos_, 0);
168   ET_CHECK_OK_OR_RETURN_ERROR(
169       encode_res.error(),
170       "failed to encode prompt %s",
171       post_process_prompt.c_str());
172 
173   std::vector<uint64_t> prompt_tokens = encode_res.get();
174   int num_prompt_tokens = prompt_tokens.size();
175   ET_CHECK_MSG(num_prompt_tokens < max_seq_len_, "max seq length exceeded");
176   ET_CHECK_MSG(
177       num_prompt_tokens < seq_len,
178       "sequence length exceeded - please increase the seq_len value");
179 
180   int64_t pos = 0, prev_token, cur_token = prompt_tokens[0];
181   KVCachedMemory::IO* ptr =
182       static_cast<KVCachedMemory::IO*>(io_mem_->get_mutable_ptr());
183   ptr->input_tok = static_cast<int32_t>(cur_token);
184   ptr->attention_mask[max_seq_len_ - 1] = 0;
185 
186   std::vector<long long> postTime;
187   while (pos < seq_len - 1) {
188     // inference
189     run_model_step(inputs);
190     Tensor& logits_tensor = output_tensors.back()[0];
191 
192     if (pos == num_prompt_tokens) {
193       stats_.first_token_ms = time_in_ms();
194     } else if (pos == num_prompt_tokens - 1) {
195       stats_.prompt_eval_end_ms = time_in_ms();
196     }
197 
198     long sample_start_time_ms = time_in_ms();
199     prev_token = cur_token;
200     cur_token = logitsToToken<float>(logits_tensor);
201     stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms;
202 
203     if (pos < num_prompt_tokens - 1) {
204       cur_token = prompt_tokens[pos + 1];
205     }
206     io_mem_->update_io(cur_token, ++pos, output_tensors);
207     auto piece_res = tokenizer_->decode(prev_token, cur_token);
208     ET_CHECK(piece_res.ok());
209 
210     if (token_callback) {
211       token_callback(piece_res.get().c_str());
212     }
213 
214     if (pos >= num_prompt_tokens && eos_id_.count(cur_token) > 0) {
215       ET_LOG(Info, "\nReached to the end of generation");
216       break;
217     }
218   }
219   stats_.inference_end_ms = time_in_ms();
220   if (pos == seq_len) {
221     ET_LOG(Info, "\nSequence length (%i tokens) reached!", seq_len);
222   }
223 
224   stats_.num_prompt_tokens = num_prompt_tokens;
225   stats_.num_generated_tokens = pos - num_prompt_tokens;
226   printReport(stats_);
227   if (stats_callback) {
228     stats_callback(stats_);
229   }
230 
231   return Error::Ok;
232 }
233 
234 namespace {
printReport(const Runner::Stats & stats)235 void printReport(const Runner::Stats& stats) {
236   printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str());
237 
238   ET_LOG(
239       Info,
240       "\tPrompt Tokens: %" PRIu64 "    Generated Tokens: %" PRIu64,
241       stats.num_prompt_tokens,
242       stats.num_generated_tokens);
243 
244   ET_LOG(
245       Info,
246       "\tModel Load Time:\t\t%f (seconds)",
247       ((double)(stats.model_load_end_ms - stats.model_load_start_ms) /
248        stats.SCALING_FACTOR_UNITS_PER_SECOND));
249   double inference_time_ms =
250       (double)(stats.inference_end_ms - stats.inference_start_ms);
251   ET_LOG(
252       Info,
253       "\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
254       inference_time_ms / stats.SCALING_FACTOR_UNITS_PER_SECOND,
255 
256       (stats.num_generated_tokens) /
257           (double)(stats.inference_end_ms - stats.inference_start_ms) *
258           stats.SCALING_FACTOR_UNITS_PER_SECOND);
259   double prompt_eval_time =
260       (double)(stats.prompt_eval_end_ms - stats.inference_start_ms);
261   ET_LOG(
262       Info,
263       "\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
264       prompt_eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND,
265       (stats.num_prompt_tokens) / prompt_eval_time *
266           stats.SCALING_FACTOR_UNITS_PER_SECOND);
267 
268   double eval_time =
269       (double)(stats.inference_end_ms - stats.prompt_eval_end_ms);
270   ET_LOG(
271       Info,
272       "\t\tGenerated %" PRIu64
273       " tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
274       stats.num_generated_tokens,
275       eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND,
276       stats.num_generated_tokens / eval_time *
277           stats.SCALING_FACTOR_UNITS_PER_SECOND);
278 
279   // Time to first token is measured from the start of inference, excluding
280   // model load time.
281   ET_LOG(
282       Info,
283       "\tTime to first generated token:\t%f (seconds)",
284       ((double)(stats.first_token_ms - stats.inference_start_ms) /
285        stats.SCALING_FACTOR_UNITS_PER_SECOND));
286 
287   ET_LOG(
288       Info,
289       "\tSampling time over %" PRIu64 " tokens:\t%f (seconds)",
290       stats.num_prompt_tokens + stats.num_generated_tokens,
291       (double)stats.aggregate_sampling_time_ms /
292           stats.SCALING_FACTOR_UNITS_PER_SECOND);
293 }
294 
statsToJsonString(const Runner::Stats & stats)295 std::string statsToJsonString(const Runner::Stats& stats) {
296   std::stringstream ss;
297   ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << ","
298      << "\"generated_tokens\":" << stats.num_generated_tokens << ","
299      << "\"model_load_start_ms\":" << stats.model_load_start_ms << ","
300      << "\"model_load_end_ms\":" << stats.model_load_end_ms << ","
301      << "\"inference_start_ms\":" << stats.inference_start_ms << ","
302      << "\"inference_end_ms\":" << stats.inference_end_ms << ","
303      << "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << ","
304      << "\"first_token_ms\":" << stats.first_token_ms << ","
305      << "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms
306      << "," << "\"SCALING_FACTOR_UNITS_PER_SECOND\":"
307      << stats.SCALING_FACTOR_UNITS_PER_SECOND << "}";
308   return ss.str();
309 }
310 } // namespace
311 
get_methods_meta()312 std::vector<Result<MethodMeta>> Runner::get_methods_meta() {
313   std::vector<Result<MethodMeta>> methods_meta;
314   methods_meta.reserve(modules_.size());
315   for (std::shared_ptr<Module>& module : modules_) {
316     methods_meta.emplace_back(module->method_meta("forward"));
317   }
318   return methods_meta;
319 }
320 } // namespace example
321