• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
17 #define TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
18 
19 #include <cmath>
20 #include <cstdint>
21 #include <limits>
22 #include <ostream>
23 #include <string>
24 #include <unordered_set>
25 #include <vector>
26 
27 #include "tensorflow/core/util/stats_calculator.h"
28 #include "tensorflow/lite/c/common.h"
29 #include "tensorflow/lite/profiling/memory_info.h"
30 #include "tensorflow/lite/profiling/memory_usage_monitor.h"
31 #include "tensorflow/lite/tools/benchmark/benchmark_params.h"
32 #include "tensorflow/lite/tools/command_line_flags.h"
33 
34 namespace tflite {
35 namespace benchmark {
36 
37 enum RunType {
38   WARMUP,
39   REGULAR,
40 };
41 
42 class BenchmarkResults {
43  public:
BenchmarkResults()44   BenchmarkResults() {}
BenchmarkResults(double model_size_mb,int64_t startup_latency_us,uint64_t input_bytes,tensorflow::Stat<int64_t> warmup_time_us,tensorflow::Stat<int64_t> inference_time_us,const profiling::memory::MemoryUsage & init_mem_usage,const profiling::memory::MemoryUsage & overall_mem_usage,float peak_mem_mb)45   BenchmarkResults(double model_size_mb, int64_t startup_latency_us,
46                    uint64_t input_bytes,
47                    tensorflow::Stat<int64_t> warmup_time_us,
48                    tensorflow::Stat<int64_t> inference_time_us,
49                    const profiling::memory::MemoryUsage& init_mem_usage,
50                    const profiling::memory::MemoryUsage& overall_mem_usage,
51                    float peak_mem_mb)
52       : model_size_mb_(model_size_mb),
53         startup_latency_us_(startup_latency_us),
54         input_bytes_(input_bytes),
55         warmup_time_us_(warmup_time_us),
56         inference_time_us_(inference_time_us),
57         init_mem_usage_(init_mem_usage),
58         overall_mem_usage_(overall_mem_usage),
59         peak_mem_mb_(peak_mem_mb) {}
60 
model_size_mb()61   const double model_size_mb() const { return model_size_mb_; }
inference_time_us()62   tensorflow::Stat<int64_t> inference_time_us() const {
63     return inference_time_us_;
64   }
warmup_time_us()65   tensorflow::Stat<int64_t> warmup_time_us() const { return warmup_time_us_; }
startup_latency_us()66   int64_t startup_latency_us() const { return startup_latency_us_; }
input_bytes()67   uint64_t input_bytes() const { return input_bytes_; }
throughput_MB_per_second()68   double throughput_MB_per_second() const {
69     double bytes_per_sec = (input_bytes_ * inference_time_us_.count() * 1e6) /
70                            inference_time_us_.sum();
71     return bytes_per_sec / (1024.0 * 1024.0);
72   }
73 
init_mem_usage()74   const profiling::memory::MemoryUsage& init_mem_usage() const {
75     return init_mem_usage_;
76   }
overall_mem_usage()77   const profiling::memory::MemoryUsage& overall_mem_usage() const {
78     return overall_mem_usage_;
79   }
peak_mem_mb()80   float peak_mem_mb() const { return peak_mem_mb_; }
81 
82  private:
83   double model_size_mb_ = 0.0;
84   int64_t startup_latency_us_ = 0;
85   uint64_t input_bytes_ = 0;
86   tensorflow::Stat<int64_t> warmup_time_us_;
87   tensorflow::Stat<int64_t> inference_time_us_;
88   profiling::memory::MemoryUsage init_mem_usage_;
89   profiling::memory::MemoryUsage overall_mem_usage_;
90   // An invalid value could happen when we don't monitor memory footprint for
91   // the inference, or the memory usage info isn't available on the benchmarking
92   // platform.
93   float peak_mem_mb_ =
94       profiling::memory::MemoryUsageMonitor::kInvalidMemUsageMB;
95 };
96 
97 class BenchmarkListener {
98  public:
99   // Called before the (outer) inference loop begins.
100   // Note that this is called *after* the interpreter has been initialized, but
101   // *before* any warmup runs have been executed.
OnBenchmarkStart(const BenchmarkParams & params)102   virtual void OnBenchmarkStart(const BenchmarkParams& params) {}
103   // Called before a single (inner) inference call starts.
OnSingleRunStart(RunType runType)104   virtual void OnSingleRunStart(RunType runType) {}
105   // Called before a single (inner) inference call ends.
OnSingleRunEnd()106   virtual void OnSingleRunEnd() {}
107   // Called after the (outer) inference loop begins.
OnBenchmarkEnd(const BenchmarkResults & results)108   virtual void OnBenchmarkEnd(const BenchmarkResults& results) {}
~BenchmarkListener()109   virtual ~BenchmarkListener() {}
110 };
111 
112 // A listener that forwards its method calls to a collection of listeners.
113 class BenchmarkListeners : public BenchmarkListener {
114  public:
115   // Added a listener to the listener collection.
116   // |listener| is not owned by the instance of |BenchmarkListeners|.
117   // |listener| should not be null and should outlast the instance of
118   // |BenchmarkListeners|.
AddListener(BenchmarkListener * listener)119   void AddListener(BenchmarkListener* listener) {
120     listeners_.push_back(listener);
121   }
122 
123   // Remove all listeners after [index] including the one at 'index'.
RemoveListeners(int index)124   void RemoveListeners(int index) {
125     if (index >= NumListeners()) return;
126     listeners_.resize(index);
127   }
128 
NumListeners()129   int NumListeners() const { return listeners_.size(); }
130 
OnBenchmarkStart(const BenchmarkParams & params)131   void OnBenchmarkStart(const BenchmarkParams& params) override {
132     for (auto listener : listeners_) {
133       listener->OnBenchmarkStart(params);
134     }
135   }
136 
OnSingleRunStart(RunType runType)137   void OnSingleRunStart(RunType runType) override {
138     for (auto listener : listeners_) {
139       listener->OnSingleRunStart(runType);
140     }
141   }
142 
OnSingleRunEnd()143   void OnSingleRunEnd() override {
144     for (auto listener : listeners_) {
145       listener->OnSingleRunEnd();
146     }
147   }
148 
OnBenchmarkEnd(const BenchmarkResults & results)149   void OnBenchmarkEnd(const BenchmarkResults& results) override {
150     for (auto listener : listeners_) {
151       listener->OnBenchmarkEnd(results);
152     }
153   }
154 
~BenchmarkListeners()155   ~BenchmarkListeners() override {}
156 
157  private:
158   // Use vector so listeners are invoked in the order they are added.
159   std::vector<BenchmarkListener*> listeners_;
160 };
161 
162 // Benchmark listener that just logs the results of benchmark run.
163 class BenchmarkLoggingListener : public BenchmarkListener {
164  public:
165   void OnBenchmarkEnd(const BenchmarkResults& results) override;
166 };
167 
168 template <typename T>
CreateFlag(const char * name,BenchmarkParams * params,const std::string & usage)169 Flag CreateFlag(const char* name, BenchmarkParams* params,
170                 const std::string& usage) {
171   return Flag(
172       name,
173       [params, name](const T& val, int argv_position) {
174         params->Set<T>(name, val, argv_position);
175       },
176       params->Get<T>(name), usage, Flag::kOptional);
177 }
178 
179 // Benchmarks a model.
180 //
181 // Subclasses need to implement initialization and running of the model.
182 // The results can be collected by adding BenchmarkListener(s).
183 class BenchmarkModel {
184  public:
185   static BenchmarkParams DefaultParams();
186   BenchmarkModel();
BenchmarkModel(BenchmarkParams params)187   explicit BenchmarkModel(BenchmarkParams params)
188       : params_(std::move(params)) {}
~BenchmarkModel()189   virtual ~BenchmarkModel() {}
190   virtual TfLiteStatus Init() = 0;
191   TfLiteStatus Run(int argc, char** argv);
192   virtual TfLiteStatus Run();
AddListener(BenchmarkListener * listener)193   void AddListener(BenchmarkListener* listener) {
194     listeners_.AddListener(listener);
195   }
196   // Remove all listeners after [index] including the one at 'index'.
RemoveListeners(int index)197   void RemoveListeners(int index) { listeners_.RemoveListeners(index); }
NumListeners()198   int NumListeners() const { return listeners_.NumListeners(); }
199 
mutable_params()200   BenchmarkParams* mutable_params() { return &params_; }
201 
202   // Unparsable flags will remain in 'argv' in the original order and 'argc'
203   // will be updated accordingly.
204   TfLiteStatus ParseFlags(int* argc, char** argv);
205 
206  protected:
207   virtual void LogParams();
208   virtual TfLiteStatus ValidateParams();
209 
ParseFlags(int argc,char ** argv)210   TfLiteStatus ParseFlags(int argc, char** argv) {
211     return ParseFlags(&argc, argv);
212   }
213   virtual std::vector<Flag> GetFlags();
214 
215   // Get the model file size if it's available.
MayGetModelFileSize()216   virtual int64_t MayGetModelFileSize() { return -1; }
217   virtual uint64_t ComputeInputBytes() = 0;
218   virtual tensorflow::Stat<int64_t> Run(int min_num_times, float min_secs,
219                                         float max_secs, RunType run_type,
220                                         TfLiteStatus* invoke_status);
221   // Prepares input data for benchmark. This can be used to initialize input
222   // data that has non-trivial cost.
223   virtual TfLiteStatus PrepareInputData();
224 
225   virtual TfLiteStatus ResetInputsAndOutputs();
226   virtual TfLiteStatus RunImpl() = 0;
227 
228   // Create a MemoryUsageMonitor to report peak memory footprint if specified.
229   virtual std::unique_ptr<profiling::memory::MemoryUsageMonitor>
230   MayCreateMemoryUsageMonitor() const;
231 
232   BenchmarkParams params_;
233   BenchmarkListeners listeners_;
234 };
235 
236 }  // namespace benchmark
237 }  // namespace tflite
238 
239 #endif  // TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
240