1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.h"
17
18 #include "tensorflow/core/util/stats_calculator.h"
19 #include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h"
20
21 extern "C" {
22
23 // -----------------------------------------------------------------------------
24 // C APIs corresponding to tflite::benchmark::BenchmarkResults type.
25 // -----------------------------------------------------------------------------
26 struct TfLiteBenchmarkResults {
27 const tflite::benchmark::BenchmarkResults* results;
28 };
29
30 // Converts the given int64_t stat into a TfLiteBenchmarkInt64Stat struct.
ConvertStat(const tensorflow::Stat<int64_t> & stat)31 TfLiteBenchmarkInt64Stat ConvertStat(const tensorflow::Stat<int64_t>& stat) {
32 return {
33 stat.empty(), stat.first(), stat.newest(), stat.max(),
34 stat.min(), stat.count(), stat.sum(), stat.squared_sum(),
35 stat.all_same(), stat.avg(), stat.std_deviation(),
36 };
37 }
38
TfLiteBenchmarkResultsGetInferenceTimeMicroseconds(const TfLiteBenchmarkResults * results)39 TfLiteBenchmarkInt64Stat TfLiteBenchmarkResultsGetInferenceTimeMicroseconds(
40 const TfLiteBenchmarkResults* results) {
41 return ConvertStat(results->results->inference_time_us());
42 }
43
TfLiteBenchmarkResultsGetWarmupTimeMicroseconds(const TfLiteBenchmarkResults * results)44 TfLiteBenchmarkInt64Stat TfLiteBenchmarkResultsGetWarmupTimeMicroseconds(
45 const TfLiteBenchmarkResults* results) {
46 return ConvertStat(results->results->warmup_time_us());
47 }
48
TfLiteBenchmarkResultsGetStartupLatencyMicroseconds(const TfLiteBenchmarkResults * results)49 int64_t TfLiteBenchmarkResultsGetStartupLatencyMicroseconds(
50 const TfLiteBenchmarkResults* results) {
51 return results->results->startup_latency_us();
52 }
53
TfLiteBenchmarkResultsGetInputBytes(const TfLiteBenchmarkResults * results)54 uint64_t TfLiteBenchmarkResultsGetInputBytes(
55 const TfLiteBenchmarkResults* results) {
56 return results->results->input_bytes();
57 }
58
TfLiteBenchmarkResultsGetThroughputMbPerSecond(const TfLiteBenchmarkResults * results)59 double TfLiteBenchmarkResultsGetThroughputMbPerSecond(
60 const TfLiteBenchmarkResults* results) {
61 return results->results->throughput_MB_per_second();
62 }
63
64 // -----------------------------------------------------------------------------
65 // C APIs corresponding to tflite::benchmark::BenchmarkListener type.
66 // -----------------------------------------------------------------------------
67 class BenchmarkListenerAdapter : public tflite::benchmark::BenchmarkListener {
68 public:
OnBenchmarkStart(const tflite::benchmark::BenchmarkParams & params)69 void OnBenchmarkStart(
70 const tflite::benchmark::BenchmarkParams& params) override {
71 if (on_benchmark_start_fn_ != nullptr) {
72 on_benchmark_start_fn_(user_data_);
73 }
74 }
75
OnSingleRunStart(tflite::benchmark::RunType runType)76 void OnSingleRunStart(tflite::benchmark::RunType runType) override {
77 if (on_single_run_start_fn_ != nullptr) {
78 on_single_run_start_fn_(user_data_, runType == tflite::benchmark::WARMUP
79 ? TfLiteBenchmarkWarmup
80 : TfLiteBenchmarkRegular);
81 }
82 }
83
OnSingleRunEnd()84 void OnSingleRunEnd() override {
85 if (on_single_run_end_fn_ != nullptr) {
86 on_single_run_end_fn_(user_data_);
87 }
88 }
89
OnBenchmarkEnd(const tflite::benchmark::BenchmarkResults & results)90 void OnBenchmarkEnd(
91 const tflite::benchmark::BenchmarkResults& results) override {
92 if (on_benchmark_end_fn_ != nullptr) {
93 TfLiteBenchmarkResults* wrapper = new TfLiteBenchmarkResults{&results};
94 on_benchmark_end_fn_(user_data_, wrapper);
95 delete wrapper;
96 }
97 }
98
99 // Keep the user_data pointer provided when setting the callbacks.
100 void* user_data_;
101
102 // Function pointers set by the TfLiteBenchmarkListenerSetCallbacks call.
103 // Only non-null callbacks will be actually called.
104 void (*on_benchmark_start_fn_)(void* user_data);
105 void (*on_single_run_start_fn_)(void* user_data,
106 TfLiteBenchmarkRunType runType);
107 void (*on_single_run_end_fn_)(void* user_data);
108 void (*on_benchmark_end_fn_)(void* user_data,
109 TfLiteBenchmarkResults* results);
110 };
111
112 struct TfLiteBenchmarkListener {
113 std::unique_ptr<BenchmarkListenerAdapter> adapter;
114 };
115
TfLiteBenchmarkListenerCreate()116 TfLiteBenchmarkListener* TfLiteBenchmarkListenerCreate() {
117 std::unique_ptr<BenchmarkListenerAdapter> adapter(
118 new BenchmarkListenerAdapter());
119 return new TfLiteBenchmarkListener{std::move(adapter)};
120 }
121
TfLiteBenchmarkListenerDelete(TfLiteBenchmarkListener * listener)122 void TfLiteBenchmarkListenerDelete(TfLiteBenchmarkListener* listener) {
123 delete listener;
124 }
125
TfLiteBenchmarkListenerSetCallbacks(TfLiteBenchmarkListener * listener,void * user_data,void (* on_benchmark_start_fn)(void * user_data),void (* on_single_run_start_fn)(void * user_data,TfLiteBenchmarkRunType runType),void (* on_single_run_end_fn)(void * user_data),void (* on_benchmark_end_fn)(void * user_data,TfLiteBenchmarkResults * results))126 void TfLiteBenchmarkListenerSetCallbacks(
127 TfLiteBenchmarkListener* listener, void* user_data,
128 void (*on_benchmark_start_fn)(void* user_data),
129 void (*on_single_run_start_fn)(void* user_data,
130 TfLiteBenchmarkRunType runType),
131 void (*on_single_run_end_fn)(void* user_data),
132 void (*on_benchmark_end_fn)(void* user_data,
133 TfLiteBenchmarkResults* results)) {
134 listener->adapter->user_data_ = user_data;
135 listener->adapter->on_benchmark_start_fn_ = on_benchmark_start_fn;
136 listener->adapter->on_single_run_start_fn_ = on_single_run_start_fn;
137 listener->adapter->on_single_run_end_fn_ = on_single_run_end_fn;
138 listener->adapter->on_benchmark_end_fn_ = on_benchmark_end_fn;
139 }
140
141 // -----------------------------------------------------------------------------
142 // C APIs corresponding to tflite::benchmark::BenchmarkTfLiteModel type.
143 // -----------------------------------------------------------------------------
144 struct TfLiteBenchmarkTfLiteModel {
145 std::unique_ptr<tflite::benchmark::BenchmarkTfLiteModel> benchmark_model;
146 };
147
TfLiteBenchmarkTfLiteModelCreate()148 TfLiteBenchmarkTfLiteModel* TfLiteBenchmarkTfLiteModelCreate() {
149 std::unique_ptr<tflite::benchmark::BenchmarkTfLiteModel> benchmark_model(
150 new tflite::benchmark::BenchmarkTfLiteModel());
151 return new TfLiteBenchmarkTfLiteModel{std::move(benchmark_model)};
152 }
153
TfLiteBenchmarkTfLiteModelDelete(TfLiteBenchmarkTfLiteModel * benchmark_model)154 void TfLiteBenchmarkTfLiteModelDelete(
155 TfLiteBenchmarkTfLiteModel* benchmark_model) {
156 delete benchmark_model;
157 }
158
TfLiteBenchmarkTfLiteModelInit(TfLiteBenchmarkTfLiteModel * benchmark_model)159 TfLiteStatus TfLiteBenchmarkTfLiteModelInit(
160 TfLiteBenchmarkTfLiteModel* benchmark_model) {
161 return benchmark_model->benchmark_model->Init();
162 }
163
TfLiteBenchmarkTfLiteModelRun(TfLiteBenchmarkTfLiteModel * benchmark_model)164 TfLiteStatus TfLiteBenchmarkTfLiteModelRun(
165 TfLiteBenchmarkTfLiteModel* benchmark_model) {
166 return benchmark_model->benchmark_model->Run();
167 }
168
TfLiteBenchmarkTfLiteModelRunWithArgs(TfLiteBenchmarkTfLiteModel * benchmark_model,int argc,char ** argv)169 TfLiteStatus TfLiteBenchmarkTfLiteModelRunWithArgs(
170 TfLiteBenchmarkTfLiteModel* benchmark_model, int argc, char** argv) {
171 return benchmark_model->benchmark_model->Run(argc, argv);
172 }
173
TfLiteBenchmarkTfLiteModelAddListener(TfLiteBenchmarkTfLiteModel * benchmark_model,const TfLiteBenchmarkListener * listener)174 void TfLiteBenchmarkTfLiteModelAddListener(
175 TfLiteBenchmarkTfLiteModel* benchmark_model,
176 const TfLiteBenchmarkListener* listener) {
177 return benchmark_model->benchmark_model->AddListener(listener->adapter.get());
178 }
179
180 } // extern "C"
181