1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/cl/cl_arguments.h"
17
18 #include <string>
19
20 #include "absl/strings/ascii.h"
21 #include "absl/strings/match.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/substitute.h"
24 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
25 #include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
26 #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
27 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
28 #include "tensorflow/lite/delegates/gpu/cl/texture2d.h"
29 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
30 #include "tensorflow/lite/delegates/gpu/common/util.h"
31
32 namespace tflite {
33 namespace gpu {
34 namespace cl {
35 namespace {
IsWordSymbol(char symbol)36 bool IsWordSymbol(char symbol) {
37 return absl::ascii_isalnum(symbol) || symbol == '_';
38 }
39
ReplaceAllWords(const std::string & old_word,const std::string & new_word,std::string * str)40 void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
41 std::string* str) {
42 size_t position = str->find(old_word);
43 while (position != std::string::npos) {
44 char prev = position == 0 ? '.' : (*str)[position - 1];
45 char next = position + old_word.size() < str->size()
46 ? (*str)[position + old_word.size()]
47 : '.';
48 if (IsWordSymbol(prev) || IsWordSymbol(next)) {
49 position = str->find(old_word, position + 1);
50 continue;
51 }
52 str->replace(position, old_word.size(), new_word);
53 position = str->find(old_word, position + new_word.size());
54 }
55 }
56
GetNextWord(const std::string & code,size_t first_position)57 std::string GetNextWord(const std::string& code, size_t first_position) {
58 size_t pos = first_position;
59 char t = code[pos];
60 while (IsWordSymbol(t)) {
61 pos++;
62 t = code[pos];
63 }
64 return code.substr(first_position, pos - first_position);
65 }
66
FindEnclosingBracket(const std::string & text,size_t first_pos,char bracket)67 size_t FindEnclosingBracket(const std::string& text, size_t first_pos,
68 char bracket) {
69 const std::map<char, char> brackets = {
70 {'(', ')'},
71 {'{', '}'},
72 {'[', ']'},
73 {'<', '>'},
74 };
75 char b_open = bracket;
76 auto it = brackets.find(b_open);
77 if (it == brackets.end()) {
78 return -1;
79 }
80 char b_close = it->second;
81 size_t pos = first_pos;
82 int opened = 1;
83 int closed = 0;
84 while (opened != closed && pos < text.size()) {
85 if (text[pos] == b_open) {
86 opened++;
87 } else if (text[pos] == b_close) {
88 closed++;
89 }
90 pos++;
91 }
92 if (opened == closed) {
93 return pos;
94 } else {
95 return -1;
96 }
97 }
98
ParseArgsInsideBrackets(const std::string & text,size_t open_bracket_pos,size_t * close_bracket_pos,std::vector<std::string> * args)99 absl::Status ParseArgsInsideBrackets(const std::string& text,
100 size_t open_bracket_pos,
101 size_t* close_bracket_pos,
102 std::vector<std::string>* args) {
103 *close_bracket_pos =
104 FindEnclosingBracket(text, open_bracket_pos + 1, text[open_bracket_pos]);
105 if (*close_bracket_pos == -1) {
106 return absl::NotFoundError("Not found enclosing bracket");
107 }
108 std::string str_args = text.substr(open_bracket_pos + 1,
109 *close_bracket_pos - open_bracket_pos - 2);
110 std::vector<absl::string_view> words = absl::StrSplit(str_args, ',');
111 args->reserve(words.size());
112 for (const auto& word : words) {
113 absl::string_view arg = absl::StripAsciiWhitespace(word);
114 if (!arg.empty()) {
115 args->push_back(std::string(arg));
116 }
117 }
118 return absl::OkStatus();
119 }
120
AppendArgument(const std::string & arg,std::string * args)121 void AppendArgument(const std::string& arg, std::string* args) {
122 if (!args->empty()) {
123 absl::StrAppend(args, ",\n ");
124 }
125 absl::StrAppend(args, arg);
126 }
127
GetImageModifier(AccessType access)128 std::string GetImageModifier(AccessType access) {
129 switch (access) {
130 case AccessType::READ:
131 return "__read_only";
132 case AccessType::WRITE:
133 return "__write_only";
134 case AccessType::READ_WRITE:
135 return "__read_write";
136 }
137 }
138
GetDefaultSamplers(const GpuInfo & gpu_info)139 std::string GetDefaultSamplers(const GpuInfo& gpu_info) {
140 std::string result;
141 result +=
142 "__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | "
143 "CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n";
144 if (gpu_info.IsAdreno() && gpu_info.adreno_info.IsAdreno3xx()) {
145 // Unfortunately, CLK_ADDRESS_CLAMP is very slow on Adreno3xx and
146 // we can observe huge register overhead when compared to other modes.
147
148 // While using CLK_ADDRESS_NONE with out-of-range image coordinates is
149 // undefined in the OpenCL specification, we have observed that
150 // CLK_ADDRESS_NONE works like CLK_ADDRESS_CLAMP for out-of-range image
151 // coordinates for RGBA F16/F32 textures on Adreno3xx devices. Using
152 // CLK_ADDRESS_NONE is significantly faster than CLK_ADDRESS_CLAMP on Adreno
153 // 3xx.
154 result +=
155 "__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | "
156 "CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n";
157 } else {
158 result +=
159 "__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | "
160 "CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n";
161 }
162
163 return result;
164 }
165
CreateCLObject(GPUObjectDescriptor * desc,CLContext * context,GPUObjectPtr * result)166 absl::Status CreateCLObject(GPUObjectDescriptor* desc, CLContext* context,
167 GPUObjectPtr* result) {
168 const auto* buffer_desc = dynamic_cast<const BufferDescriptor*>(desc);
169 if (buffer_desc) {
170 Buffer gpu_buffer;
171 RETURN_IF_ERROR(
172 gpu_buffer.CreateFromBufferDescriptor(*buffer_desc, context));
173 *result = absl::make_unique<Buffer>(std::move(gpu_buffer));
174 return absl::OkStatus();
175 }
176
177 const auto* texture_desc = dynamic_cast<const Texture2DDescriptor*>(desc);
178 if (texture_desc) {
179 Texture2D gpu_texture;
180 RETURN_IF_ERROR(
181 gpu_texture.CreateFromTexture2DDescriptor(*texture_desc, context));
182 *result = absl::make_unique<Texture2D>(std::move(gpu_texture));
183 return absl::OkStatus();
184 }
185
186 const auto* linear_desc = dynamic_cast<const TensorLinearDescriptor*>(desc);
187 if (linear_desc) {
188 LinearStorage gpu_storage;
189 RETURN_IF_ERROR(
190 gpu_storage.CreateFromTensorLinearDescriptor(*linear_desc, context));
191 *result = absl::make_unique<LinearStorage>(std::move(gpu_storage));
192 return absl::OkStatus();
193 }
194
195 const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc);
196 if (tensor_desc) {
197 Tensor gpu_tensor;
198 RETURN_IF_ERROR(gpu_tensor.CreateFromDescriptor(*tensor_desc, context));
199 *result = absl::make_unique<Tensor>(std::move(gpu_tensor));
200 return absl::OkStatus();
201 }
202
203 return absl::InvalidArgumentError("Unknown GPU descriptor.");
204 }
205
206 } // namespace
207
208 // Static
209 constexpr char CLArguments::kArgsPrefix[];
210
Init(const GpuInfo & gpu_info,const std::map<std::string,std::string> & linkables,CLContext * context,Arguments * args,std::string * code)211 absl::Status CLArguments::Init(
212 const GpuInfo& gpu_info,
213 const std::map<std::string, std::string>& linkables, CLContext* context,
214 Arguments* args, std::string* code) {
215 RETURN_IF_ERROR(AllocateObjects(*args, context));
216 RETURN_IF_ERROR(AddObjectArgs(args));
217 RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, *args, linkables, code));
218 object_refs_ = std::move(args->object_refs_);
219 args->GetActiveArguments(kArgsPrefix, *code);
220 const bool use_f32_for_halfs = gpu_info.IsPowerVR();
221 CopyArguments(*args, use_f32_for_halfs);
222 RETURN_IF_ERROR(SetObjectsResources(*args));
223 RenameArgumentsInCode(code);
224 ResolveArgsPass(code);
225 *code = absl::Substitute(*code, GetListOfArgs());
226 if (gpu_info.SupportsImages()) {
227 *code = GetDefaultSamplers(gpu_info) + *code;
228 }
229 return absl::OkStatus();
230 }
231
Init(const GpuInfo & gpu_info,Arguments * args,CLContext * context)232 absl::Status CLArguments::Init(const GpuInfo& gpu_info, Arguments* args,
233 CLContext* context) {
234 RETURN_IF_ERROR(AllocateObjects(*args, context));
235 RETURN_IF_ERROR(AddObjectArgs(args));
236 object_refs_ = std::move(args->object_refs_);
237 const bool use_f32_for_halfs = gpu_info.IsPowerVR();
238 CopyArguments(*args, use_f32_for_halfs);
239 RETURN_IF_ERROR(SetObjectsResources(*args));
240 return absl::OkStatus();
241 }
242
AllocateObjects(const Arguments & args,CLContext * context)243 absl::Status CLArguments::AllocateObjects(const Arguments& args,
244 CLContext* context) {
245 objects_.resize(args.objects_.size());
246 int i = 0;
247 for (auto& t : args.objects_) {
248 RETURN_IF_ERROR(CreateCLObject(t.second.get(), context, &objects_[i]));
249 i++;
250 }
251 return absl::OkStatus();
252 }
253
AddObjectArgs(Arguments * args)254 absl::Status CLArguments::AddObjectArgs(Arguments* args) {
255 for (auto& t : args->objects_) {
256 AddGPUResources(t.first, t.second->GetGPUResources(), args);
257 }
258 for (auto& t : args->object_refs_) {
259 AddGPUResources(t.first, t.second->GetGPUResources(), args);
260 }
261 return absl::OkStatus();
262 }
263
SetObjectsResources(const Arguments & args)264 absl::Status CLArguments::SetObjectsResources(const Arguments& args) {
265 int i = 0;
266 for (const auto& t : args.objects_) {
267 GPUResourcesWithValue resources;
268 RETURN_IF_ERROR(objects_[i]->GetGPUResources(t.second.get(), &resources));
269 RETURN_IF_ERROR(SetGPUResources(t.first, resources));
270 i++;
271 }
272 return absl::OkStatus();
273 }
274
ResolveSelectorsPass(const GpuInfo & gpu_info,const Arguments & args,const std::map<std::string,std::string> & linkables,std::string * code)275 absl::Status CLArguments::ResolveSelectorsPass(
276 const GpuInfo& gpu_info, const Arguments& args,
277 const std::map<std::string, std::string>& linkables, std::string* code) {
278 std::string result;
279 size_t position = 0;
280 size_t next_position = code->find(kArgsPrefix);
281 while (next_position != std::string::npos) {
282 size_t arg_pos = next_position;
283 next_position += strlen(kArgsPrefix);
284 std::string object_name = GetNextWord(*code, next_position);
285 char next = (*code)[next_position + object_name.size()];
286 if (next == '.') {
287 next_position += object_name.size() + 1;
288 std::string selector_name = GetNextWord(*code, next_position);
289 next_position += selector_name.size();
290 next = (*code)[next_position];
291 std::vector<std::string> template_args;
292 if (next == '<') {
293 size_t close_bracket_pos;
294 RETURN_IF_ERROR(ParseArgsInsideBrackets(
295 *code, next_position, &close_bracket_pos, &template_args));
296 next_position = close_bracket_pos;
297 next = (*code)[next_position];
298 }
299 if (next != '(') {
300 return absl::NotFoundError(absl::StrCat(
301 "Expected ( after ", object_name, ".", selector_name, " call"));
302 }
303 std::vector<std::string> function_args;
304 size_t close_bracket_pos;
305 RETURN_IF_ERROR(ParseArgsInsideBrackets(
306 *code, next_position, &close_bracket_pos, &function_args));
307 for (auto& arg : function_args) {
308 RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, args, {}, &arg));
309 }
310 std::string patch;
311 RETURN_IF_ERROR(ResolveSelector(gpu_info, args, linkables, object_name,
312 selector_name, function_args,
313 template_args, &patch));
314 code->replace(arg_pos, close_bracket_pos - arg_pos, patch);
315 position = arg_pos + patch.size();
316 } else {
317 position = arg_pos + strlen(kArgsPrefix);
318 }
319 next_position = code->find(kArgsPrefix, position);
320 }
321 return absl::OkStatus();
322 }
323
ResolveObjectNames(const std::string & object_name,const std::vector<std::string> & member_names,std::string * code)324 void CLArguments::ResolveObjectNames(
325 const std::string& object_name,
326 const std::vector<std::string>& member_names, std::string* code) {
327 for (const auto& member_name : member_names) {
328 const std::string new_name = kArgsPrefix + object_name + "_" + member_name;
329 ReplaceAllWords(member_name, new_name, code);
330 }
331 }
332
ResolveSelector(const GpuInfo & gpu_info,const Arguments & args,const std::map<std::string,std::string> & linkables,const std::string & object_name,const std::string & selector,const std::vector<std::string> & function_args,const std::vector<std::string> & template_args,std::string * result)333 absl::Status CLArguments::ResolveSelector(
334 const GpuInfo& gpu_info, const Arguments& args,
335 const std::map<std::string, std::string>& linkables,
336 const std::string& object_name, const std::string& selector,
337 const std::vector<std::string>& function_args,
338 const std::vector<std::string>& template_args, std::string* result) {
339 const GPUObjectDescriptor* desc_ptr;
340 auto it_ref = args.object_refs_.find(object_name);
341 auto it_obj = args.objects_.find(object_name);
342 if (it_ref != args.object_refs_.end()) {
343 desc_ptr = it_ref->second.get();
344 } else if (it_obj != args.objects_.end()) {
345 desc_ptr = it_obj->second.get();
346 } else {
347 return absl::NotFoundError(
348 absl::StrCat("No object with name - ", object_name));
349 }
350 auto names = desc_ptr->GetGPUResources().GetNames();
351 const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc_ptr);
352 if (tensor_desc && (selector == "Write" || selector == "Linking")) {
353 auto it = linkables.find(object_name);
354 if (it != linkables.end()) {
355 if (desc_ptr->GetAccess() != AccessType::WRITE &&
356 desc_ptr->GetAccess() != AccessType::READ_WRITE) {
357 return absl::FailedPreconditionError(absl::StrCat(
358 "Object with name - ", object_name, " should have Write access."));
359 }
360 std::string value_name, x_coord, y_coord, s_coord;
361 RETURN_IF_ERROR(tensor_desc->GetLinkingContextFromWriteSelector(
362 function_args, &value_name, &x_coord, &y_coord, &s_coord));
363 // x_coord can have batch size property of link_object
364 ResolveObjectNames(object_name, names, &x_coord);
365 *result = it->second;
366 ReplaceAllWords("in_out_value", value_name, result);
367 ReplaceAllWords("X_COORD", x_coord, result);
368 ReplaceAllWords("Y_COORD", y_coord, result);
369 ReplaceAllWords("S_COORD", s_coord, result);
370 RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, args, {}, result));
371 if (selector == "Linking") {
372 return absl::OkStatus();
373 }
374 }
375 }
376 std::string patch;
377 RETURN_IF_ERROR(desc_ptr->PerformSelector(gpu_info, selector, function_args,
378 template_args, &patch));
379 ResolveObjectNames(object_name, names, &patch);
380 *result += patch;
381 return absl::OkStatus();
382 }
383
ResolveArgsPass(std::string * code)384 void CLArguments::ResolveArgsPass(std::string* code) {
385 size_t position = 0;
386 size_t next_position = code->find(kArgsPrefix);
387 while (next_position != std::string::npos) {
388 size_t arg_pos = next_position;
389 next_position += strlen(kArgsPrefix);
390 std::string object_name = GetNextWord(*code, next_position);
391 std::string new_name = object_name;
392 code->replace(arg_pos, object_name.size() + strlen(kArgsPrefix), new_name);
393 position = arg_pos + new_name.size();
394 next_position = code->find(kArgsPrefix, position);
395 }
396 }
397
CopyScalarValues(Arguments * args) const398 void CLArguments::CopyScalarValues(Arguments* args) const {
399 for (const auto& fvalue : float_values_) {
400 args->float_values_[fvalue.first].value = fvalue.second.value;
401 }
402 for (const auto& ivalue : int_values_) {
403 args->int_values_[ivalue.first].value = ivalue.second.value;
404 }
405 for (const auto& hfvalue : half_values_) {
406 args->half_values_[hfvalue.first].value = hfvalue.second.value;
407 }
408 }
409
CopyArguments(const Arguments & args,bool use_f32_for_halfs)410 void CLArguments::CopyArguments(const Arguments& args, bool use_f32_for_halfs) {
411 for (const auto& fvalue : args.float_values_) {
412 auto& new_val = float_values_[fvalue.first];
413 new_val.value = fvalue.second.value;
414 new_val.active = fvalue.second.active;
415 if (fvalue.second.active) {
416 new_val.offset = shared_float4s_data_.size();
417 shared_float4s_data_.push_back(new_val.value);
418 }
419 }
420 for (const auto& ivalue : args.int_values_) {
421 auto& new_val = int_values_[ivalue.first];
422 new_val.value = ivalue.second.value;
423 new_val.active = ivalue.second.active;
424 if (ivalue.second.active) {
425 new_val.offset = shared_int4s_data_.size();
426 shared_int4s_data_.push_back(new_val.value);
427 }
428 }
429 for (const auto& hfvalue : args.half_values_) {
430 auto& new_val = half_values_[hfvalue.first];
431 new_val.value = hfvalue.second.value;
432 new_val.active = hfvalue.second.active;
433 if (hfvalue.second.active) {
434 if (use_f32_for_halfs) {
435 new_val.store_as_f32 = true;
436 new_val.offset = shared_float4s_data_.size();
437 shared_float4s_data_.push_back(new_val.value);
438 } else {
439 new_val.store_as_f32 = false;
440 new_val.offset = shared_half4s_data_.size();
441 shared_half4s_data_.push_back(new_val.value);
442 }
443 }
444 }
445 int shared_int4s_aligned_size = AlignByN(shared_int4s_data_.size(), 4);
446 shared_int4s_data_.resize(shared_int4s_aligned_size);
447 int shared_float4s_aligned_size = AlignByN(shared_float4s_data_.size(), 4);
448 shared_float4s_data_.resize(shared_float4s_aligned_size);
449 int shared_half4s_aligned_size = AlignByN(shared_half4s_data_.size(), 4);
450 shared_half4s_data_.resize(shared_half4s_aligned_size);
451 }
452
RenameArgumentsInCode(std::string * code)453 void CLArguments::RenameArgumentsInCode(std::string* code) {
454 const std::string postfixes[4] = {"x", "y", "z", "w"};
455 for (const auto& fvalue : float_values_) {
456 if (fvalue.second.active) {
457 std::string index = std::to_string(fvalue.second.offset / 4);
458 std::string new_name =
459 "shared_float4_" + index + "." + postfixes[fvalue.second.offset % 4];
460 ReplaceAllWords(kArgsPrefix + fvalue.first, new_name, code);
461 }
462 }
463 for (const auto& ivalue : int_values_) {
464 if (ivalue.second.active) {
465 std::string index = std::to_string(ivalue.second.offset / 4);
466 std::string new_name =
467 "shared_int4_" + index + "." + postfixes[ivalue.second.offset % 4];
468 ReplaceAllWords(kArgsPrefix + ivalue.first, new_name, code);
469 }
470 }
471 for (const auto& hfvalue : half_values_) {
472 if (hfvalue.second.active) {
473 std::string index = std::to_string(hfvalue.second.offset / 4);
474 std::string new_name;
475 if (hfvalue.second.store_as_f32) {
476 new_name = "(half)(shared_float4_" + index + "." +
477 postfixes[hfvalue.second.offset % 4] + ")";
478 } else {
479 new_name = "shared_half4_" + index + "." +
480 postfixes[hfvalue.second.offset % 4];
481 }
482 ReplaceAllWords(kArgsPrefix + hfvalue.first, new_name, code);
483 }
484 }
485 }
486
AddBuffer(const std::string & name,const GPUBufferDescriptor & desc)487 void CLArguments::AddBuffer(const std::string& name,
488 const GPUBufferDescriptor& desc) {
489 buffers_[name].desc = desc;
490 }
AddImage2D(const std::string & name,const GPUImage2DDescriptor & desc)491 void CLArguments::AddImage2D(const std::string& name,
492 const GPUImage2DDescriptor& desc) {
493 images2d_[name].desc = desc;
494 }
495
AddImage2DArray(const std::string & name,const GPUImage2DArrayDescriptor & desc)496 void CLArguments::AddImage2DArray(const std::string& name,
497 const GPUImage2DArrayDescriptor& desc) {
498 image2d_arrays_[name].desc = desc;
499 }
500
AddImage3D(const std::string & name,const GPUImage3DDescriptor & desc)501 void CLArguments::AddImage3D(const std::string& name,
502 const GPUImage3DDescriptor& desc) {
503 images3d_[name].desc = desc;
504 }
505
AddImageBuffer(const std::string & name,const GPUImageBufferDescriptor & desc)506 void CLArguments::AddImageBuffer(const std::string& name,
507 const GPUImageBufferDescriptor& desc) {
508 image_buffers_[name].desc = desc;
509 }
510
AddCustomMemory(const std::string & name,const GPUCustomMemoryDescriptor & desc)511 void CLArguments::AddCustomMemory(const std::string& name,
512 const GPUCustomMemoryDescriptor& desc) {
513 custom_memories_[name].desc = desc;
514 }
515
AddGPUResources(const std::string & name,const GPUResources & resources,Arguments * args)516 void CLArguments::AddGPUResources(const std::string& name,
517 const GPUResources& resources,
518 Arguments* args) {
519 for (const auto& r : resources.ints) {
520 args->AddInt(absl::StrCat(name, "_", r));
521 }
522 for (const auto& r : resources.floats) {
523 args->AddFloat(absl::StrCat(name, "_", r));
524 }
525 for (const auto& r : resources.buffers) {
526 AddBuffer(absl::StrCat(name, "_", r.first), r.second);
527 }
528 for (const auto& r : resources.images2d) {
529 AddImage2D(absl::StrCat(name, "_", r.first), r.second);
530 }
531 for (const auto& r : resources.image2d_arrays) {
532 AddImage2DArray(absl::StrCat(name, "_", r.first), r.second);
533 }
534 for (const auto& r : resources.images3d) {
535 AddImage3D(absl::StrCat(name, "_", r.first), r.second);
536 }
537 for (const auto& r : resources.image_buffers) {
538 AddImageBuffer(absl::StrCat(name, "_", r.first), r.second);
539 }
540 for (const auto& r : resources.custom_memories) {
541 AddCustomMemory(absl::StrCat(name, "_", r.first), r.second);
542 }
543 }
544
SetInt(const std::string & name,int value)545 absl::Status CLArguments::SetInt(const std::string& name, int value) {
546 auto it = int_values_.find(name);
547 if (it == int_values_.end()) {
548 return absl::NotFoundError(
549 absl::StrCat("No int argument with name - ", name));
550 }
551 it->second.value = value;
552 if (it->second.active) {
553 shared_int4s_data_[it->second.offset] = value;
554 }
555 return absl::OkStatus();
556 }
SetFloat(const std::string & name,float value)557 absl::Status CLArguments::SetFloat(const std::string& name, float value) {
558 auto it = float_values_.find(name);
559 if (it == float_values_.end()) {
560 return absl::NotFoundError(
561 absl::StrCat("No float argument with name - ", name));
562 }
563 it->second.value = value;
564 if (it->second.active) {
565 shared_float4s_data_[it->second.offset] = value;
566 }
567 return absl::OkStatus();
568 }
569
SetHalf(const std::string & name,half value)570 absl::Status CLArguments::SetHalf(const std::string& name, half value) {
571 auto it = half_values_.find(name);
572 if (it == half_values_.end()) {
573 return absl::NotFoundError(
574 absl::StrCat("No half argument with name - ", name));
575 }
576 it->second.value = value;
577 if (it->second.active) {
578 if (it->second.store_as_f32) {
579 shared_float4s_data_[it->second.offset] = value;
580 } else {
581 shared_half4s_data_[it->second.offset] = value;
582 }
583 }
584 return absl::OkStatus();
585 }
586
SetImage2D(const std::string & name,cl_mem memory)587 absl::Status CLArguments::SetImage2D(const std::string& name, cl_mem memory) {
588 auto it = images2d_.find(name);
589 if (it == images2d_.end()) {
590 return absl::NotFoundError(
591 absl::StrCat("No image2D argument with name - ", name));
592 }
593 it->second.memory = memory;
594 return absl::OkStatus();
595 }
596
SetBuffer(const std::string & name,cl_mem memory)597 absl::Status CLArguments::SetBuffer(const std::string& name, cl_mem memory) {
598 auto it = buffers_.find(name);
599 if (it == buffers_.end()) {
600 return absl::NotFoundError(
601 absl::StrCat("No buffer argument with name - ", name));
602 }
603 it->second.memory = memory;
604 return absl::OkStatus();
605 }
606
SetImage2DArray(const std::string & name,cl_mem memory)607 absl::Status CLArguments::SetImage2DArray(const std::string& name,
608 cl_mem memory) {
609 auto it = image2d_arrays_.find(name);
610 if (it == image2d_arrays_.end()) {
611 return absl::NotFoundError(
612 absl::StrCat("No image2D array argument with name - ", name));
613 }
614 it->second.memory = memory;
615 return absl::OkStatus();
616 }
617
SetImage3D(const std::string & name,cl_mem memory)618 absl::Status CLArguments::SetImage3D(const std::string& name, cl_mem memory) {
619 auto it = images3d_.find(name);
620 if (it == images3d_.end()) {
621 return absl::NotFoundError(
622 absl::StrCat("No image3D argument with name - ", name));
623 }
624 it->second.memory = memory;
625 return absl::OkStatus();
626 }
627
SetImageBuffer(const std::string & name,cl_mem memory)628 absl::Status CLArguments::SetImageBuffer(const std::string& name,
629 cl_mem memory) {
630 auto it = image_buffers_.find(name);
631 if (it == image_buffers_.end()) {
632 return absl::NotFoundError(
633 absl::StrCat("No image buffer argument with name - ", name));
634 }
635 it->second.memory = memory;
636 return absl::OkStatus();
637 }
638
SetCustomMemory(const std::string & name,cl_mem memory)639 absl::Status CLArguments::SetCustomMemory(const std::string& name,
640 cl_mem memory) {
641 auto it = custom_memories_.find(name);
642 if (it == custom_memories_.end()) {
643 return absl::NotFoundError(
644 absl::StrCat("No custom memory argument with name - ", name));
645 }
646 it->second.memory = memory;
647 return absl::OkStatus();
648 }
649
SetObjectRef(const std::string & name,const GPUObject * object)650 absl::Status CLArguments::SetObjectRef(const std::string& name,
651 const GPUObject* object) {
652 auto it = object_refs_.find(name);
653 if (it == object_refs_.end()) {
654 return absl::NotFoundError(
655 absl::StrCat("No object ref with name - ", name));
656 }
657 GPUResourcesWithValue resources;
658 RETURN_IF_ERROR(object->GetGPUResources(it->second.get(), &resources));
659 return SetGPUResources(name, resources);
660 }
661
SetGPUResources(const std::string & name,const GPUResourcesWithValue & resources)662 absl::Status CLArguments::SetGPUResources(
663 const std::string& name, const GPUResourcesWithValue& resources) {
664 for (const auto& r : resources.ints) {
665 RETURN_IF_ERROR(SetInt(absl::StrCat(name, "_", r.first), r.second));
666 }
667 for (const auto& r : resources.floats) {
668 RETURN_IF_ERROR(SetFloat(absl::StrCat(name, "_", r.first), r.second));
669 }
670 for (const auto& r : resources.buffers) {
671 RETURN_IF_ERROR(SetBuffer(absl::StrCat(name, "_", r.first), r.second));
672 }
673 for (const auto& r : resources.images2d) {
674 RETURN_IF_ERROR(SetImage2D(absl::StrCat(name, "_", r.first), r.second));
675 }
676 for (const auto& r : resources.image2d_arrays) {
677 RETURN_IF_ERROR(
678 SetImage2DArray(absl::StrCat(name, "_", r.first), r.second));
679 }
680 for (const auto& r : resources.images3d) {
681 RETURN_IF_ERROR(SetImage3D(absl::StrCat(name, "_", r.first), r.second));
682 }
683 for (const auto& r : resources.image_buffers) {
684 RETURN_IF_ERROR(SetImageBuffer(absl::StrCat(name, "_", r.first), r.second));
685 }
686 for (const auto& r : resources.custom_memories) {
687 RETURN_IF_ERROR(
688 SetCustomMemory(absl::StrCat(name, "_", r.first), r.second));
689 }
690 return absl::OkStatus();
691 }
692
GetListOfArgs()693 std::string CLArguments::GetListOfArgs() {
694 std::string result;
695 for (auto& t : buffers_) {
696 const std::string type_name =
697 t.second.desc.data_type == DataType::FLOAT32 ? "float" : "half";
698 std::string attributes;
699 for (const auto& attr : t.second.desc.attributes) {
700 attributes += absl::StrCat(" __attribute__((", attr, "))");
701 }
702 AppendArgument(
703 absl::StrCat(
704 MemoryTypeToCLType(t.second.desc.memory_type), " ",
705 ToCLDataType(t.second.desc.data_type, t.second.desc.element_size),
706 "* ", t.first, attributes),
707 &result);
708 }
709 for (auto& t : image_buffers_) {
710 AppendArgument(absl::StrCat(GetImageModifier(t.second.desc.access_type),
711 " image1d_buffer_t ", t.first),
712 &result);
713 }
714 for (auto& t : images2d_) {
715 AppendArgument(absl::StrCat(GetImageModifier(t.second.desc.access_type),
716 " image2d_t ", t.first),
717 &result);
718 }
719 for (auto& t : image2d_arrays_) {
720 AppendArgument(absl::StrCat(GetImageModifier(t.second.desc.access_type),
721 " image2d_array_t ", t.first),
722 &result);
723 }
724 for (auto& t : images3d_) {
725 AppendArgument(absl::StrCat(GetImageModifier(t.second.desc.access_type),
726 " image3d_t ", t.first),
727 &result);
728 }
729 for (auto& t : custom_memories_) {
730 AppendArgument(absl::StrCat(t.second.desc.type_name, " ", t.first),
731 &result);
732 }
733 for (int i = 0; i < shared_int4s_data_.size() / 4; ++i) {
734 AppendArgument(absl::StrCat("int4 shared_int4_", i), &result);
735 }
736 for (int i = 0; i < shared_float4s_data_.size() / 4; ++i) {
737 AppendArgument(absl::StrCat("float4 shared_float4_", i), &result);
738 }
739 for (int i = 0; i < shared_half4s_data_.size() / 4; ++i) {
740 AppendArgument(absl::StrCat("half4 shared_half4_", i), &result);
741 }
742 return result;
743 }
744
Bind(cl_kernel kernel,int offset)745 absl::Status CLArguments::Bind(cl_kernel kernel, int offset) {
746 for (auto& t : buffers_) {
747 const int error_code =
748 clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
749 if (error_code != CL_SUCCESS) {
750 return absl::UnknownError(absl::StrCat(
751 "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
752 "(at index - ", offset, ")"));
753 }
754 offset++;
755 }
756 for (auto& t : image_buffers_) {
757 const int error_code =
758 clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
759 if (error_code != CL_SUCCESS) {
760 return absl::UnknownError(absl::StrCat(
761 "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
762 "(at index - ", offset, ")"));
763 }
764 offset++;
765 }
766 for (auto& t : images2d_) {
767 const int error_code =
768 clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
769 if (error_code != CL_SUCCESS) {
770 return absl::UnknownError(absl::StrCat(
771 "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
772 "(at index - ", offset, ")"));
773 }
774 offset++;
775 }
776 for (auto& t : image2d_arrays_) {
777 const int error_code =
778 clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
779 if (error_code != CL_SUCCESS) {
780 return absl::UnknownError(absl::StrCat(
781 "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
782 "(at index - ", offset, ")"));
783 }
784 offset++;
785 }
786 for (auto& t : images3d_) {
787 const int error_code =
788 clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
789 if (error_code != CL_SUCCESS) {
790 return absl::UnknownError(absl::StrCat(
791 "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
792 "(at index - ", offset, ")"));
793 }
794 offset++;
795 }
796 for (auto& t : custom_memories_) {
797 const int error_code =
798 clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
799 if (error_code != CL_SUCCESS) {
800 return absl::UnknownError(absl::StrCat(
801 "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
802 "(at index - ", offset, ")"));
803 }
804 offset++;
805 }
806 for (int i = 0; i < shared_int4s_data_.size() / 4; ++i) {
807 const int error_code = clSetKernelArg(kernel, offset, sizeof(int32_t) * 4,
808 &shared_int4s_data_[i * 4]);
809 if (error_code != CL_SUCCESS) {
810 return absl::UnknownError(absl::StrCat(
811 "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
812 "(at index - ", offset, ")"));
813 }
814 offset++;
815 }
816 for (int i = 0; i < shared_float4s_data_.size() / 4; ++i) {
817 const int error_code = clSetKernelArg(kernel, offset, sizeof(int32_t) * 4,
818 &shared_float4s_data_[i * 4]);
819 if (error_code != CL_SUCCESS) {
820 return absl::UnknownError(absl::StrCat(
821 "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
822 "(at index - ", offset, ")"));
823 }
824 offset++;
825 }
826 for (int i = 0; i < shared_half4s_data_.size() / 4; ++i) {
827 const int error_code = clSetKernelArg(kernel, offset, sizeof(int16_t) * 4,
828 &shared_half4s_data_[i * 4]);
829 if (error_code != CL_SUCCESS) {
830 return absl::UnknownError(absl::StrCat(
831 "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
832 "(at index - ", offset, ")"));
833 }
834 offset++;
835 }
836 return absl::OkStatus();
837 }
838
839 } // namespace cl
840 } // namespace gpu
841 } // namespace tflite
842