• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
17 #define TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
18 
19 #include <cmath>
20 #include <cstdint>
21 #include <limits>
22 #include <ostream>
23 #include <string>
24 #include <unordered_set>
25 #include <vector>
26 
27 #include "tensorflow/core/util/stats_calculator.h"
28 #include "tensorflow/lite/c/common.h"
29 #include "tensorflow/lite/profiling/memory_info.h"
30 #include "tensorflow/lite/tools/benchmark/benchmark_params.h"
31 #include "tensorflow/lite/tools/command_line_flags.h"
32 
33 namespace tflite {
34 namespace benchmark {
35 
36 enum RunType {
37   WARMUP,
38   REGULAR,
39 };
40 
41 class BenchmarkResults {
42  public:
BenchmarkResults()43   BenchmarkResults() {}
BenchmarkResults(double model_size_mb,int64_t startup_latency_us,uint64_t input_bytes,tensorflow::Stat<int64_t> warmup_time_us,tensorflow::Stat<int64_t> inference_time_us,const profiling::memory::MemoryUsage & init_mem_usage,const profiling::memory::MemoryUsage & overall_mem_usage)44   BenchmarkResults(double model_size_mb, int64_t startup_latency_us,
45                    uint64_t input_bytes,
46                    tensorflow::Stat<int64_t> warmup_time_us,
47                    tensorflow::Stat<int64_t> inference_time_us,
48                    const profiling::memory::MemoryUsage& init_mem_usage,
49                    const profiling::memory::MemoryUsage& overall_mem_usage)
50       : model_size_mb_(model_size_mb),
51         startup_latency_us_(startup_latency_us),
52         input_bytes_(input_bytes),
53         warmup_time_us_(warmup_time_us),
54         inference_time_us_(inference_time_us),
55         init_mem_usage_(init_mem_usage),
56         overall_mem_usage_(overall_mem_usage) {}
57 
model_size_mb()58   const double model_size_mb() const { return model_size_mb_; }
inference_time_us()59   tensorflow::Stat<int64_t> inference_time_us() const {
60     return inference_time_us_;
61   }
warmup_time_us()62   tensorflow::Stat<int64_t> warmup_time_us() const { return warmup_time_us_; }
startup_latency_us()63   int64_t startup_latency_us() const { return startup_latency_us_; }
input_bytes()64   uint64_t input_bytes() const { return input_bytes_; }
throughput_MB_per_second()65   double throughput_MB_per_second() const {
66     double bytes_per_sec = (input_bytes_ * inference_time_us_.count() * 1e6) /
67                            inference_time_us_.sum();
68     return bytes_per_sec / (1024.0 * 1024.0);
69   }
70 
init_mem_usage()71   const profiling::memory::MemoryUsage& init_mem_usage() const {
72     return init_mem_usage_;
73   }
overall_mem_usage()74   const profiling::memory::MemoryUsage& overall_mem_usage() const {
75     return overall_mem_usage_;
76   }
77 
78  private:
79   double model_size_mb_ = 0.0;
80   int64_t startup_latency_us_ = 0;
81   uint64_t input_bytes_ = 0;
82   tensorflow::Stat<int64_t> warmup_time_us_;
83   tensorflow::Stat<int64_t> inference_time_us_;
84   profiling::memory::MemoryUsage init_mem_usage_;
85   profiling::memory::MemoryUsage overall_mem_usage_;
86 };
87 
88 class BenchmarkListener {
89  public:
90   // Called before the (outer) inference loop begins.
91   // Note that this is called *after* the interpreter has been initialized, but
92   // *before* any warmup runs have been executed.
OnBenchmarkStart(const BenchmarkParams & params)93   virtual void OnBenchmarkStart(const BenchmarkParams& params) {}
94   // Called before a single (inner) inference call starts.
OnSingleRunStart(RunType runType)95   virtual void OnSingleRunStart(RunType runType) {}
96   // Called before a single (inner) inference call ends.
OnSingleRunEnd()97   virtual void OnSingleRunEnd() {}
98   // Called after the (outer) inference loop begins.
OnBenchmarkEnd(const BenchmarkResults & results)99   virtual void OnBenchmarkEnd(const BenchmarkResults& results) {}
~BenchmarkListener()100   virtual ~BenchmarkListener() {}
101 };
102 
103 // A listener that forwards its method calls to a collection of listeners.
104 class BenchmarkListeners : public BenchmarkListener {
105  public:
106   // Added a listener to the listener collection.
107   // |listener| is not owned by the instance of |BenchmarkListeners|.
108   // |listener| should not be null and should outlast the instance of
109   // |BenchmarkListeners|.
AddListener(BenchmarkListener * listener)110   void AddListener(BenchmarkListener* listener) {
111     listeners_.push_back(listener);
112   }
113 
114   // Remove all listeners after [index] including the one at 'index'.
RemoveListeners(int index)115   void RemoveListeners(int index) {
116     if (index >= NumListeners()) return;
117     listeners_.resize(index);
118   }
119 
NumListeners()120   int NumListeners() const { return listeners_.size(); }
121 
OnBenchmarkStart(const BenchmarkParams & params)122   void OnBenchmarkStart(const BenchmarkParams& params) override {
123     for (auto listener : listeners_) {
124       listener->OnBenchmarkStart(params);
125     }
126   }
127 
OnSingleRunStart(RunType runType)128   void OnSingleRunStart(RunType runType) override {
129     for (auto listener : listeners_) {
130       listener->OnSingleRunStart(runType);
131     }
132   }
133 
OnSingleRunEnd()134   void OnSingleRunEnd() override {
135     for (auto listener : listeners_) {
136       listener->OnSingleRunEnd();
137     }
138   }
139 
OnBenchmarkEnd(const BenchmarkResults & results)140   void OnBenchmarkEnd(const BenchmarkResults& results) override {
141     for (auto listener : listeners_) {
142       listener->OnBenchmarkEnd(results);
143     }
144   }
145 
~BenchmarkListeners()146   ~BenchmarkListeners() override {}
147 
148  private:
149   // Use vector so listeners are invoked in the order they are added.
150   std::vector<BenchmarkListener*> listeners_;
151 };
152 
153 // Benchmark listener that just logs the results of benchmark run.
154 class BenchmarkLoggingListener : public BenchmarkListener {
155  public:
156   void OnBenchmarkEnd(const BenchmarkResults& results) override;
157 };
158 
159 template <typename T>
CreateFlag(const char * name,BenchmarkParams * params,const std::string & usage)160 Flag CreateFlag(const char* name, BenchmarkParams* params,
161                 const std::string& usage) {
162   return Flag(
163       name, [params, name](const T& val) { params->Set<T>(name, val); },
164       params->Get<T>(name), usage, Flag::kOptional);
165 }
166 
167 // Benchmarks a model.
168 //
169 // Subclasses need to implement initialization and running of the model.
170 // The results can be collected by adding BenchmarkListener(s).
171 class BenchmarkModel {
172  public:
173   static BenchmarkParams DefaultParams();
174   BenchmarkModel();
BenchmarkModel(BenchmarkParams params)175   explicit BenchmarkModel(BenchmarkParams params)
176       : params_(std::move(params)) {}
~BenchmarkModel()177   virtual ~BenchmarkModel() {}
178   virtual TfLiteStatus Init() = 0;
179   TfLiteStatus Run(int argc, char** argv);
180   virtual TfLiteStatus Run();
AddListener(BenchmarkListener * listener)181   void AddListener(BenchmarkListener* listener) {
182     listeners_.AddListener(listener);
183   }
184   // Remove all listeners after [index] including the one at 'index'.
RemoveListeners(int index)185   void RemoveListeners(int index) { listeners_.RemoveListeners(index); }
NumListeners()186   int NumListeners() const { return listeners_.NumListeners(); }
187 
mutable_params()188   BenchmarkParams* mutable_params() { return &params_; }
189 
190   // Unparsable flags will remain in 'argv' in the original order and 'argc'
191   // will be updated accordingly.
192   TfLiteStatus ParseFlags(int* argc, char** argv);
193 
194  protected:
195   virtual void LogParams();
196   virtual TfLiteStatus ValidateParams();
197 
ParseFlags(int argc,char ** argv)198   TfLiteStatus ParseFlags(int argc, char** argv) {
199     return ParseFlags(&argc, argv);
200   }
201   virtual std::vector<Flag> GetFlags();
202 
203   // Get the model file size if it's available.
MayGetModelFileSize()204   virtual int64_t MayGetModelFileSize() { return -1; }
205   virtual uint64_t ComputeInputBytes() = 0;
206   virtual tensorflow::Stat<int64_t> Run(int min_num_times, float min_secs,
207                                         float max_secs, RunType run_type,
208                                         TfLiteStatus* invoke_status);
209   // Prepares input data for benchmark. This can be used to initialize input
210   // data that has non-trivial cost.
211   virtual TfLiteStatus PrepareInputData();
212 
213   virtual TfLiteStatus ResetInputsAndOutputs();
214   virtual TfLiteStatus RunImpl() = 0;
215   BenchmarkParams params_;
216   BenchmarkListeners listeners_;
217 };
218 
219 }  // namespace benchmark
220 }  // namespace tflite
221 
222 #endif  // TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
223