• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Qualcomm Innovation Center, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // A simple llama3.2 runner that includes preprocessing and post processing
10 // logic. The module takes in a string as input and emits a string as output.
11 
12 #pragma once
13 
14 #include <cstdint>
15 #include <functional>
16 #include <memory>
17 #include <string>
18 #include <unordered_map>
19 
20 #include <executorch/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h>
21 #include <executorch/extension/llm/sampler/sampler.h>
22 #include <executorch/extension/llm/tokenizer/tokenizer.h>
23 #include <executorch/extension/module/module.h>
24 
25 namespace example {
26 
27 class Runner {
28  public:
29   explicit Runner(
30       const std::vector<std::string>& models_path,
31       const std::string& tokenizer_path,
32       const float temperature);
33 
34   struct Stats {
35     // Scaling factor for timestamps - in this case, we use ms.
36     const long SCALING_FACTOR_UNITS_PER_SECOND = 1000;
37     // Time stamps for the different stages of the execution
38     // model_load_start_ms: Start of model loading.
39     long model_load_start_ms;
40     // model_load_end_ms: End of model loading.
41     long model_load_end_ms;
42     // inference_start_ms: Immediately after the model is loaded (or we check
43     // for model load), measure the inference time.
44     long inference_start_ms;
45     // prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right
46     // before the inference loop starts
47     long prompt_eval_end_ms;
48     // first_token: Timestamp when the first generated token is emitted
49     long first_token_ms;
50     // inference_end_ms: End of inference/generation.
51     long inference_end_ms;
52     // Keep a running total of the time spent in sampling.
53     long aggregate_sampling_time_ms;
54     // Token count from prompt
55     int64_t num_prompt_tokens;
56     // Token count from generated (total - prompt)
57     int64_t num_generated_tokens;
58   };
59 
60   bool is_loaded() const;
61   executorch::runtime::Error load();
62   executorch::runtime::Error generate(
63       const std::string& prompt,
64       const std::string& system_prompt,
65       int32_t seq_len,
66       std::function<void(const std::string&)> token_callback = {},
67       std::function<void(const Stats&)> stats_callback = {});
68   void stop();
69   std::vector<executorch::runtime::Result<executorch::runtime::MethodMeta>>
70   get_methods_meta();
71 
72  private:
73   template <typename T>
74   int32_t logitsToToken(const executorch::aten::Tensor& logits_tensor);
75   void run_model_step(
76       std::vector<std::vector<executorch::runtime::EValue>>& inputs);
77   // metadata
78   int32_t bos_id_;
79   std::unordered_set<uint64_t> eos_id_;
80   const int32_t n_bos_;
81   const int32_t n_eos_;
82   const int32_t vocab_size_;
83   const int32_t max_seq_len_;
84   std::vector<std::shared_ptr<executorch::extension::Module>> modules_;
85   std::string tokenizer_path_;
86   float temperature_;
87   std::unique_ptr<executorch::extension::llm::Tokenizer> tokenizer_;
88   std::unique_ptr<executorch::extension::llm::Sampler> sampler_;
89   Stats stats_;
90   std::unique_ptr<Memory> io_mem_;
91 };
92 
93 } // namespace example
94