1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
17
18 #include <string>
19
20 #include "tensorflow/lite/tools/logging.h"
21
22 namespace tflite {
23 namespace evaluation {
24 namespace {
25 constexpr char kNnapiDelegate[] = "nnapi";
26 constexpr char kGpuDelegate[] = "gpu";
27 constexpr char kHexagonDelegate[] = "hexagon";
28 constexpr char kXnnpackDelegate[] = "xnnpack";
29 constexpr char kCoremlDelegate[] = "coreml";
30 } // namespace
31
ParseStringToDelegateType(const std::string & val)32 TfliteInferenceParams::Delegate ParseStringToDelegateType(
33 const std::string& val) {
34 if (val == kNnapiDelegate) return TfliteInferenceParams::NNAPI;
35 if (val == kGpuDelegate) return TfliteInferenceParams::GPU;
36 if (val == kHexagonDelegate) return TfliteInferenceParams::HEXAGON;
37 if (val == kXnnpackDelegate) return TfliteInferenceParams::XNNPACK;
38 if (val == kCoremlDelegate) return TfliteInferenceParams::COREML;
39 return TfliteInferenceParams::NONE;
40 }
41
CreateTfLiteDelegate(const TfliteInferenceParams & params,std::string * error_msg)42 TfLiteDelegatePtr CreateTfLiteDelegate(const TfliteInferenceParams& params,
43 std::string* error_msg) {
44 const auto type = params.delegate();
45 switch (type) {
46 case TfliteInferenceParams::NNAPI: {
47 auto p = CreateNNAPIDelegate();
48 if (!p && error_msg) *error_msg = "NNAPI not supported";
49 return p;
50 }
51 case TfliteInferenceParams::GPU: {
52 auto p = CreateGPUDelegate();
53 if (!p && error_msg) *error_msg = "GPU delegate not supported.";
54 return p;
55 }
56 case TfliteInferenceParams::HEXAGON: {
57 auto p = CreateHexagonDelegate(/*library_directory_path=*/"",
58 /*profiling=*/false);
59 if (!p && error_msg) {
60 *error_msg =
61 "Hexagon delegate is not supported on the platform or required "
62 "libraries are missing.";
63 }
64 return p;
65 }
66 case TfliteInferenceParams::XNNPACK: {
67 auto p = CreateXNNPACKDelegate(params.num_threads());
68 if (!p && error_msg) *error_msg = "XNNPACK delegate not supported.";
69 return p;
70 }
71 case TfliteInferenceParams::COREML: {
72 auto p = CreateCoreMlDelegate();
73 if (!p && error_msg) *error_msg = "CoreML delegate not supported.";
74 return p;
75 }
76 case TfliteInferenceParams::NONE:
77 return TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
78 default:
79 if (error_msg) {
80 *error_msg = "Creation of delegate type: " +
81 TfliteInferenceParams::Delegate_Name(type) +
82 " not supported yet.";
83 }
84 return TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
85 }
86 }
87
DelegateProviders()88 DelegateProviders::DelegateProviders()
89 : delegate_list_util_(¶ms_),
90 delegates_map_([=]() -> std::unordered_map<std::string, int> {
91 std::unordered_map<std::string, int> delegates_map;
92 const auto& providers = delegate_list_util_.providers();
93 for (int i = 0; i < providers.size(); ++i) {
94 delegates_map[providers[i]->GetName()] = i;
95 }
96 return delegates_map;
97 }()) {
98 delegate_list_util_.AddAllDelegateParams();
99 }
100
GetFlags()101 std::vector<Flag> DelegateProviders::GetFlags() {
102 std::vector<Flag> flags;
103 delegate_list_util_.AppendCmdlineFlags(flags);
104 return flags;
105 }
106
InitFromCmdlineArgs(int * argc,const char ** argv)107 bool DelegateProviders::InitFromCmdlineArgs(int* argc, const char** argv) {
108 std::vector<Flag> flags = GetFlags();
109 bool parse_result = Flags::Parse(argc, argv, flags);
110 if (!parse_result || params_.Get<bool>("help")) {
111 std::string usage = Flags::Usage(argv[0], flags);
112 TFLITE_LOG(ERROR) << usage;
113 // Returning false intentionally when "--help=true" is specified so that
114 // the caller could check the return value to decide stopping the execution.
115 parse_result = false;
116 }
117 return parse_result;
118 }
119
CreateDelegate(const std::string & name) const120 TfLiteDelegatePtr DelegateProviders::CreateDelegate(
121 const std::string& name) const {
122 const auto it = delegates_map_.find(name);
123 if (it == delegates_map_.end()) {
124 return TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
125 }
126 const auto& providers = delegate_list_util_.providers();
127 return providers[it->second]->CreateTfLiteDelegate(params_);
128 }
129
GetAllParams(const TfliteInferenceParams & params) const130 tools::ToolParams DelegateProviders::GetAllParams(
131 const TfliteInferenceParams& params) const {
132 tools::ToolParams tool_params;
133 tool_params.Merge(params_, /*overwrite*/ false);
134
135 if (params.has_num_threads()) {
136 tool_params.Set<int32_t>("num_threads", params.num_threads());
137 }
138
139 const auto type = params.delegate();
140 switch (type) {
141 case TfliteInferenceParams::NNAPI:
142 if (tool_params.HasParam("use_nnapi")) {
143 tool_params.Set<bool>("use_nnapi", true);
144 }
145 break;
146 case TfliteInferenceParams::GPU:
147 if (tool_params.HasParam("use_gpu")) {
148 tool_params.Set<bool>("use_gpu", true);
149 }
150 break;
151 case TfliteInferenceParams::HEXAGON:
152 if (tool_params.HasParam("use_hexagon")) {
153 tool_params.Set<bool>("use_hexagon", true);
154 }
155 break;
156 case TfliteInferenceParams::XNNPACK:
157 if (tool_params.HasParam("use_xnnpack")) {
158 tool_params.Set<bool>("use_xnnpack", true);
159 }
160 break;
161 case TfliteInferenceParams::COREML:
162 if (tool_params.HasParam("use_coreml")) {
163 tool_params.Set<bool>("use_coreml", true);
164 }
165 break;
166 default:
167 break;
168 }
169 return tool_params;
170 }
171
172 } // namespace evaluation
173 } // namespace tflite
174