1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
17 #define TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
18
19 #include <cmath>
20 #include <cstdint>
21 #include <limits>
22 #include <ostream>
23 #include <string>
24 #include <unordered_set>
25 #include <vector>
26
27 #include "tensorflow/core/util/stats_calculator.h"
28 #include "tensorflow/lite/c/common.h"
29 #include "tensorflow/lite/profiling/memory_info.h"
30 #include "tensorflow/lite/tools/benchmark/benchmark_params.h"
31 #include "tensorflow/lite/tools/command_line_flags.h"
32
33 namespace tflite {
34 namespace benchmark {
35
36 enum RunType {
37 WARMUP,
38 REGULAR,
39 };
40
41 class BenchmarkResults {
42 public:
BenchmarkResults()43 BenchmarkResults() {}
BenchmarkResults(double model_size_mb,int64_t startup_latency_us,uint64_t input_bytes,tensorflow::Stat<int64_t> warmup_time_us,tensorflow::Stat<int64_t> inference_time_us,const profiling::memory::MemoryUsage & init_mem_usage,const profiling::memory::MemoryUsage & overall_mem_usage)44 BenchmarkResults(double model_size_mb, int64_t startup_latency_us,
45 uint64_t input_bytes,
46 tensorflow::Stat<int64_t> warmup_time_us,
47 tensorflow::Stat<int64_t> inference_time_us,
48 const profiling::memory::MemoryUsage& init_mem_usage,
49 const profiling::memory::MemoryUsage& overall_mem_usage)
50 : model_size_mb_(model_size_mb),
51 startup_latency_us_(startup_latency_us),
52 input_bytes_(input_bytes),
53 warmup_time_us_(warmup_time_us),
54 inference_time_us_(inference_time_us),
55 init_mem_usage_(init_mem_usage),
56 overall_mem_usage_(overall_mem_usage) {}
57
model_size_mb()58 const double model_size_mb() const { return model_size_mb_; }
inference_time_us()59 tensorflow::Stat<int64_t> inference_time_us() const {
60 return inference_time_us_;
61 }
warmup_time_us()62 tensorflow::Stat<int64_t> warmup_time_us() const { return warmup_time_us_; }
startup_latency_us()63 int64_t startup_latency_us() const { return startup_latency_us_; }
input_bytes()64 uint64_t input_bytes() const { return input_bytes_; }
throughput_MB_per_second()65 double throughput_MB_per_second() const {
66 double bytes_per_sec = (input_bytes_ * inference_time_us_.count() * 1e6) /
67 inference_time_us_.sum();
68 return bytes_per_sec / (1024.0 * 1024.0);
69 }
70
init_mem_usage()71 const profiling::memory::MemoryUsage& init_mem_usage() const {
72 return init_mem_usage_;
73 }
overall_mem_usage()74 const profiling::memory::MemoryUsage& overall_mem_usage() const {
75 return overall_mem_usage_;
76 }
77
78 private:
79 double model_size_mb_ = 0.0;
80 int64_t startup_latency_us_ = 0;
81 uint64_t input_bytes_ = 0;
82 tensorflow::Stat<int64_t> warmup_time_us_;
83 tensorflow::Stat<int64_t> inference_time_us_;
84 profiling::memory::MemoryUsage init_mem_usage_;
85 profiling::memory::MemoryUsage overall_mem_usage_;
86 };
87
88 class BenchmarkListener {
89 public:
90 // Called before the (outer) inference loop begins.
91 // Note that this is called *after* the interpreter has been initialized, but
92 // *before* any warmup runs have been executed.
OnBenchmarkStart(const BenchmarkParams & params)93 virtual void OnBenchmarkStart(const BenchmarkParams& params) {}
94 // Called before a single (inner) inference call starts.
OnSingleRunStart(RunType runType)95 virtual void OnSingleRunStart(RunType runType) {}
96 // Called before a single (inner) inference call ends.
OnSingleRunEnd()97 virtual void OnSingleRunEnd() {}
98 // Called after the (outer) inference loop begins.
OnBenchmarkEnd(const BenchmarkResults & results)99 virtual void OnBenchmarkEnd(const BenchmarkResults& results) {}
~BenchmarkListener()100 virtual ~BenchmarkListener() {}
101 };
102
103 // A listener that forwards its method calls to a collection of listeners.
104 class BenchmarkListeners : public BenchmarkListener {
105 public:
106 // Added a listener to the listener collection.
107 // |listener| is not owned by the instance of |BenchmarkListeners|.
108 // |listener| should not be null and should outlast the instance of
109 // |BenchmarkListeners|.
AddListener(BenchmarkListener * listener)110 void AddListener(BenchmarkListener* listener) {
111 listeners_.push_back(listener);
112 }
113
114 // Remove all listeners after [index] including the one at 'index'.
RemoveListeners(int index)115 void RemoveListeners(int index) {
116 if (index >= NumListeners()) return;
117 listeners_.resize(index);
118 }
119
NumListeners()120 int NumListeners() const { return listeners_.size(); }
121
OnBenchmarkStart(const BenchmarkParams & params)122 void OnBenchmarkStart(const BenchmarkParams& params) override {
123 for (auto listener : listeners_) {
124 listener->OnBenchmarkStart(params);
125 }
126 }
127
OnSingleRunStart(RunType runType)128 void OnSingleRunStart(RunType runType) override {
129 for (auto listener : listeners_) {
130 listener->OnSingleRunStart(runType);
131 }
132 }
133
OnSingleRunEnd()134 void OnSingleRunEnd() override {
135 for (auto listener : listeners_) {
136 listener->OnSingleRunEnd();
137 }
138 }
139
OnBenchmarkEnd(const BenchmarkResults & results)140 void OnBenchmarkEnd(const BenchmarkResults& results) override {
141 for (auto listener : listeners_) {
142 listener->OnBenchmarkEnd(results);
143 }
144 }
145
~BenchmarkListeners()146 ~BenchmarkListeners() override {}
147
148 private:
149 // Use vector so listeners are invoked in the order they are added.
150 std::vector<BenchmarkListener*> listeners_;
151 };
152
153 // Benchmark listener that just logs the results of benchmark run.
154 class BenchmarkLoggingListener : public BenchmarkListener {
155 public:
156 void OnBenchmarkEnd(const BenchmarkResults& results) override;
157 };
158
159 template <typename T>
CreateFlag(const char * name,BenchmarkParams * params,const std::string & usage)160 Flag CreateFlag(const char* name, BenchmarkParams* params,
161 const std::string& usage) {
162 return Flag(
163 name, [params, name](const T& val) { params->Set<T>(name, val); },
164 params->Get<T>(name), usage, Flag::kOptional);
165 }
166
167 // Benchmarks a model.
168 //
169 // Subclasses need to implement initialization and running of the model.
170 // The results can be collected by adding BenchmarkListener(s).
171 class BenchmarkModel {
172 public:
173 static BenchmarkParams DefaultParams();
174 BenchmarkModel();
BenchmarkModel(BenchmarkParams params)175 explicit BenchmarkModel(BenchmarkParams params)
176 : params_(std::move(params)) {}
~BenchmarkModel()177 virtual ~BenchmarkModel() {}
178 virtual TfLiteStatus Init() = 0;
179 TfLiteStatus Run(int argc, char** argv);
180 virtual TfLiteStatus Run();
AddListener(BenchmarkListener * listener)181 void AddListener(BenchmarkListener* listener) {
182 listeners_.AddListener(listener);
183 }
184 // Remove all listeners after [index] including the one at 'index'.
RemoveListeners(int index)185 void RemoveListeners(int index) { listeners_.RemoveListeners(index); }
NumListeners()186 int NumListeners() const { return listeners_.NumListeners(); }
187
mutable_params()188 BenchmarkParams* mutable_params() { return ¶ms_; }
189
190 // Unparsable flags will remain in 'argv' in the original order and 'argc'
191 // will be updated accordingly.
192 TfLiteStatus ParseFlags(int* argc, char** argv);
193
194 protected:
195 virtual void LogParams();
196 virtual TfLiteStatus ValidateParams();
197
ParseFlags(int argc,char ** argv)198 TfLiteStatus ParseFlags(int argc, char** argv) {
199 return ParseFlags(&argc, argv);
200 }
201 virtual std::vector<Flag> GetFlags();
202
203 // Get the model file size if it's available.
MayGetModelFileSize()204 virtual int64_t MayGetModelFileSize() { return -1; }
205 virtual uint64_t ComputeInputBytes() = 0;
206 virtual tensorflow::Stat<int64_t> Run(int min_num_times, float min_secs,
207 float max_secs, RunType run_type,
208 TfLiteStatus* invoke_status);
209 // Prepares input data for benchmark. This can be used to initialize input
210 // data that has non-trivial cost.
211 virtual TfLiteStatus PrepareInputData();
212
213 virtual TfLiteStatus ResetInputsAndOutputs();
214 virtual TfLiteStatus RunImpl() = 0;
215 BenchmarkParams params_;
216 BenchmarkListeners listeners_;
217 };
218
219 } // namespace benchmark
220 } // namespace tflite
221
222 #endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
223