• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/api.h"
17 
18 namespace tflite {
19 namespace gpu {
20 namespace {
21 
22 struct ObjectTypeGetter {
operator ()tflite::gpu::__anon4c7694df0111::ObjectTypeGetter23   ObjectType operator()(absl::monostate) const { return ObjectType::UNKNOWN; }
operator ()tflite::gpu::__anon4c7694df0111::ObjectTypeGetter24   ObjectType operator()(OpenGlBuffer) const { return ObjectType::OPENGL_SSBO; }
operator ()tflite::gpu::__anon4c7694df0111::ObjectTypeGetter25   ObjectType operator()(OpenGlTexture) const {
26     return ObjectType::OPENGL_TEXTURE;
27   }
operator ()tflite::gpu::__anon4c7694df0111::ObjectTypeGetter28   ObjectType operator()(OpenClBuffer) const {
29     return ObjectType::OPENCL_BUFFER;
30   }
operator ()tflite::gpu::__anon4c7694df0111::ObjectTypeGetter31   ObjectType operator()(OpenClTexture) const {
32     return ObjectType::OPENCL_TEXTURE;
33   }
operator ()tflite::gpu::__anon4c7694df0111::ObjectTypeGetter34   ObjectType operator()(VulkanBuffer) const {
35     return ObjectType::VULKAN_BUFFER;
36   }
operator ()tflite::gpu::__anon4c7694df0111::ObjectTypeGetter37   ObjectType operator()(VulkanTexture) const {
38     return ObjectType::VULKAN_TEXTURE;
39   }
operator ()tflite::gpu::__anon4c7694df0111::ObjectTypeGetter40   ObjectType operator()(CpuMemory) const { return ObjectType::CPU_MEMORY; }
41 };
42 
43 struct ObjectValidityChecker {
operator ()tflite::gpu::__anon4c7694df0111::ObjectValidityChecker44   bool operator()(absl::monostate) const { return false; }
operator ()tflite::gpu::__anon4c7694df0111::ObjectValidityChecker45   bool operator()(OpenGlBuffer obj) const { return obj.id != GL_INVALID_INDEX; }
operator ()tflite::gpu::__anon4c7694df0111::ObjectValidityChecker46   bool operator()(OpenGlTexture obj) const {
47     return obj.id != GL_INVALID_INDEX && obj.format != GL_INVALID_ENUM;
48   }
operator ()tflite::gpu::__anon4c7694df0111::ObjectValidityChecker49   bool operator()(OpenClBuffer obj) const { return obj.memobj; }
operator ()tflite::gpu::__anon4c7694df0111::ObjectValidityChecker50   bool operator()(OpenClTexture obj) const { return obj.memobj; }
operator ()tflite::gpu::__anon4c7694df0111::ObjectValidityChecker51   bool operator()(VulkanBuffer obj) const { return obj.memory; }
operator ()tflite::gpu::__anon4c7694df0111::ObjectValidityChecker52   bool operator()(VulkanTexture obj) const { return obj.memory; }
operator ()tflite::gpu::__anon4c7694df0111::ObjectValidityChecker53   bool operator()(CpuMemory obj) const {
54     return obj.data != nullptr && obj.size_bytes > 0 &&
55            (data_type == DataType::UNKNOWN ||
56             obj.size_bytes % SizeOf(data_type) == 0);
57   }
58   DataType data_type;
59 };
60 
61 }  // namespace
62 
IsValid(const ObjectDef & def)63 bool IsValid(const ObjectDef& def) {
64   return def.data_type != DataType::UNKNOWN &&
65          def.data_layout != DataLayout::UNKNOWN &&
66          def.object_type != ObjectType::UNKNOWN;
67 }
68 
GetType(const TensorObject & object)69 ObjectType GetType(const TensorObject& object) {
70   return absl::visit(ObjectTypeGetter{}, object);
71 }
72 
IsValid(const TensorObjectDef & def)73 bool IsValid(const TensorObjectDef& def) { return IsValid(def.object_def); }
74 
IsValid(const TensorObjectDef & def,const TensorObject & object)75 bool IsValid(const TensorObjectDef& def, const TensorObject& object) {
76   return GetType(object) == def.object_def.object_type &&
77          absl::visit(ObjectValidityChecker{def.object_def.data_type}, object);
78 }
79 
IsObjectPresent(ObjectType type,const TensorObject & obj)80 bool IsObjectPresent(ObjectType type, const TensorObject& obj) {
81   switch (type) {
82     case ObjectType::CPU_MEMORY:
83       return absl::holds_alternative<CpuMemory>(obj);
84     case ObjectType::OPENGL_SSBO:
85       return absl::holds_alternative<OpenGlBuffer>(obj);
86     case ObjectType::OPENGL_TEXTURE:
87       return absl::holds_alternative<OpenGlTexture>(obj);
88     case ObjectType::OPENCL_BUFFER:
89       return absl::holds_alternative<OpenClBuffer>(obj);
90     case ObjectType::OPENCL_TEXTURE:
91       return absl::holds_alternative<OpenClTexture>(obj);
92     case ObjectType::VULKAN_BUFFER:
93       return absl::holds_alternative<VulkanBuffer>(obj);
94     case ObjectType::VULKAN_TEXTURE:
95       return absl::holds_alternative<VulkanTexture>(obj);
96     case ObjectType::UNKNOWN:
97       return false;
98   }
99 }
100 
IsObjectInitialized(const TensorObject & obj)101 bool IsObjectInitialized(const TensorObject& obj) {
102   return GetType(obj) != ObjectType::UNKNOWN;
103 }
104 
NumElements(const TensorObjectDef & def)105 uint32_t NumElements(const TensorObjectDef& def) {
106   const auto& d = def.dimensions;
107   switch (def.object_def.data_layout) {
108     case DataLayout::BHWC:
109       return d.product();
110     case DataLayout::HWDC4:
111     case DataLayout::HDWC4:
112     case DataLayout::DHWC4:
113       return d.b * d.h * d.w * AlignByN(d.c, 4);
114     case DataLayout::UNKNOWN:
115       return 0;
116   }
117   return 0;
118 }
119 
GetPosition(const InferenceOptions & options,InferencePriority p)120 int GetPosition(const InferenceOptions& options, InferencePriority p) {
121   if (options.priority1 == p) return 1;
122   if (options.priority2 == p) return 2;
123   if (options.priority3 == p) return 3;
124   return 4;  // least important
125 }
126 
GetRelativeImportance(const InferenceOptions & options,InferencePriority p1,InferencePriority p2)127 PriorityImportance GetRelativeImportance(const InferenceOptions& options,
128                                          InferencePriority p1,
129                                          InferencePriority p2) {
130   int p1_position = GetPosition(options, p1);
131   int p2_position = GetPosition(options, p2);
132   if (p1_position == p2_position) return PriorityImportance::UNKNOWN;
133   return p1_position < p2_position ? PriorityImportance::HIGHER
134                                    : PriorityImportance::LOWER;
135 }
136 
IsValid(const InferenceOptions & options)137 bool IsValid(const InferenceOptions& options) {
138   if (options.usage == InferenceUsage::UNKNOWN) {
139     return false;
140   }
141   if (options.priority1 == InferencePriority::UNKNOWN ||
142       options.priority2 == InferencePriority::UNKNOWN ||
143       options.priority3 == InferencePriority::UNKNOWN) {
144     return false;
145   }
146   if (options.priority1 == InferencePriority::AUTO) {
147     return false;
148   }
149   if (options.priority2 == InferencePriority::AUTO &&
150       options.priority3 != InferencePriority::AUTO) {
151     return false;
152   }
153   if (options.priority1 == options.priority2 ||
154       options.priority1 == options.priority3) {
155     return false;
156   }
157   if (options.priority2 == options.priority3 &&
158       options.priority2 != InferencePriority::AUTO) {
159     return false;
160   }
161   return true;
162 }
163 
164 // Implementation note: this resolution logic is shared between GL and CL
165 // backends, but they might have own logic. Thus, the function is defined
166 // here just for code re-use purposes.
ResolveAutoPriority(InferenceOptions * options)167 void ResolveAutoPriority(InferenceOptions* options) {
168   // priority1 can not be AUTO as it would make options invalid.
169   if (options->priority2 == InferencePriority::AUTO) {
170     switch (options->priority1) {
171       case InferencePriority::MIN_LATENCY:
172         options->priority2 = InferencePriority::MIN_MEMORY_USAGE;
173         options->priority3 = InferencePriority::MAX_PRECISION;
174         return;
175       case InferencePriority::MIN_MEMORY_USAGE:
176         options->priority2 = InferencePriority::MAX_PRECISION;
177         options->priority3 = InferencePriority::MIN_LATENCY;
178         return;
179       case InferencePriority::MAX_PRECISION:
180         options->priority2 = InferencePriority::MIN_LATENCY;
181         options->priority3 = InferencePriority::MIN_MEMORY_USAGE;
182         return;
183       case InferencePriority::UNKNOWN:
184       case InferencePriority::AUTO:
185         // Invalid and unreachable option.
186         return;
187     }
188   }
189 
190   if (options->priority3 == InferencePriority::AUTO) {
191     // Simply add missing priority
192     if (GetPosition(*options, InferencePriority::MIN_LATENCY) == 4) {
193       options->priority3 = InferencePriority::MIN_LATENCY;
194     } else if (GetPosition(*options, InferencePriority::MAX_PRECISION) == 4) {
195       options->priority3 = InferencePriority::MAX_PRECISION;
196     } else if (GetPosition(*options, InferencePriority::MIN_MEMORY_USAGE) ==
197                4) {
198       options->priority3 = InferencePriority::MIN_MEMORY_USAGE;
199     }
200   }
201 }
202 
203 }  // namespace gpu
204 }  // namespace tflite
205