• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h"
17 
18 #include <cstdarg>
19 #include <cstdint>
20 #include <cstdlib>
21 #include <fstream>
22 #include <iostream>
23 #include <memory>
24 #include <random>
25 #include <string>
26 #include <unordered_set>
27 #include <vector>
28 
29 #include "absl/base/attributes.h"
30 #include "absl/strings/numbers.h"
31 #include "ruy/profiler/profiler.h"  // from @ruy
32 #include "tensorflow/lite/c/common.h"
33 #include "tensorflow/lite/kernels/cpu_backend_context.h"
34 #include "tensorflow/lite/kernels/register.h"
35 #include "tensorflow/lite/model.h"
36 #include "tensorflow/lite/op_resolver.h"
37 #include "tensorflow/lite/profiling/profile_summary_formatter.h"
38 #include "tensorflow/lite/string_util.h"
39 #include "tensorflow/lite/tools/benchmark/benchmark_utils.h"
40 #include "tensorflow/lite/tools/benchmark/profiling_listener.h"
41 #include "tensorflow/lite/tools/delegates/delegate_provider.h"
42 #include "tensorflow/lite/tools/logging.h"
43 
44 void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
45 
46 // Version with Weak linker attribute doing nothing: if someone links this
47 // library with another definition of this function (presumably to actually
48 // register custom ops), that version will be used instead.
49 void ABSL_ATTRIBUTE_WEAK
RegisterSelectedOps(::tflite::MutableOpResolver * resolver)50 RegisterSelectedOps(::tflite::MutableOpResolver* resolver) {}
51 
52 namespace tflite {
53 namespace benchmark {
54 namespace {
55 
56 // Backward compat with previous approach to enabling op profiling.
57 #if defined(TFLITE_PROFILING_ENABLED)
58 constexpr int kOpProfilingEnabledDefault = true;
59 #else
60 constexpr int kOpProfilingEnabledDefault = false;
61 #endif
62 
63 // Dumps ruy profiling events if the ruy profiler is enabled.
64 class RuyProfileListener : public BenchmarkListener {
65  public:
66   void OnBenchmarkStart(const BenchmarkParams& params) override;
67 
68   void OnBenchmarkEnd(const BenchmarkResults& results) override;
69 
70  private:
71   std::unique_ptr<ruy::profiler::ScopeProfile> ruy_profile_;
72 };
73 
OnBenchmarkStart(const BenchmarkParams & params)74 void RuyProfileListener::OnBenchmarkStart(const BenchmarkParams& params) {
75   ruy_profile_.reset(new ruy::profiler::ScopeProfile);
76 }
77 
OnBenchmarkEnd(const BenchmarkResults & results)78 void RuyProfileListener::OnBenchmarkEnd(const BenchmarkResults& results) {
79   ruy_profile_ = nullptr;
80 }
81 
Split(const std::string & str,const char delim)82 std::vector<std::string> Split(const std::string& str, const char delim) {
83   std::vector<std::string> results;
84   if (!util::SplitAndParse(str, delim, &results)) {
85     results.clear();
86   }
87   return results;
88 }
89 
GetNumElements(const TfLiteIntArray * dim_array)90 int GetNumElements(const TfLiteIntArray* dim_array) {
91   int num_elements = 1;
92   for (size_t i = 0; i < dim_array->size; i++) {
93     num_elements *= dim_array->data[i];
94   }
95   return num_elements;
96 }
97 
FillRandomString(tflite::DynamicBuffer * buffer,const TfLiteIntArray * dim_array,const std::function<std::string ()> & random_func)98 void FillRandomString(tflite::DynamicBuffer* buffer,
99                       const TfLiteIntArray* dim_array,
100                       const std::function<std::string()>& random_func) {
101   int num_elements = GetNumElements(dim_array);
102   for (int i = 0; i < num_elements; ++i) {
103     auto str = random_func();
104     buffer->AddString(str.data(), str.length());
105   }
106 }
107 
FindLayerInfoIndex(std::vector<BenchmarkTfLiteModel::InputLayerInfo> * info,const std::string & input_name,const string & names_string)108 int FindLayerInfoIndex(std::vector<BenchmarkTfLiteModel::InputLayerInfo>* info,
109                        const std::string& input_name,
110                        const string& names_string) {
111   for (int i = 0; i < info->size(); ++i) {
112     if (info->at(i).name == input_name) {
113       return i;
114     }
115   }
116   TFLITE_LOG(FATAL) << "Cannot find the corresponding input_layer name("
117                     << input_name << ") in --input_layer as " << names_string;
118   return -1;
119 }
120 
PopulateInputValueRanges(const std::string & names_string,const std::string & value_ranges_string,std::vector<BenchmarkTfLiteModel::InputLayerInfo> * info)121 TfLiteStatus PopulateInputValueRanges(
122     const std::string& names_string, const std::string& value_ranges_string,
123     std::vector<BenchmarkTfLiteModel::InputLayerInfo>* info) {
124   std::vector<std::string> value_ranges = Split(value_ranges_string, ':');
125   for (const auto& val : value_ranges) {
126     std::vector<std::string> name_range = Split(val, ',');
127     if (name_range.size() != 3) {
128       TFLITE_LOG(ERROR) << "Wrong input value range item specified: " << val;
129       return kTfLiteError;
130     }
131 
132     // Ensure the specific input layer name exists.
133     int layer_info_idx = FindLayerInfoIndex(info, name_range[0], names_string);
134 
135     // Parse the range value.
136     int low, high;
137     bool has_low = absl::SimpleAtoi(name_range[1], &low);
138     bool has_high = absl::SimpleAtoi(name_range[2], &high);
139     if (!has_low || !has_high || low > high) {
140       TFLITE_LOG(ERROR)
141           << "Wrong low and high value of the input value range specified: "
142           << val;
143       return kTfLiteError;
144     }
145     info->at(layer_info_idx).has_value_range = true;
146     info->at(layer_info_idx).low = low;
147     info->at(layer_info_idx).high = high;
148   }
149   return kTfLiteOk;
150 }
151 
PopulateInputValueFiles(const std::string & names_string,const std::string & value_files_string,std::vector<BenchmarkTfLiteModel::InputLayerInfo> * info)152 TfLiteStatus PopulateInputValueFiles(
153     const std::string& names_string, const std::string& value_files_string,
154     std::vector<BenchmarkTfLiteModel::InputLayerInfo>* info) {
155   std::vector<std::string> value_files = Split(value_files_string, ',');
156   for (const auto& val : value_files) {
157     std::vector<std::string> name_file = Split(val, ':');
158     if (name_file.size() != 2) {
159       TFLITE_LOG(ERROR) << "Wrong input value file item specified: " << val;
160       return kTfLiteError;
161     }
162 
163     // Ensure the specific input layer name exists.
164     int layer_info_idx = FindLayerInfoIndex(info, name_file[0], names_string);
165     if (info->at(layer_info_idx).has_value_range) {
166       TFLITE_LOG(WARN)
167           << "The input_name:" << info->at(layer_info_idx).name
168           << " appears both in input_layer_value_files and "
169              "input_layer_value_range. The input_layer_value_range of the "
170              "input_name will be ignored.";
171     }
172     info->at(layer_info_idx).input_file_path = name_file[1];
173   }
174   return kTfLiteOk;
175 }
176 
PopulateInputLayerInfo(const std::string & names_string,const std::string & shapes_string,const std::string & value_ranges_string,const std::string & value_files_string,std::vector<BenchmarkTfLiteModel::InputLayerInfo> * info)177 TfLiteStatus PopulateInputLayerInfo(
178     const std::string& names_string, const std::string& shapes_string,
179     const std::string& value_ranges_string,
180     const std::string& value_files_string,
181     std::vector<BenchmarkTfLiteModel::InputLayerInfo>* info) {
182   info->clear();
183   std::vector<std::string> names = Split(names_string, ',');
184   std::vector<std::string> shapes = Split(shapes_string, ':');
185 
186   if (names.size() != shapes.size()) {
187     TFLITE_LOG(ERROR) << "The number of items in"
188                       << " --input_layer_shape (" << shapes_string << ", with "
189                       << shapes.size() << " items)"
190                       << " must match the number of items in"
191                       << " --input_layer (" << names_string << ", with "
192                       << names.size() << " items)."
193                       << " For example --input_layer=input1,input2"
194                       << " --input_layer_shape=1,224,224,4:1,20";
195     return kTfLiteError;
196   }
197 
198   for (int i = 0; i < names.size(); ++i) {
199     info->push_back(BenchmarkTfLiteModel::InputLayerInfo());
200     BenchmarkTfLiteModel::InputLayerInfo& input = info->back();
201 
202     input.name = names[i];
203 
204     TFLITE_TOOLS_CHECK(util::SplitAndParse(shapes[i], ',', &input.shape))
205         << "Incorrect size string specified: " << shapes[i];
206     for (int dim : input.shape) {
207       if (dim == -1) {
208         TFLITE_LOG(ERROR)
209             << "Any unknown sizes in the shapes (-1's) must be replaced"
210             << " with the size you want to benchmark with.";
211         return kTfLiteError;
212       }
213     }
214   }
215 
216   // Populate input value range if it's specified.
217   TF_LITE_ENSURE_STATUS(
218       PopulateInputValueRanges(names_string, value_ranges_string, info));
219 
220   // Populate input value files if it's specified.
221   TF_LITE_ENSURE_STATUS(
222       PopulateInputValueFiles(names_string, value_files_string, info));
223 
224   return kTfLiteOk;
225 }
226 
227 std::shared_ptr<profiling::ProfileSummaryFormatter>
CreateProfileSummaryFormatter(bool format_as_csv)228 CreateProfileSummaryFormatter(bool format_as_csv) {
229   return format_as_csv
230              ? std::make_shared<profiling::ProfileSummaryCSVFormatter>()
231              : std::make_shared<profiling::ProfileSummaryDefaultFormatter>();
232 }
233 
234 }  // namespace
235 
DefaultParams()236 BenchmarkParams BenchmarkTfLiteModel::DefaultParams() {
237   BenchmarkParams default_params = BenchmarkModel::DefaultParams();
238   default_params.AddParam("graph", BenchmarkParam::Create<std::string>(""));
239   default_params.AddParam("input_layer",
240                           BenchmarkParam::Create<std::string>(""));
241   default_params.AddParam("input_layer_shape",
242                           BenchmarkParam::Create<std::string>(""));
243   default_params.AddParam("input_layer_value_range",
244                           BenchmarkParam::Create<std::string>(""));
245   default_params.AddParam("input_layer_value_files",
246                           BenchmarkParam::Create<std::string>(""));
247   default_params.AddParam("allow_fp16", BenchmarkParam::Create<bool>(false));
248   default_params.AddParam("require_full_delegation",
249                           BenchmarkParam::Create<bool>(false));
250   default_params.AddParam(
251       "enable_op_profiling",
252       BenchmarkParam::Create<bool>(kOpProfilingEnabledDefault));
253   default_params.AddParam("max_profiling_buffer_entries",
254                           BenchmarkParam::Create<int32_t>(1024));
255   default_params.AddParam("profiling_output_csv_file",
256                           BenchmarkParam::Create<std::string>(""));
257 
258   for (const auto& delegate_provider :
259        tools::GetRegisteredDelegateProviders()) {
260     default_params.Merge(delegate_provider->DefaultParams());
261   }
262 
263   return default_params;
264 }
265 
BenchmarkTfLiteModel(BenchmarkParams params)266 BenchmarkTfLiteModel::BenchmarkTfLiteModel(BenchmarkParams params)
267     : BenchmarkModel(std::move(params)),
268       random_engine_(std::random_device()()) {
269   AddListener(&log_output_);
270 }
271 
CleanUp()272 void BenchmarkTfLiteModel::CleanUp() {
273   // Free up any pre-allocated tensor data during PrepareInputData.
274   inputs_data_.clear();
275 }
276 
~BenchmarkTfLiteModel()277 BenchmarkTfLiteModel::~BenchmarkTfLiteModel() {
278   CleanUp();
279 
280   // Destory the owned interpreter earlier than other objects (specially
281   // 'owned_delegates_').
282   interpreter_.reset();
283 }
284 
GetFlags()285 std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
286   std::vector<Flag> flags = BenchmarkModel::GetFlags();
287   std::vector<Flag> specific_flags = {
288       CreateFlag<std::string>("graph", &params_, "graph file name"),
289       CreateFlag<std::string>("input_layer", &params_, "input layer names"),
290       CreateFlag<std::string>("input_layer_shape", &params_,
291                               "input layer shape"),
292       CreateFlag<std::string>(
293           "input_layer_value_range", &params_,
294           "A map-like string representing value range for *integer* input "
295           "layers. Each item is separated by ':', and the item value consists "
296           "of input layer name and integer-only range values (both low and "
297           "high are inclusive) separated by ',', e.g. input1,1,2:input2,0,254"),
298       CreateFlag<std::string>(
299           "input_layer_value_files", &params_,
300           "A map-like string representing value file. Each item is separated "
301           "by ',', and the item value consists "
302           "of input layer name and value file path separated by ':', e.g. "
303           "input1:file_path1,input2:file_path2. If the input_name appears both "
304           "in input_layer_value_range and input_layer_value_files, "
305           "input_layer_value_range of the input_name will be ignored. The file "
306           "format is binary and it should be array format or null separated "
307           "strings format."),
308       CreateFlag<bool>("allow_fp16", &params_, "allow fp16"),
309       CreateFlag<bool>("require_full_delegation", &params_,
310                        "require delegate to run the entire graph"),
311       CreateFlag<bool>("enable_op_profiling", &params_, "enable op profiling"),
312       CreateFlag<int32_t>("max_profiling_buffer_entries", &params_,
313                           "max profiling buffer entries"),
314       CreateFlag<std::string>(
315           "profiling_output_csv_file", &params_,
316           "File path to export profile data as CSV, if not set "
317           "prints to stdout.")};
318 
319   flags.insert(flags.end(), specific_flags.begin(), specific_flags.end());
320 
321   for (const auto& delegate_provider :
322        tools::GetRegisteredDelegateProviders()) {
323     auto delegate_flags = delegate_provider->CreateFlags(&params_);
324     flags.insert(flags.end(), delegate_flags.begin(), delegate_flags.end());
325   }
326 
327   return flags;
328 }
329 
LogParams()330 void BenchmarkTfLiteModel::LogParams() {
331   BenchmarkModel::LogParams();
332   const bool verbose = params_.Get<bool>("verbose");
333   // Always log the value of --graph.
334   LOG_BENCHMARK_PARAM(std::string, "graph", "Graph", /*verbose*/ true);
335   LOG_BENCHMARK_PARAM(std::string, "input_layer", "Input layers", verbose);
336   LOG_BENCHMARK_PARAM(std::string, "input_layer_shape", "Input shapes",
337                       verbose);
338   LOG_BENCHMARK_PARAM(std::string, "input_layer_value_range",
339                       "Input value ranges", verbose);
340   LOG_BENCHMARK_PARAM(std::string, "input_layer_value_files",
341                       "Input value files", verbose);
342 
343   LOG_BENCHMARK_PARAM(bool, "allow_fp16", "Allow fp16", verbose);
344   LOG_BENCHMARK_PARAM(bool, "require_full_delegation",
345                       "Require full delegation", verbose);
346   LOG_BENCHMARK_PARAM(bool, "enable_op_profiling", "Enable op profiling",
347                       verbose);
348   LOG_BENCHMARK_PARAM(int32_t, "max_profiling_buffer_entries",
349                       "Max profiling buffer entries", verbose);
350   LOG_BENCHMARK_PARAM(std::string, "profiling_output_csv_file",
351                       "CSV File to export profiling data to", verbose);
352 
353   for (const auto& delegate_provider :
354        tools::GetRegisteredDelegateProviders()) {
355     delegate_provider->LogParams(params_, verbose);
356   }
357 }
358 
ValidateParams()359 TfLiteStatus BenchmarkTfLiteModel::ValidateParams() {
360   if (params_.Get<std::string>("graph").empty()) {
361     TFLITE_LOG(ERROR)
362         << "Please specify the name of your TF Lite input file with --graph";
363     return kTfLiteError;
364   }
365 
366   return PopulateInputLayerInfo(
367       params_.Get<std::string>("input_layer"),
368       params_.Get<std::string>("input_layer_shape"),
369       params_.Get<std::string>("input_layer_value_range"),
370       params_.Get<std::string>("input_layer_value_files"), &inputs_);
371 }
372 
ComputeInputBytes()373 uint64_t BenchmarkTfLiteModel::ComputeInputBytes() {
374   TFLITE_TOOLS_CHECK(interpreter_);
375   uint64_t total_input_bytes = 0;
376   for (int input : interpreter_->inputs()) {
377     auto* t = interpreter_->tensor(input);
378     total_input_bytes += t->bytes;
379   }
380   return total_input_bytes;
381 }
382 
MayGetModelFileSize()383 int64_t BenchmarkTfLiteModel::MayGetModelFileSize() {
384   std::ifstream in_file(params_.Get<std::string>("graph"),
385                         std::ios::binary | std::ios::ate);
386   return in_file.tellg();
387 }
388 
LoadInputTensorData(const TfLiteTensor & t,const std::string & input_file_path)389 BenchmarkTfLiteModel::InputTensorData BenchmarkTfLiteModel::LoadInputTensorData(
390     const TfLiteTensor& t, const std::string& input_file_path) {
391   std::ifstream value_file(input_file_path, std::ios::binary);
392   if (!value_file.good()) {
393     TFLITE_LOG(FATAL) << "Failed to read the input_layer_value_file:"
394                       << input_file_path;
395   }
396   InputTensorData t_data;
397   if (t.type == kTfLiteString) {
398     t_data.data = VoidUniquePtr(
399         static_cast<void*>(new tflite::DynamicBuffer()),
400         [](void* ptr) { delete static_cast<DynamicBuffer*>(ptr); });
401     std::string line;
402     size_t num_line = 0;
403     // Read the line with the delimiter '\0'.
404     while (std::getline(value_file, line, '\0')) {
405       num_line++;
406       static_cast<DynamicBuffer*>(t_data.data.get())
407           ->AddString(line.data(), line.length());
408     }
409     int num_elements = GetNumElements(t.dims);
410     if (num_line != num_elements) {
411       TFLITE_LOG(FATAL) << "The number of string in the input_layer_value_file("
412                         << input_file_path << ") is " << num_line
413                         << ". It should be " << num_elements << ".";
414     }
415   } else {
416     value_file.seekg(0, std::ios_base::end);
417     if (value_file.tellg() != t.bytes) {
418       TFLITE_LOG(FATAL) << "The size of " << input_file_path << " is "
419                         << value_file.tellg() << " bytes. It should be "
420                         << t.bytes << " bytes.";
421     }
422     t_data.bytes = t.bytes;
423     t_data.data =
424         VoidUniquePtr(static_cast<void*>(new char[t.bytes]),
425                       [](void* ptr) { delete[] static_cast<char*>(ptr); });
426     value_file.clear();
427     value_file.seekg(0, std::ios_base::beg);
428     value_file.read(static_cast<char*>(t_data.data.get()), t.bytes);
429   }
430   return t_data;
431 }
432 
433 BenchmarkTfLiteModel::InputTensorData
CreateRandomTensorData(const TfLiteTensor & t,const InputLayerInfo * layer_info)434 BenchmarkTfLiteModel::CreateRandomTensorData(const TfLiteTensor& t,
435                                              const InputLayerInfo* layer_info) {
436   bool has_value_range = false;
437   int low_range = 0;
438   int high_range = 0;
439   if (layer_info) {
440     has_value_range = layer_info->has_value_range;
441     low_range = layer_info->low;
442     high_range = layer_info->high;
443   }
444   int num_elements = GetNumElements(t.dims);
445   switch (t.type) {
446     case kTfLiteFloat32: {
447       return CreateInputTensorData<float>(
448           num_elements, std::uniform_real_distribution<float>(-0.5f, 0.5f));
449     }
450     case kTfLiteFloat16: {
451       // TODO(b/138843274): Remove this preprocessor guard when bug is fixed.
452 #if TFLITE_ENABLE_FP16_CPU_BENCHMARKS
453 #if __GNUC__ && \
454     (__clang__ || __ARM_FP16_FORMAT_IEEE || __ARM_FP16_FORMAT_ALTERNATIVE)
455       // __fp16 is available on Clang or when __ARM_FP16_FORMAT_* is defined.
456       return CreateInputTensorData<__fp16>(
457           num_elements, std::uniform_real_distribution<float>(-0.5f, 0.5f));
458 #else
459       TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
460                         << " of type FLOAT16 on this platform.";
461 #endif
462 #else
463       // You need to build with -DTFLITE_ENABLE_FP16_CPU_BENCHMARKS=1 using a
464       // compiler that supports __fp16 type. Note: when using Clang and *not*
465       // linking with compiler-rt, a definition of __gnu_h2f_ieee and
466       // __gnu_f2h_ieee must be supplied.
467       TFLITE_LOG(FATAL) << "Populating the tensor " << t.name
468                         << " of type FLOAT16 is disabled.";
469 #endif  // TFLITE_ENABLE_FP16_CPU_BENCHMARKS
470       break;
471     }
472     case kTfLiteFloat64: {
473       return CreateInputTensorData<double>(
474           num_elements, std::uniform_real_distribution<double>(-0.5, 0.5));
475     }
476     case kTfLiteInt64: {
477       int low = has_value_range ? low_range : 0;
478       int high = has_value_range ? high_range : 99;
479       return CreateInputTensorData<int64_t>(
480           num_elements, std::uniform_int_distribution<int64_t>(low, high));
481     }
482     case kTfLiteInt32: {
483       int low = has_value_range ? low_range : 0;
484       int high = has_value_range ? high_range : 99;
485       return CreateInputTensorData<int32_t>(
486           num_elements, std::uniform_int_distribution<int32_t>(low, high));
487     }
488     case kTfLiteUInt32: {
489       int low = has_value_range ? low_range : 0;
490       int high = has_value_range ? high_range : 99;
491       return CreateInputTensorData<uint32_t>(
492           num_elements, std::uniform_int_distribution<uint32_t>(low, high));
493     }
494     case kTfLiteInt16: {
495       int low = has_value_range ? low_range : 0;
496       int high = has_value_range ? high_range : 99;
497       return CreateInputTensorData<int16_t>(
498           num_elements, std::uniform_int_distribution<int16_t>(low, high));
499     }
500     case kTfLiteUInt8: {
501       int low = has_value_range ? low_range : 0;
502       int high = has_value_range ? high_range : 254;
503       // std::uniform_int_distribution is specified not to support char types.
504       return CreateInputTensorData<uint8_t>(
505           num_elements, std::uniform_int_distribution<uint32_t>(low, high));
506     }
507     case kTfLiteInt8: {
508       int low = has_value_range ? low_range : -127;
509       int high = has_value_range ? high_range : 127;
510       // std::uniform_int_distribution is specified not to support char types.
511       return CreateInputTensorData<int8_t>(
512           num_elements, std::uniform_int_distribution<int32_t>(low, high));
513     }
514     case kTfLiteString: {
515       // TODO(haoliang): No need to cache string tensors right now.
516       break;
517     }
518     case kTfLiteBool: {
519       // According to std::uniform_int_distribution specification, non-int type
520       // is not supported.
521       return CreateInputTensorData<bool>(
522           num_elements, std::uniform_int_distribution<uint32_t>(0, 1));
523     }
524     default: {
525       TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t.name
526                         << " of type " << t.type;
527     }
528   }
529   return InputTensorData();
530 }
531 
PrepareInputData()532 TfLiteStatus BenchmarkTfLiteModel::PrepareInputData() {
533   CleanUp();
534 
535   // Note the corresponding relation between 'interpreter_inputs' and 'inputs_'
536   // (i.e. the specified input layer info) has been checked in
537   // BenchmarkTfLiteModel::Init() before calling this function. So, we simply
538   // use the corresponding input layer info to initializethe input data value
539   // properly.
540   auto interpreter_inputs = interpreter_->inputs();
541   for (int i = 0; i < interpreter_inputs.size(); ++i) {
542     int tensor_index = interpreter_inputs[i];
543     const TfLiteTensor& t = *(interpreter_->tensor(tensor_index));
544     const InputLayerInfo* input_layer_info = nullptr;
545     // Note that when input layer parameters (i.e. --input_layer,
546     // --input_layer_shape) are not specified, inputs_ is empty.
547     if (!inputs_.empty()) input_layer_info = &inputs_[i];
548 
549     InputTensorData t_data;
550     if (input_layer_info && !input_layer_info->input_file_path.empty()) {
551       t_data = LoadInputTensorData(t, input_layer_info->input_file_path);
552     } else {
553       t_data = CreateRandomTensorData(t, input_layer_info);
554     }
555     inputs_data_.push_back(std::move(t_data));
556   }
557   return kTfLiteOk;
558 }
559 
ResetInputsAndOutputs()560 TfLiteStatus BenchmarkTfLiteModel::ResetInputsAndOutputs() {
561   auto interpreter_inputs = interpreter_->inputs();
562   // Set the values of the input tensors from inputs_data_.
563   for (int j = 0; j < interpreter_inputs.size(); ++j) {
564     int i = interpreter_inputs[j];
565     TfLiteTensor* t = interpreter_->tensor(i);
566     if (t->type == kTfLiteString) {
567       if (inputs_data_[j].data) {
568         static_cast<DynamicBuffer*>(inputs_data_[j].data.get())
569             ->WriteToTensor(t, /*new_shape=*/nullptr);
570       } else {
571         tflite::DynamicBuffer buffer;
572         FillRandomString(&buffer, t->dims, []() {
573           return "we're have some friends over saturday to hang out in the "
574                  "yard";
575         });
576         buffer.WriteToTensor(t, /*new_shape=*/nullptr);
577       }
578     } else {
579       std::memcpy(t->data.raw, inputs_data_[j].data.get(),
580                   inputs_data_[j].bytes);
581     }
582   }
583 
584   return kTfLiteOk;
585 }
586 
InitInterpreter()587 TfLiteStatus BenchmarkTfLiteModel::InitInterpreter() {
588   auto resolver = GetOpResolver();
589   const int32_t num_threads = params_.Get<int32_t>("num_threads");
590   const bool use_caching = params_.Get<bool>("use_caching");
591   tflite::InterpreterBuilder(*model_, *resolver)(&interpreter_, num_threads);
592   if (!interpreter_) {
593     TFLITE_LOG(ERROR) << "Failed to initialize the interpreter";
594     return kTfLiteError;
595   }
596   // Manually enable caching behavior in TF Lite interpreter.
597   if (use_caching) {
598     external_context_.reset(new tflite::ExternalCpuBackendContext());
599     std::unique_ptr<tflite::CpuBackendContext> cpu_backend_context(
600         new tflite::CpuBackendContext());
601     cpu_backend_context->SetUseCaching(true);
602     cpu_backend_context->SetMaxNumThreads(num_threads);
603     external_context_->set_internal_backend_context(
604         std::move(cpu_backend_context));
605     interpreter_->SetExternalContext(kTfLiteCpuBackendContext,
606                                      external_context_.get());
607   }
608 
609   return kTfLiteOk;
610 }
611 
Init()612 TfLiteStatus BenchmarkTfLiteModel::Init() {
613   TF_LITE_ENSURE_STATUS(LoadModel());
614   TF_LITE_ENSURE_STATUS(InitInterpreter());
615 
616   // Install profilers if necessary right after interpreter is created so that
617   // any memory allocations inside the TFLite runtime could be recorded if the
618   // installed profiler profile memory usage information.
619   profiling_listener_ = MayCreateProfilingListener();
620   if (profiling_listener_) AddListener(profiling_listener_.get());
621 
622   interpreter_->SetAllowFp16PrecisionForFp32(params_.Get<bool>("allow_fp16"));
623 
624   owned_delegates_.clear();
625 
626   // Contains all ids of TfLiteNodes that have been checked to see whether it's
627   // delegated or not.
628   std::unordered_set<int> checked_node_ids;
629   for (const auto& delegate_provider :
630        tools::GetRegisteredDelegateProviders()) {
631     auto delegate = delegate_provider->CreateTfLiteDelegate(params_);
632     // It's possible that a delegate of certain type won't be created as
633     // user-specified benchmark params tells not to.
634     if (delegate == nullptr) continue;
635     if (interpreter_->ModifyGraphWithDelegate(delegate.get()) != kTfLiteOk) {
636       TFLITE_LOG(ERROR) << "Failed to apply " << delegate_provider->GetName()
637                         << " delegate.";
638       return kTfLiteError;
639     } else {
640       // Ideally, such delegate info should already be computed when the
641       // delegate is being applied to the model graph.
642       int num_delegated_kernels = 0;
643       for (int i = 0; i < interpreter_->execution_plan().size(); ++i) {
644         int node_id = interpreter_->execution_plan()[i];
645         if (checked_node_ids.find(node_id) != checked_node_ids.end()) {
646           continue;
647         }
648         const TfLiteNode& node =
649             interpreter_->node_and_registration(node_id)->first;
650 
651         // Note that the 'delegate' here could be an ExternalDelegateWrapper
652         // object that wraps an actual external delegate, in which case,
653         // 'node.delegate' will be different from 'delegate' because
654         // 'node.delegate' refers to the actual external delegate.
655         if (node.delegate != nullptr) {
656           num_delegated_kernels++;
657           checked_node_ids.insert(node_id);
658         }
659       }
660       bool fully_delegated = (num_delegated_kernels == 1 &&
661                               interpreter_->execution_plan().size() == 1);
662 
663       if (params_.Get<bool>("require_full_delegation") && !fully_delegated) {
664         TFLITE_LOG(ERROR) << "Disallowed CPU fallback detected.";
665         return kTfLiteError;
666       }
667       if (fully_delegated) {
668         TFLITE_LOG(INFO) << "Explicitly applied "
669                          << delegate_provider->GetName()
670                          << " delegate, and the model graph will be completely"
671                          << " executed by the delegate.";
672       } else if (num_delegated_kernels > 0) {
673         TFLITE_LOG(INFO) << "Explicitly applied "
674                          << delegate_provider->GetName()
675                          << " delegate, and the model graph will be partially"
676                          << " executed by the delegate w/ "
677                          << num_delegated_kernels << " delegate kernels.";
678       } else {
679         TFLITE_LOG(INFO)
680             << "Though " << delegate_provider->GetName()
681             << " delegate is explicitly applied, the model graph will not be"
682             << " executed by the delegate.";
683       }
684     }
685     owned_delegates_.emplace_back(std::move(delegate));
686   }
687 
688   auto interpreter_inputs = interpreter_->inputs();
689 
690   if (!inputs_.empty()) {
691     TFLITE_TOOLS_CHECK_EQ(inputs_.size(), interpreter_inputs.size())
692         << "Inputs mismatch: Model inputs #:" << inputs_.size()
693         << " expected: " << interpreter_inputs.size();
694   }
695 
696   // Check if the tensor names match, and log a warning if it doesn't.
697   // TODO(ycling): Consider to make this an error again when the new converter
698   // create tensors with consistent naming.
699   for (int j = 0; j < inputs_.size(); ++j) {
700     const InputLayerInfo& input = inputs_[j];
701     int i = interpreter_inputs[j];
702     TfLiteTensor* t = interpreter_->tensor(i);
703     if (input.name != t->name) {
704       TFLITE_LOG(WARN) << "Tensor # " << i << " is named " << t->name
705                        << " but flags call it " << input.name;
706     }
707 
708     if (input.shape.size() != t->dims->size) {
709       TFLITE_LOG(ERROR) << "Input tensor #" << i << " should have "
710                         << t->dims->size << " dimensions!";
711       return kTfLiteError;
712     }
713   }
714 
715   // Resize all non-string tensors.
716   for (int j = 0; j < inputs_.size(); ++j) {
717     const InputLayerInfo& input = inputs_[j];
718     int i = interpreter_inputs[j];
719     TfLiteTensor* t = interpreter_->tensor(i);
720     if (t->type != kTfLiteString) {
721       interpreter_->ResizeInputTensor(i, input.shape);
722     }
723   }
724 
725   if (interpreter_->AllocateTensors() != kTfLiteOk) {
726     TFLITE_LOG(ERROR) << "Failed to allocate tensors!";
727     return kTfLiteError;
728   }
729 
730   ruy_profiling_listener_.reset(new RuyProfileListener());
731   AddListener(ruy_profiling_listener_.get());
732 
733   return kTfLiteOk;
734 }
735 
LoadModel()736 TfLiteStatus BenchmarkTfLiteModel::LoadModel() {
737   std::string graph = params_.Get<std::string>("graph");
738   model_ = tflite::FlatBufferModel::BuildFromFile(graph.c_str());
739   if (!model_) {
740     TFLITE_LOG(ERROR) << "Failed to mmap model " << graph;
741     return kTfLiteError;
742   }
743   TFLITE_LOG(INFO) << "Loaded model " << graph;
744   return kTfLiteOk;
745 }
746 
GetOpResolver() const747 std::unique_ptr<tflite::OpResolver> BenchmarkTfLiteModel::GetOpResolver()
748     const {
749   auto resolver = new tflite::ops::builtin::BuiltinOpResolver();
750   RegisterSelectedOps(resolver);
751   return std::unique_ptr<tflite::OpResolver>(resolver);
752 }
753 
754 std::unique_ptr<BenchmarkListener>
MayCreateProfilingListener() const755 BenchmarkTfLiteModel::MayCreateProfilingListener() const {
756   if (!params_.Get<bool>("enable_op_profiling")) return nullptr;
757 
758   return std::unique_ptr<BenchmarkListener>(new ProfilingListener(
759       interpreter_.get(), params_.Get<int32_t>("max_profiling_buffer_entries"),
760       params_.Get<std::string>("profiling_output_csv_file"),
761       CreateProfileSummaryFormatter(
762           !params_.Get<std::string>("profiling_output_csv_file").empty())));
763 }
764 
RunImpl()765 TfLiteStatus BenchmarkTfLiteModel::RunImpl() { return interpreter_->Invoke(); }
766 
767 }  // namespace benchmark
768 }  // namespace tflite
769