• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h"
17 
18 #include <string>
19 
20 #include "tensorflow/lite/tools/logging.h"
21 
22 namespace tflite {
23 namespace evaluation {
24 namespace {
25 constexpr char kNnapiDelegate[] = "nnapi";
26 constexpr char kGpuDelegate[] = "gpu";
27 constexpr char kHexagonDelegate[] = "hexagon";
28 constexpr char kXnnpackDelegate[] = "xnnpack";
29 constexpr char kCoremlDelegate[] = "coreml";
30 }  // namespace
31 
ParseStringToDelegateType(const std::string & val)32 TfliteInferenceParams::Delegate ParseStringToDelegateType(
33     const std::string& val) {
34   if (val == kNnapiDelegate) return TfliteInferenceParams::NNAPI;
35   if (val == kGpuDelegate) return TfliteInferenceParams::GPU;
36   if (val == kHexagonDelegate) return TfliteInferenceParams::HEXAGON;
37   if (val == kXnnpackDelegate) return TfliteInferenceParams::XNNPACK;
38   if (val == kCoremlDelegate) return TfliteInferenceParams::COREML;
39   return TfliteInferenceParams::NONE;
40 }
41 
CreateTfLiteDelegate(const TfliteInferenceParams & params,std::string * error_msg)42 TfLiteDelegatePtr CreateTfLiteDelegate(const TfliteInferenceParams& params,
43                                        std::string* error_msg) {
44   const auto type = params.delegate();
45   switch (type) {
46     case TfliteInferenceParams::NNAPI: {
47       auto p = CreateNNAPIDelegate();
48       if (!p && error_msg) *error_msg = "NNAPI not supported";
49       return p;
50     }
51     case TfliteInferenceParams::GPU: {
52       auto p = CreateGPUDelegate();
53       if (!p && error_msg) *error_msg = "GPU delegate not supported.";
54       return p;
55     }
56     case TfliteInferenceParams::HEXAGON: {
57       auto p = CreateHexagonDelegate(/*library_directory_path=*/"",
58                                      /*profiling=*/false);
59       if (!p && error_msg) {
60         *error_msg =
61             "Hexagon delegate is not supported on the platform or required "
62             "libraries are missing.";
63       }
64       return p;
65     }
66     case TfliteInferenceParams::XNNPACK: {
67       auto p = CreateXNNPACKDelegate(params.num_threads());
68       if (!p && error_msg) *error_msg = "XNNPACK delegate not supported.";
69       return p;
70     }
71     case TfliteInferenceParams::COREML: {
72       auto p = CreateCoreMlDelegate();
73       if (!p && error_msg) *error_msg = "CoreML delegate not supported.";
74       return p;
75     }
76     case TfliteInferenceParams::NONE:
77       return TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
78     default:
79       if (error_msg) {
80         *error_msg = "Creation of delegate type: " +
81                      TfliteInferenceParams::Delegate_Name(type) +
82                      " not supported yet.";
83       }
84       return TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
85   }
86 }
87 
DelegateProviders()88 DelegateProviders::DelegateProviders()
89     : delegate_list_util_(&params_),
90       delegates_map_([=]() -> std::unordered_map<std::string, int> {
91         std::unordered_map<std::string, int> delegates_map;
92         const auto& providers = delegate_list_util_.providers();
93         for (int i = 0; i < providers.size(); ++i) {
94           delegates_map[providers[i]->GetName()] = i;
95         }
96         return delegates_map;
97       }()) {
98   delegate_list_util_.AddAllDelegateParams();
99 }
100 
GetFlags()101 std::vector<Flag> DelegateProviders::GetFlags() {
102   std::vector<Flag> flags;
103   delegate_list_util_.AppendCmdlineFlags(flags);
104   return flags;
105 }
106 
InitFromCmdlineArgs(int * argc,const char ** argv)107 bool DelegateProviders::InitFromCmdlineArgs(int* argc, const char** argv) {
108   std::vector<Flag> flags = GetFlags();
109   bool parse_result = Flags::Parse(argc, argv, flags);
110   if (!parse_result || params_.Get<bool>("help")) {
111     std::string usage = Flags::Usage(argv[0], flags);
112     TFLITE_LOG(ERROR) << usage;
113     // Returning false intentionally when "--help=true" is specified so that
114     // the caller could check the return value to decide stopping the execution.
115     parse_result = false;
116   }
117   return parse_result;
118 }
119 
CreateDelegate(const std::string & name) const120 TfLiteDelegatePtr DelegateProviders::CreateDelegate(
121     const std::string& name) const {
122   const auto it = delegates_map_.find(name);
123   if (it == delegates_map_.end()) {
124     return TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
125   }
126   const auto& providers = delegate_list_util_.providers();
127   return providers[it->second]->CreateTfLiteDelegate(params_);
128 }
129 
GetAllParams(const TfliteInferenceParams & params) const130 tools::ToolParams DelegateProviders::GetAllParams(
131     const TfliteInferenceParams& params) const {
132   tools::ToolParams tool_params;
133   tool_params.Merge(params_, /*overwrite*/ false);
134 
135   if (params.has_num_threads()) {
136     tool_params.Set<int32_t>("num_threads", params.num_threads());
137   }
138 
139   const auto type = params.delegate();
140   switch (type) {
141     case TfliteInferenceParams::NNAPI:
142       if (tool_params.HasParam("use_nnapi")) {
143         tool_params.Set<bool>("use_nnapi", true);
144       }
145       break;
146     case TfliteInferenceParams::GPU:
147       if (tool_params.HasParam("use_gpu")) {
148         tool_params.Set<bool>("use_gpu", true);
149       }
150       break;
151     case TfliteInferenceParams::HEXAGON:
152       if (tool_params.HasParam("use_hexagon")) {
153         tool_params.Set<bool>("use_hexagon", true);
154       }
155       break;
156     case TfliteInferenceParams::XNNPACK:
157       if (tool_params.HasParam("use_xnnpack")) {
158         tool_params.Set<bool>("use_xnnpack", true);
159       }
160       break;
161     case TfliteInferenceParams::COREML:
162       if (tool_params.HasParam("use_coreml")) {
163         tool_params.Set<bool>("use_coreml", true);
164       }
165       break;
166     default:
167       break;
168   }
169   return tool_params;
170 }
171 
172 }  // namespace evaluation
173 }  // namespace tflite
174