• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.h"
17 
18 #include "tensorflow/core/util/stats_calculator.h"
19 #include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h"
20 
21 extern "C" {
22 
23 // -----------------------------------------------------------------------------
24 // C APIs corresponding to tflite::benchmark::BenchmarkResults type.
25 // -----------------------------------------------------------------------------
26 struct TfLiteBenchmarkResults {
27   const tflite::benchmark::BenchmarkResults* results;
28 };
29 
30 // Converts the given int64_t stat into a TfLiteBenchmarkInt64Stat struct.
ConvertStat(const tensorflow::Stat<int64_t> & stat)31 TfLiteBenchmarkInt64Stat ConvertStat(const tensorflow::Stat<int64_t>& stat) {
32   return {
33       stat.empty(),    stat.first(), stat.newest(),        stat.max(),
34       stat.min(),      stat.count(), stat.sum(),           stat.squared_sum(),
35       stat.all_same(), stat.avg(),   stat.std_deviation(),
36   };
37 }
38 
TfLiteBenchmarkResultsGetInferenceTimeMicroseconds(const TfLiteBenchmarkResults * results)39 TfLiteBenchmarkInt64Stat TfLiteBenchmarkResultsGetInferenceTimeMicroseconds(
40     const TfLiteBenchmarkResults* results) {
41   return ConvertStat(results->results->inference_time_us());
42 }
43 
TfLiteBenchmarkResultsGetWarmupTimeMicroseconds(const TfLiteBenchmarkResults * results)44 TfLiteBenchmarkInt64Stat TfLiteBenchmarkResultsGetWarmupTimeMicroseconds(
45     const TfLiteBenchmarkResults* results) {
46   return ConvertStat(results->results->warmup_time_us());
47 }
48 
TfLiteBenchmarkResultsGetStartupLatencyMicroseconds(const TfLiteBenchmarkResults * results)49 int64_t TfLiteBenchmarkResultsGetStartupLatencyMicroseconds(
50     const TfLiteBenchmarkResults* results) {
51   return results->results->startup_latency_us();
52 }
53 
TfLiteBenchmarkResultsGetInputBytes(const TfLiteBenchmarkResults * results)54 uint64_t TfLiteBenchmarkResultsGetInputBytes(
55     const TfLiteBenchmarkResults* results) {
56   return results->results->input_bytes();
57 }
58 
TfLiteBenchmarkResultsGetThroughputMbPerSecond(const TfLiteBenchmarkResults * results)59 double TfLiteBenchmarkResultsGetThroughputMbPerSecond(
60     const TfLiteBenchmarkResults* results) {
61   return results->results->throughput_MB_per_second();
62 }
63 
64 // -----------------------------------------------------------------------------
65 // C APIs corresponding to tflite::benchmark::BenchmarkListener type.
66 // -----------------------------------------------------------------------------
67 class BenchmarkListenerAdapter : public tflite::benchmark::BenchmarkListener {
68  public:
OnBenchmarkStart(const tflite::benchmark::BenchmarkParams & params)69   void OnBenchmarkStart(
70       const tflite::benchmark::BenchmarkParams& params) override {
71     if (on_benchmark_start_fn_ != nullptr) {
72       on_benchmark_start_fn_(user_data_);
73     }
74   }
75 
OnSingleRunStart(tflite::benchmark::RunType runType)76   void OnSingleRunStart(tflite::benchmark::RunType runType) override {
77     if (on_single_run_start_fn_ != nullptr) {
78       on_single_run_start_fn_(user_data_, runType == tflite::benchmark::WARMUP
79                                               ? TfLiteBenchmarkWarmup
80                                               : TfLiteBenchmarkRegular);
81     }
82   }
83 
OnSingleRunEnd()84   void OnSingleRunEnd() override {
85     if (on_single_run_end_fn_ != nullptr) {
86       on_single_run_end_fn_(user_data_);
87     }
88   }
89 
OnBenchmarkEnd(const tflite::benchmark::BenchmarkResults & results)90   void OnBenchmarkEnd(
91       const tflite::benchmark::BenchmarkResults& results) override {
92     if (on_benchmark_end_fn_ != nullptr) {
93       TfLiteBenchmarkResults* wrapper = new TfLiteBenchmarkResults{&results};
94       on_benchmark_end_fn_(user_data_, wrapper);
95       delete wrapper;
96     }
97   }
98 
99   // Keep the user_data pointer provided when setting the callbacks.
100   void* user_data_;
101 
102   // Function pointers set by the TfLiteBenchmarkListenerSetCallbacks call.
103   // Only non-null callbacks will be actually called.
104   void (*on_benchmark_start_fn_)(void* user_data);
105   void (*on_single_run_start_fn_)(void* user_data,
106                                   TfLiteBenchmarkRunType runType);
107   void (*on_single_run_end_fn_)(void* user_data);
108   void (*on_benchmark_end_fn_)(void* user_data,
109                                TfLiteBenchmarkResults* results);
110 };
111 
112 struct TfLiteBenchmarkListener {
113   std::unique_ptr<BenchmarkListenerAdapter> adapter;
114 };
115 
TfLiteBenchmarkListenerCreate()116 TfLiteBenchmarkListener* TfLiteBenchmarkListenerCreate() {
117   std::unique_ptr<BenchmarkListenerAdapter> adapter(
118       new BenchmarkListenerAdapter());
119   return new TfLiteBenchmarkListener{std::move(adapter)};
120 }
121 
TfLiteBenchmarkListenerDelete(TfLiteBenchmarkListener * listener)122 void TfLiteBenchmarkListenerDelete(TfLiteBenchmarkListener* listener) {
123   delete listener;
124 }
125 
TfLiteBenchmarkListenerSetCallbacks(TfLiteBenchmarkListener * listener,void * user_data,void (* on_benchmark_start_fn)(void * user_data),void (* on_single_run_start_fn)(void * user_data,TfLiteBenchmarkRunType runType),void (* on_single_run_end_fn)(void * user_data),void (* on_benchmark_end_fn)(void * user_data,TfLiteBenchmarkResults * results))126 void TfLiteBenchmarkListenerSetCallbacks(
127     TfLiteBenchmarkListener* listener, void* user_data,
128     void (*on_benchmark_start_fn)(void* user_data),
129     void (*on_single_run_start_fn)(void* user_data,
130                                    TfLiteBenchmarkRunType runType),
131     void (*on_single_run_end_fn)(void* user_data),
132     void (*on_benchmark_end_fn)(void* user_data,
133                                 TfLiteBenchmarkResults* results)) {
134   listener->adapter->user_data_ = user_data;
135   listener->adapter->on_benchmark_start_fn_ = on_benchmark_start_fn;
136   listener->adapter->on_single_run_start_fn_ = on_single_run_start_fn;
137   listener->adapter->on_single_run_end_fn_ = on_single_run_end_fn;
138   listener->adapter->on_benchmark_end_fn_ = on_benchmark_end_fn;
139 }
140 
141 // -----------------------------------------------------------------------------
142 // C APIs corresponding to tflite::benchmark::BenchmarkTfLiteModel type.
143 // -----------------------------------------------------------------------------
144 struct TfLiteBenchmarkTfLiteModel {
145   std::unique_ptr<tflite::benchmark::BenchmarkTfLiteModel> benchmark_model;
146 };
147 
TfLiteBenchmarkTfLiteModelCreate()148 TfLiteBenchmarkTfLiteModel* TfLiteBenchmarkTfLiteModelCreate() {
149   std::unique_ptr<tflite::benchmark::BenchmarkTfLiteModel> benchmark_model(
150       new tflite::benchmark::BenchmarkTfLiteModel());
151   return new TfLiteBenchmarkTfLiteModel{std::move(benchmark_model)};
152 }
153 
TfLiteBenchmarkTfLiteModelDelete(TfLiteBenchmarkTfLiteModel * benchmark_model)154 void TfLiteBenchmarkTfLiteModelDelete(
155     TfLiteBenchmarkTfLiteModel* benchmark_model) {
156   delete benchmark_model;
157 }
158 
TfLiteBenchmarkTfLiteModelInit(TfLiteBenchmarkTfLiteModel * benchmark_model)159 TfLiteStatus TfLiteBenchmarkTfLiteModelInit(
160     TfLiteBenchmarkTfLiteModel* benchmark_model) {
161   return benchmark_model->benchmark_model->Init();
162 }
163 
TfLiteBenchmarkTfLiteModelRun(TfLiteBenchmarkTfLiteModel * benchmark_model)164 TfLiteStatus TfLiteBenchmarkTfLiteModelRun(
165     TfLiteBenchmarkTfLiteModel* benchmark_model) {
166   return benchmark_model->benchmark_model->Run();
167 }
168 
TfLiteBenchmarkTfLiteModelRunWithArgs(TfLiteBenchmarkTfLiteModel * benchmark_model,int argc,char ** argv)169 TfLiteStatus TfLiteBenchmarkTfLiteModelRunWithArgs(
170     TfLiteBenchmarkTfLiteModel* benchmark_model, int argc, char** argv) {
171   return benchmark_model->benchmark_model->Run(argc, argv);
172 }
173 
TfLiteBenchmarkTfLiteModelAddListener(TfLiteBenchmarkTfLiteModel * benchmark_model,const TfLiteBenchmarkListener * listener)174 void TfLiteBenchmarkTfLiteModelAddListener(
175     TfLiteBenchmarkTfLiteModel* benchmark_model,
176     const TfLiteBenchmarkListener* listener) {
177   return benchmark_model->benchmark_model->AddListener(listener->adapter.get());
178 }
179 
180 }  // extern "C"
181