1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
17 #define TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
18
19 #include <cmath>
20 #include <cstdint>
21 #include <limits>
22 #include <ostream>
23 #include <string>
24 #include <unordered_set>
25 #include <vector>
26
27 #include "tensorflow/core/util/stats_calculator.h"
28 #include "tensorflow/lite/c/common.h"
29 #include "tensorflow/lite/profiling/memory_info.h"
30 #include "tensorflow/lite/profiling/memory_usage_monitor.h"
31 #include "tensorflow/lite/tools/benchmark/benchmark_params.h"
32 #include "tensorflow/lite/tools/command_line_flags.h"
33
34 namespace tflite {
35 namespace benchmark {
36
37 enum RunType {
38 WARMUP,
39 REGULAR,
40 };
41
42 class BenchmarkResults {
43 public:
BenchmarkResults()44 BenchmarkResults() {}
BenchmarkResults(double model_size_mb,int64_t startup_latency_us,uint64_t input_bytes,tensorflow::Stat<int64_t> warmup_time_us,tensorflow::Stat<int64_t> inference_time_us,const profiling::memory::MemoryUsage & init_mem_usage,const profiling::memory::MemoryUsage & overall_mem_usage,float peak_mem_mb)45 BenchmarkResults(double model_size_mb, int64_t startup_latency_us,
46 uint64_t input_bytes,
47 tensorflow::Stat<int64_t> warmup_time_us,
48 tensorflow::Stat<int64_t> inference_time_us,
49 const profiling::memory::MemoryUsage& init_mem_usage,
50 const profiling::memory::MemoryUsage& overall_mem_usage,
51 float peak_mem_mb)
52 : model_size_mb_(model_size_mb),
53 startup_latency_us_(startup_latency_us),
54 input_bytes_(input_bytes),
55 warmup_time_us_(warmup_time_us),
56 inference_time_us_(inference_time_us),
57 init_mem_usage_(init_mem_usage),
58 overall_mem_usage_(overall_mem_usage),
59 peak_mem_mb_(peak_mem_mb) {}
60
model_size_mb()61 const double model_size_mb() const { return model_size_mb_; }
inference_time_us()62 tensorflow::Stat<int64_t> inference_time_us() const {
63 return inference_time_us_;
64 }
warmup_time_us()65 tensorflow::Stat<int64_t> warmup_time_us() const { return warmup_time_us_; }
startup_latency_us()66 int64_t startup_latency_us() const { return startup_latency_us_; }
input_bytes()67 uint64_t input_bytes() const { return input_bytes_; }
throughput_MB_per_second()68 double throughput_MB_per_second() const {
69 double bytes_per_sec = (input_bytes_ * inference_time_us_.count() * 1e6) /
70 inference_time_us_.sum();
71 return bytes_per_sec / (1024.0 * 1024.0);
72 }
73
init_mem_usage()74 const profiling::memory::MemoryUsage& init_mem_usage() const {
75 return init_mem_usage_;
76 }
overall_mem_usage()77 const profiling::memory::MemoryUsage& overall_mem_usage() const {
78 return overall_mem_usage_;
79 }
peak_mem_mb()80 float peak_mem_mb() const { return peak_mem_mb_; }
81
82 private:
83 double model_size_mb_ = 0.0;
84 int64_t startup_latency_us_ = 0;
85 uint64_t input_bytes_ = 0;
86 tensorflow::Stat<int64_t> warmup_time_us_;
87 tensorflow::Stat<int64_t> inference_time_us_;
88 profiling::memory::MemoryUsage init_mem_usage_;
89 profiling::memory::MemoryUsage overall_mem_usage_;
90 // An invalid value could happen when we don't monitor memory footprint for
91 // the inference, or the memory usage info isn't available on the benchmarking
92 // platform.
93 float peak_mem_mb_ =
94 profiling::memory::MemoryUsageMonitor::kInvalidMemUsageMB;
95 };
96
97 class BenchmarkListener {
98 public:
99 // Called before the (outer) inference loop begins.
100 // Note that this is called *after* the interpreter has been initialized, but
101 // *before* any warmup runs have been executed.
OnBenchmarkStart(const BenchmarkParams & params)102 virtual void OnBenchmarkStart(const BenchmarkParams& params) {}
103 // Called before a single (inner) inference call starts.
OnSingleRunStart(RunType runType)104 virtual void OnSingleRunStart(RunType runType) {}
105 // Called before a single (inner) inference call ends.
OnSingleRunEnd()106 virtual void OnSingleRunEnd() {}
107 // Called after the (outer) inference loop begins.
OnBenchmarkEnd(const BenchmarkResults & results)108 virtual void OnBenchmarkEnd(const BenchmarkResults& results) {}
~BenchmarkListener()109 virtual ~BenchmarkListener() {}
110 };
111
112 // A listener that forwards its method calls to a collection of listeners.
113 class BenchmarkListeners : public BenchmarkListener {
114 public:
115 // Added a listener to the listener collection.
116 // |listener| is not owned by the instance of |BenchmarkListeners|.
117 // |listener| should not be null and should outlast the instance of
118 // |BenchmarkListeners|.
AddListener(BenchmarkListener * listener)119 void AddListener(BenchmarkListener* listener) {
120 listeners_.push_back(listener);
121 }
122
123 // Remove all listeners after [index] including the one at 'index'.
RemoveListeners(int index)124 void RemoveListeners(int index) {
125 if (index >= NumListeners()) return;
126 listeners_.resize(index);
127 }
128
NumListeners()129 int NumListeners() const { return listeners_.size(); }
130
OnBenchmarkStart(const BenchmarkParams & params)131 void OnBenchmarkStart(const BenchmarkParams& params) override {
132 for (auto listener : listeners_) {
133 listener->OnBenchmarkStart(params);
134 }
135 }
136
OnSingleRunStart(RunType runType)137 void OnSingleRunStart(RunType runType) override {
138 for (auto listener : listeners_) {
139 listener->OnSingleRunStart(runType);
140 }
141 }
142
OnSingleRunEnd()143 void OnSingleRunEnd() override {
144 for (auto listener : listeners_) {
145 listener->OnSingleRunEnd();
146 }
147 }
148
OnBenchmarkEnd(const BenchmarkResults & results)149 void OnBenchmarkEnd(const BenchmarkResults& results) override {
150 for (auto listener : listeners_) {
151 listener->OnBenchmarkEnd(results);
152 }
153 }
154
~BenchmarkListeners()155 ~BenchmarkListeners() override {}
156
157 private:
158 // Use vector so listeners are invoked in the order they are added.
159 std::vector<BenchmarkListener*> listeners_;
160 };
161
162 // Benchmark listener that just logs the results of benchmark run.
163 class BenchmarkLoggingListener : public BenchmarkListener {
164 public:
165 void OnBenchmarkEnd(const BenchmarkResults& results) override;
166 };
167
168 template <typename T>
CreateFlag(const char * name,BenchmarkParams * params,const std::string & usage)169 Flag CreateFlag(const char* name, BenchmarkParams* params,
170 const std::string& usage) {
171 return Flag(
172 name,
173 [params, name](const T& val, int argv_position) {
174 params->Set<T>(name, val, argv_position);
175 },
176 params->Get<T>(name), usage, Flag::kOptional);
177 }
178
179 // Benchmarks a model.
180 //
181 // Subclasses need to implement initialization and running of the model.
182 // The results can be collected by adding BenchmarkListener(s).
183 class BenchmarkModel {
184 public:
185 static BenchmarkParams DefaultParams();
186 BenchmarkModel();
BenchmarkModel(BenchmarkParams params)187 explicit BenchmarkModel(BenchmarkParams params)
188 : params_(std::move(params)) {}
~BenchmarkModel()189 virtual ~BenchmarkModel() {}
190 virtual TfLiteStatus Init() = 0;
191 TfLiteStatus Run(int argc, char** argv);
192 virtual TfLiteStatus Run();
AddListener(BenchmarkListener * listener)193 void AddListener(BenchmarkListener* listener) {
194 listeners_.AddListener(listener);
195 }
196 // Remove all listeners after [index] including the one at 'index'.
RemoveListeners(int index)197 void RemoveListeners(int index) { listeners_.RemoveListeners(index); }
NumListeners()198 int NumListeners() const { return listeners_.NumListeners(); }
199
mutable_params()200 BenchmarkParams* mutable_params() { return ¶ms_; }
201
202 // Unparsable flags will remain in 'argv' in the original order and 'argc'
203 // will be updated accordingly.
204 TfLiteStatus ParseFlags(int* argc, char** argv);
205
206 protected:
207 virtual void LogParams();
208 virtual TfLiteStatus ValidateParams();
209
ParseFlags(int argc,char ** argv)210 TfLiteStatus ParseFlags(int argc, char** argv) {
211 return ParseFlags(&argc, argv);
212 }
213 virtual std::vector<Flag> GetFlags();
214
215 // Get the model file size if it's available.
MayGetModelFileSize()216 virtual int64_t MayGetModelFileSize() { return -1; }
217 virtual uint64_t ComputeInputBytes() = 0;
218 virtual tensorflow::Stat<int64_t> Run(int min_num_times, float min_secs,
219 float max_secs, RunType run_type,
220 TfLiteStatus* invoke_status);
221 // Prepares input data for benchmark. This can be used to initialize input
222 // data that has non-trivial cost.
223 virtual TfLiteStatus PrepareInputData();
224
225 virtual TfLiteStatus ResetInputsAndOutputs();
226 virtual TfLiteStatus RunImpl() = 0;
227
228 // Create a MemoryUsageMonitor to report peak memory footprint if specified.
229 virtual std::unique_ptr<profiling::memory::MemoryUsageMonitor>
230 MayCreateMemoryUsageMonitor() const;
231
232 BenchmarkParams params_;
233 BenchmarkListeners listeners_;
234 };
235
236 } // namespace benchmark
237 } // namespace tflite
238
239 #endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
240