1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/api.h"
17
18 namespace tflite {
19 namespace gpu {
20 namespace {
21
22 struct ObjectTypeGetter {
operator ()tflite::gpu::__anon4c7694df0111::ObjectTypeGetter23 ObjectType operator()(absl::monostate) const { return ObjectType::UNKNOWN; }
operator ()tflite::gpu::__anon4c7694df0111::ObjectTypeGetter24 ObjectType operator()(OpenGlBuffer) const { return ObjectType::OPENGL_SSBO; }
operator ()tflite::gpu::__anon4c7694df0111::ObjectTypeGetter25 ObjectType operator()(OpenGlTexture) const {
26 return ObjectType::OPENGL_TEXTURE;
27 }
operator ()tflite::gpu::__anon4c7694df0111::ObjectTypeGetter28 ObjectType operator()(OpenClBuffer) const {
29 return ObjectType::OPENCL_BUFFER;
30 }
operator ()tflite::gpu::__anon4c7694df0111::ObjectTypeGetter31 ObjectType operator()(OpenClTexture) const {
32 return ObjectType::OPENCL_TEXTURE;
33 }
operator ()tflite::gpu::__anon4c7694df0111::ObjectTypeGetter34 ObjectType operator()(VulkanBuffer) const {
35 return ObjectType::VULKAN_BUFFER;
36 }
operator ()tflite::gpu::__anon4c7694df0111::ObjectTypeGetter37 ObjectType operator()(VulkanTexture) const {
38 return ObjectType::VULKAN_TEXTURE;
39 }
operator ()tflite::gpu::__anon4c7694df0111::ObjectTypeGetter40 ObjectType operator()(CpuMemory) const { return ObjectType::CPU_MEMORY; }
41 };
42
43 struct ObjectValidityChecker {
operator ()tflite::gpu::__anon4c7694df0111::ObjectValidityChecker44 bool operator()(absl::monostate) const { return false; }
operator ()tflite::gpu::__anon4c7694df0111::ObjectValidityChecker45 bool operator()(OpenGlBuffer obj) const { return obj.id != GL_INVALID_INDEX; }
operator ()tflite::gpu::__anon4c7694df0111::ObjectValidityChecker46 bool operator()(OpenGlTexture obj) const {
47 return obj.id != GL_INVALID_INDEX && obj.format != GL_INVALID_ENUM;
48 }
operator ()tflite::gpu::__anon4c7694df0111::ObjectValidityChecker49 bool operator()(OpenClBuffer obj) const { return obj.memobj; }
operator ()tflite::gpu::__anon4c7694df0111::ObjectValidityChecker50 bool operator()(OpenClTexture obj) const { return obj.memobj; }
operator ()tflite::gpu::__anon4c7694df0111::ObjectValidityChecker51 bool operator()(VulkanBuffer obj) const { return obj.memory; }
operator ()tflite::gpu::__anon4c7694df0111::ObjectValidityChecker52 bool operator()(VulkanTexture obj) const { return obj.memory; }
operator ()tflite::gpu::__anon4c7694df0111::ObjectValidityChecker53 bool operator()(CpuMemory obj) const {
54 return obj.data != nullptr && obj.size_bytes > 0 &&
55 (data_type == DataType::UNKNOWN ||
56 obj.size_bytes % SizeOf(data_type) == 0);
57 }
58 DataType data_type;
59 };
60
61 } // namespace
62
IsValid(const ObjectDef & def)63 bool IsValid(const ObjectDef& def) {
64 return def.data_type != DataType::UNKNOWN &&
65 def.data_layout != DataLayout::UNKNOWN &&
66 def.object_type != ObjectType::UNKNOWN;
67 }
68
GetType(const TensorObject & object)69 ObjectType GetType(const TensorObject& object) {
70 return absl::visit(ObjectTypeGetter{}, object);
71 }
72
IsValid(const TensorObjectDef & def)73 bool IsValid(const TensorObjectDef& def) { return IsValid(def.object_def); }
74
IsValid(const TensorObjectDef & def,const TensorObject & object)75 bool IsValid(const TensorObjectDef& def, const TensorObject& object) {
76 return GetType(object) == def.object_def.object_type &&
77 absl::visit(ObjectValidityChecker{def.object_def.data_type}, object);
78 }
79
IsObjectPresent(ObjectType type,const TensorObject & obj)80 bool IsObjectPresent(ObjectType type, const TensorObject& obj) {
81 switch (type) {
82 case ObjectType::CPU_MEMORY:
83 return absl::holds_alternative<CpuMemory>(obj);
84 case ObjectType::OPENGL_SSBO:
85 return absl::holds_alternative<OpenGlBuffer>(obj);
86 case ObjectType::OPENGL_TEXTURE:
87 return absl::holds_alternative<OpenGlTexture>(obj);
88 case ObjectType::OPENCL_BUFFER:
89 return absl::holds_alternative<OpenClBuffer>(obj);
90 case ObjectType::OPENCL_TEXTURE:
91 return absl::holds_alternative<OpenClTexture>(obj);
92 case ObjectType::VULKAN_BUFFER:
93 return absl::holds_alternative<VulkanBuffer>(obj);
94 case ObjectType::VULKAN_TEXTURE:
95 return absl::holds_alternative<VulkanTexture>(obj);
96 case ObjectType::UNKNOWN:
97 return false;
98 }
99 }
100
IsObjectInitialized(const TensorObject & obj)101 bool IsObjectInitialized(const TensorObject& obj) {
102 return GetType(obj) != ObjectType::UNKNOWN;
103 }
104
NumElements(const TensorObjectDef & def)105 uint32_t NumElements(const TensorObjectDef& def) {
106 const auto& d = def.dimensions;
107 switch (def.object_def.data_layout) {
108 case DataLayout::BHWC:
109 return d.product();
110 case DataLayout::HWDC4:
111 case DataLayout::HDWC4:
112 case DataLayout::DHWC4:
113 return d.b * d.h * d.w * AlignByN(d.c, 4);
114 case DataLayout::UNKNOWN:
115 return 0;
116 }
117 return 0;
118 }
119
GetPosition(const InferenceOptions & options,InferencePriority p)120 int GetPosition(const InferenceOptions& options, InferencePriority p) {
121 if (options.priority1 == p) return 1;
122 if (options.priority2 == p) return 2;
123 if (options.priority3 == p) return 3;
124 return 4; // least important
125 }
126
GetRelativeImportance(const InferenceOptions & options,InferencePriority p1,InferencePriority p2)127 PriorityImportance GetRelativeImportance(const InferenceOptions& options,
128 InferencePriority p1,
129 InferencePriority p2) {
130 int p1_position = GetPosition(options, p1);
131 int p2_position = GetPosition(options, p2);
132 if (p1_position == p2_position) return PriorityImportance::UNKNOWN;
133 return p1_position < p2_position ? PriorityImportance::HIGHER
134 : PriorityImportance::LOWER;
135 }
136
IsValid(const InferenceOptions & options)137 bool IsValid(const InferenceOptions& options) {
138 if (options.usage == InferenceUsage::UNKNOWN) {
139 return false;
140 }
141 if (options.priority1 == InferencePriority::UNKNOWN ||
142 options.priority2 == InferencePriority::UNKNOWN ||
143 options.priority3 == InferencePriority::UNKNOWN) {
144 return false;
145 }
146 if (options.priority1 == InferencePriority::AUTO) {
147 return false;
148 }
149 if (options.priority2 == InferencePriority::AUTO &&
150 options.priority3 != InferencePriority::AUTO) {
151 return false;
152 }
153 if (options.priority1 == options.priority2 ||
154 options.priority1 == options.priority3) {
155 return false;
156 }
157 if (options.priority2 == options.priority3 &&
158 options.priority2 != InferencePriority::AUTO) {
159 return false;
160 }
161 return true;
162 }
163
164 // Implementation note: this resolution logic is shared between GL and CL
165 // backends, but they might have own logic. Thus, the function is defined
166 // here just for code re-use purposes.
ResolveAutoPriority(InferenceOptions * options)167 void ResolveAutoPriority(InferenceOptions* options) {
168 // priority1 can not be AUTO as it would make options invalid.
169 if (options->priority2 == InferencePriority::AUTO) {
170 switch (options->priority1) {
171 case InferencePriority::MIN_LATENCY:
172 options->priority2 = InferencePriority::MIN_MEMORY_USAGE;
173 options->priority3 = InferencePriority::MAX_PRECISION;
174 return;
175 case InferencePriority::MIN_MEMORY_USAGE:
176 options->priority2 = InferencePriority::MAX_PRECISION;
177 options->priority3 = InferencePriority::MIN_LATENCY;
178 return;
179 case InferencePriority::MAX_PRECISION:
180 options->priority2 = InferencePriority::MIN_LATENCY;
181 options->priority3 = InferencePriority::MIN_MEMORY_USAGE;
182 return;
183 case InferencePriority::UNKNOWN:
184 case InferencePriority::AUTO:
185 // Invalid and unreachable option.
186 return;
187 }
188 }
189
190 if (options->priority3 == InferencePriority::AUTO) {
191 // Simply add missing priority
192 if (GetPosition(*options, InferencePriority::MIN_LATENCY) == 4) {
193 options->priority3 = InferencePriority::MIN_LATENCY;
194 } else if (GetPosition(*options, InferencePriority::MAX_PRECISION) == 4) {
195 options->priority3 = InferencePriority::MAX_PRECISION;
196 } else if (GetPosition(*options, InferencePriority::MIN_MEMORY_USAGE) ==
197 4) {
198 options->priority3 = InferencePriority::MIN_MEMORY_USAGE;
199 }
200 }
201 }
202
203 } // namespace gpu
204 } // namespace tflite
205