• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/cl/cl_arguments.h"
17 
18 #include <string>
19 
20 #include "absl/strings/ascii.h"
21 #include "absl/strings/match.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/substitute.h"
24 #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
25 #include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
26 #include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
27 #include "tensorflow/lite/delegates/gpu/cl/tensor.h"
28 #include "tensorflow/lite/delegates/gpu/cl/texture2d.h"
29 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
30 #include "tensorflow/lite/delegates/gpu/common/util.h"
31 
32 namespace tflite {
33 namespace gpu {
34 namespace cl {
35 namespace {
IsWordSymbol(char symbol)36 bool IsWordSymbol(char symbol) {
37   return absl::ascii_isalnum(symbol) || symbol == '_';
38 }
39 
ReplaceAllWords(const std::string & old_word,const std::string & new_word,std::string * str)40 void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
41                      std::string* str) {
42   size_t position = str->find(old_word);
43   while (position != std::string::npos) {
44     char prev = position == 0 ? '.' : (*str)[position - 1];
45     char next = position + old_word.size() < str->size()
46                     ? (*str)[position + old_word.size()]
47                     : '.';
48     if (IsWordSymbol(prev) || IsWordSymbol(next)) {
49       position = str->find(old_word, position + 1);
50       continue;
51     }
52     str->replace(position, old_word.size(), new_word);
53     position = str->find(old_word, position + new_word.size());
54   }
55 }
56 
GetNextWord(const std::string & code,size_t first_position)57 std::string GetNextWord(const std::string& code, size_t first_position) {
58   size_t pos = first_position;
59   char t = code[pos];
60   while (IsWordSymbol(t)) {
61     pos++;
62     t = code[pos];
63   }
64   return code.substr(first_position, pos - first_position);
65 }
66 
FindEnclosingBracket(const std::string & text,size_t first_pos,char bracket)67 size_t FindEnclosingBracket(const std::string& text, size_t first_pos,
68                             char bracket) {
69   const std::map<char, char> brackets = {
70       {'(', ')'},
71       {'{', '}'},
72       {'[', ']'},
73       {'<', '>'},
74   };
75   char b_open = bracket;
76   auto it = brackets.find(b_open);
77   if (it == brackets.end()) {
78     return -1;
79   }
80   char b_close = it->second;
81   size_t pos = first_pos;
82   int opened = 1;
83   int closed = 0;
84   while (opened != closed && pos < text.size()) {
85     if (text[pos] == b_open) {
86       opened++;
87     } else if (text[pos] == b_close) {
88       closed++;
89     }
90     pos++;
91   }
92   if (opened == closed) {
93     return pos;
94   } else {
95     return -1;
96   }
97 }
98 
ParseArgsInsideBrackets(const std::string & text,size_t open_bracket_pos,size_t * close_bracket_pos,std::vector<std::string> * args)99 absl::Status ParseArgsInsideBrackets(const std::string& text,
100                                      size_t open_bracket_pos,
101                                      size_t* close_bracket_pos,
102                                      std::vector<std::string>* args) {
103   *close_bracket_pos =
104       FindEnclosingBracket(text, open_bracket_pos + 1, text[open_bracket_pos]);
105   if (*close_bracket_pos == -1) {
106     return absl::NotFoundError("Not found enclosing bracket");
107   }
108   std::string str_args = text.substr(open_bracket_pos + 1,
109                                      *close_bracket_pos - open_bracket_pos - 2);
110   std::vector<absl::string_view> words = absl::StrSplit(str_args, ',');
111   args->reserve(words.size());
112   for (const auto& word : words) {
113     absl::string_view arg = absl::StripAsciiWhitespace(word);
114     if (!arg.empty()) {
115       args->push_back(std::string(arg));
116     }
117   }
118   return absl::OkStatus();
119 }
120 
AppendArgument(const std::string & arg,std::string * args)121 void AppendArgument(const std::string& arg, std::string* args) {
122   if (!args->empty()) {
123     absl::StrAppend(args, ",\n  ");
124   }
125   absl::StrAppend(args, arg);
126 }
127 
GetImageModifier(AccessType access)128 std::string GetImageModifier(AccessType access) {
129   switch (access) {
130     case AccessType::READ:
131       return "__read_only";
132     case AccessType::WRITE:
133       return "__write_only";
134     case AccessType::READ_WRITE:
135       return "__read_write";
136   }
137 }
138 
GetDefaultSamplers(const GpuInfo & gpu_info)139 std::string GetDefaultSamplers(const GpuInfo& gpu_info) {
140   std::string result;
141   result +=
142       "__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | "
143       "CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n";
144   if (gpu_info.IsAdreno() && gpu_info.adreno_info.IsAdreno3xx()) {
145     // Unfortunately, CLK_ADDRESS_CLAMP is very slow on Adreno3xx and
146     // we can observe huge register overhead when compared to other modes.
147 
148     // While using CLK_ADDRESS_NONE with out-of-range image coordinates is
149     // undefined in the OpenCL specification, we have observed that
150     // CLK_ADDRESS_NONE works like CLK_ADDRESS_CLAMP for out-of-range image
151     // coordinates for RGBA F16/F32 textures on Adreno3xx devices. Using
152     // CLK_ADDRESS_NONE is significantly faster than CLK_ADDRESS_CLAMP on Adreno
153     // 3xx.
154     result +=
155         "__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | "
156         "CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n";
157   } else {
158     result +=
159         "__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | "
160         "CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n";
161   }
162 
163   return result;
164 }
165 
CreateCLObject(GPUObjectDescriptor * desc,CLContext * context,GPUObjectPtr * result)166 absl::Status CreateCLObject(GPUObjectDescriptor* desc, CLContext* context,
167                             GPUObjectPtr* result) {
168   const auto* buffer_desc = dynamic_cast<const BufferDescriptor*>(desc);
169   if (buffer_desc) {
170     Buffer gpu_buffer;
171     RETURN_IF_ERROR(
172         gpu_buffer.CreateFromBufferDescriptor(*buffer_desc, context));
173     *result = absl::make_unique<Buffer>(std::move(gpu_buffer));
174     return absl::OkStatus();
175   }
176 
177   const auto* texture_desc = dynamic_cast<const Texture2DDescriptor*>(desc);
178   if (texture_desc) {
179     Texture2D gpu_texture;
180     RETURN_IF_ERROR(
181         gpu_texture.CreateFromTexture2DDescriptor(*texture_desc, context));
182     *result = absl::make_unique<Texture2D>(std::move(gpu_texture));
183     return absl::OkStatus();
184   }
185 
186   const auto* linear_desc = dynamic_cast<const TensorLinearDescriptor*>(desc);
187   if (linear_desc) {
188     LinearStorage gpu_storage;
189     RETURN_IF_ERROR(
190         gpu_storage.CreateFromTensorLinearDescriptor(*linear_desc, context));
191     *result = absl::make_unique<LinearStorage>(std::move(gpu_storage));
192     return absl::OkStatus();
193   }
194 
195   const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc);
196   if (tensor_desc) {
197     Tensor gpu_tensor;
198     RETURN_IF_ERROR(gpu_tensor.CreateFromDescriptor(*tensor_desc, context));
199     *result = absl::make_unique<Tensor>(std::move(gpu_tensor));
200     return absl::OkStatus();
201   }
202 
203   return absl::InvalidArgumentError("Unknown GPU descriptor.");
204 }
205 
206 }  // namespace
207 
208 // Static
209 constexpr char CLArguments::kArgsPrefix[];
210 
Init(const GpuInfo & gpu_info,const std::map<std::string,std::string> & linkables,CLContext * context,Arguments * args,std::string * code)211 absl::Status CLArguments::Init(
212     const GpuInfo& gpu_info,
213     const std::map<std::string, std::string>& linkables, CLContext* context,
214     Arguments* args, std::string* code) {
215   RETURN_IF_ERROR(AllocateObjects(*args, context));
216   RETURN_IF_ERROR(AddObjectArgs(args));
217   RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, *args, linkables, code));
218   object_refs_ = std::move(args->object_refs_);
219   args->GetActiveArguments(kArgsPrefix, *code);
220   const bool use_f32_for_halfs = gpu_info.IsPowerVR();
221   CopyArguments(*args, use_f32_for_halfs);
222   RETURN_IF_ERROR(SetObjectsResources(*args));
223   RenameArgumentsInCode(code);
224   ResolveArgsPass(code);
225   *code = absl::Substitute(*code, GetListOfArgs());
226   if (gpu_info.SupportsImages()) {
227     *code = GetDefaultSamplers(gpu_info) + *code;
228   }
229   return absl::OkStatus();
230 }
231 
Init(const GpuInfo & gpu_info,Arguments * args,CLContext * context)232 absl::Status CLArguments::Init(const GpuInfo& gpu_info, Arguments* args,
233                                CLContext* context) {
234   RETURN_IF_ERROR(AllocateObjects(*args, context));
235   RETURN_IF_ERROR(AddObjectArgs(args));
236   object_refs_ = std::move(args->object_refs_);
237   const bool use_f32_for_halfs = gpu_info.IsPowerVR();
238   CopyArguments(*args, use_f32_for_halfs);
239   RETURN_IF_ERROR(SetObjectsResources(*args));
240   return absl::OkStatus();
241 }
242 
AllocateObjects(const Arguments & args,CLContext * context)243 absl::Status CLArguments::AllocateObjects(const Arguments& args,
244                                           CLContext* context) {
245   objects_.resize(args.objects_.size());
246   int i = 0;
247   for (auto& t : args.objects_) {
248     RETURN_IF_ERROR(CreateCLObject(t.second.get(), context, &objects_[i]));
249     i++;
250   }
251   return absl::OkStatus();
252 }
253 
AddObjectArgs(Arguments * args)254 absl::Status CLArguments::AddObjectArgs(Arguments* args) {
255   for (auto& t : args->objects_) {
256     AddGPUResources(t.first, t.second->GetGPUResources(), args);
257   }
258   for (auto& t : args->object_refs_) {
259     AddGPUResources(t.first, t.second->GetGPUResources(), args);
260   }
261   return absl::OkStatus();
262 }
263 
SetObjectsResources(const Arguments & args)264 absl::Status CLArguments::SetObjectsResources(const Arguments& args) {
265   int i = 0;
266   for (const auto& t : args.objects_) {
267     GPUResourcesWithValue resources;
268     RETURN_IF_ERROR(objects_[i]->GetGPUResources(t.second.get(), &resources));
269     RETURN_IF_ERROR(SetGPUResources(t.first, resources));
270     i++;
271   }
272   return absl::OkStatus();
273 }
274 
ResolveSelectorsPass(const GpuInfo & gpu_info,const Arguments & args,const std::map<std::string,std::string> & linkables,std::string * code)275 absl::Status CLArguments::ResolveSelectorsPass(
276     const GpuInfo& gpu_info, const Arguments& args,
277     const std::map<std::string, std::string>& linkables, std::string* code) {
278   std::string result;
279   size_t position = 0;
280   size_t next_position = code->find(kArgsPrefix);
281   while (next_position != std::string::npos) {
282     size_t arg_pos = next_position;
283     next_position += strlen(kArgsPrefix);
284     std::string object_name = GetNextWord(*code, next_position);
285     char next = (*code)[next_position + object_name.size()];
286     if (next == '.') {
287       next_position += object_name.size() + 1;
288       std::string selector_name = GetNextWord(*code, next_position);
289       next_position += selector_name.size();
290       next = (*code)[next_position];
291       std::vector<std::string> template_args;
292       if (next == '<') {
293         size_t close_bracket_pos;
294         RETURN_IF_ERROR(ParseArgsInsideBrackets(
295             *code, next_position, &close_bracket_pos, &template_args));
296         next_position = close_bracket_pos;
297         next = (*code)[next_position];
298       }
299       if (next != '(') {
300         return absl::NotFoundError(absl::StrCat(
301             "Expected ( after ", object_name, ".", selector_name, " call"));
302       }
303       std::vector<std::string> function_args;
304       size_t close_bracket_pos;
305       RETURN_IF_ERROR(ParseArgsInsideBrackets(
306           *code, next_position, &close_bracket_pos, &function_args));
307       for (auto& arg : function_args) {
308         RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, args, {}, &arg));
309       }
310       std::string patch;
311       RETURN_IF_ERROR(ResolveSelector(gpu_info, args, linkables, object_name,
312                                       selector_name, function_args,
313                                       template_args, &patch));
314       code->replace(arg_pos, close_bracket_pos - arg_pos, patch);
315       position = arg_pos + patch.size();
316     } else {
317       position = arg_pos + strlen(kArgsPrefix);
318     }
319     next_position = code->find(kArgsPrefix, position);
320   }
321   return absl::OkStatus();
322 }
323 
ResolveObjectNames(const std::string & object_name,const std::vector<std::string> & member_names,std::string * code)324 void CLArguments::ResolveObjectNames(
325     const std::string& object_name,
326     const std::vector<std::string>& member_names, std::string* code) {
327   for (const auto& member_name : member_names) {
328     const std::string new_name = kArgsPrefix + object_name + "_" + member_name;
329     ReplaceAllWords(member_name, new_name, code);
330   }
331 }
332 
ResolveSelector(const GpuInfo & gpu_info,const Arguments & args,const std::map<std::string,std::string> & linkables,const std::string & object_name,const std::string & selector,const std::vector<std::string> & function_args,const std::vector<std::string> & template_args,std::string * result)333 absl::Status CLArguments::ResolveSelector(
334     const GpuInfo& gpu_info, const Arguments& args,
335     const std::map<std::string, std::string>& linkables,
336     const std::string& object_name, const std::string& selector,
337     const std::vector<std::string>& function_args,
338     const std::vector<std::string>& template_args, std::string* result) {
339   const GPUObjectDescriptor* desc_ptr;
340   auto it_ref = args.object_refs_.find(object_name);
341   auto it_obj = args.objects_.find(object_name);
342   if (it_ref != args.object_refs_.end()) {
343     desc_ptr = it_ref->second.get();
344   } else if (it_obj != args.objects_.end()) {
345     desc_ptr = it_obj->second.get();
346   } else {
347     return absl::NotFoundError(
348         absl::StrCat("No object with name - ", object_name));
349   }
350   auto names = desc_ptr->GetGPUResources().GetNames();
351   const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc_ptr);
352   if (tensor_desc && (selector == "Write" || selector == "Linking")) {
353     auto it = linkables.find(object_name);
354     if (it != linkables.end()) {
355       if (desc_ptr->GetAccess() != AccessType::WRITE &&
356           desc_ptr->GetAccess() != AccessType::READ_WRITE) {
357         return absl::FailedPreconditionError(absl::StrCat(
358             "Object with name - ", object_name, " should have Write access."));
359       }
360       std::string value_name, x_coord, y_coord, s_coord;
361       RETURN_IF_ERROR(tensor_desc->GetLinkingContextFromWriteSelector(
362           function_args, &value_name, &x_coord, &y_coord, &s_coord));
363       // x_coord can have batch size property of link_object
364       ResolveObjectNames(object_name, names, &x_coord);
365       *result = it->second;
366       ReplaceAllWords("in_out_value", value_name, result);
367       ReplaceAllWords("X_COORD", x_coord, result);
368       ReplaceAllWords("Y_COORD", y_coord, result);
369       ReplaceAllWords("S_COORD", s_coord, result);
370       RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, args, {}, result));
371       if (selector == "Linking") {
372         return absl::OkStatus();
373       }
374     }
375   }
376   std::string patch;
377   RETURN_IF_ERROR(desc_ptr->PerformSelector(gpu_info, selector, function_args,
378                                             template_args, &patch));
379   ResolveObjectNames(object_name, names, &patch);
380   *result += patch;
381   return absl::OkStatus();
382 }
383 
ResolveArgsPass(std::string * code)384 void CLArguments::ResolveArgsPass(std::string* code) {
385   size_t position = 0;
386   size_t next_position = code->find(kArgsPrefix);
387   while (next_position != std::string::npos) {
388     size_t arg_pos = next_position;
389     next_position += strlen(kArgsPrefix);
390     std::string object_name = GetNextWord(*code, next_position);
391     std::string new_name = object_name;
392     code->replace(arg_pos, object_name.size() + strlen(kArgsPrefix), new_name);
393     position = arg_pos + new_name.size();
394     next_position = code->find(kArgsPrefix, position);
395   }
396 }
397 
CopyScalarValues(Arguments * args) const398 void CLArguments::CopyScalarValues(Arguments* args) const {
399   for (const auto& fvalue : float_values_) {
400     args->float_values_[fvalue.first].value = fvalue.second.value;
401   }
402   for (const auto& ivalue : int_values_) {
403     args->int_values_[ivalue.first].value = ivalue.second.value;
404   }
405   for (const auto& hfvalue : half_values_) {
406     args->half_values_[hfvalue.first].value = hfvalue.second.value;
407   }
408 }
409 
CopyArguments(const Arguments & args,bool use_f32_for_halfs)410 void CLArguments::CopyArguments(const Arguments& args, bool use_f32_for_halfs) {
411   for (const auto& fvalue : args.float_values_) {
412     auto& new_val = float_values_[fvalue.first];
413     new_val.value = fvalue.second.value;
414     new_val.active = fvalue.second.active;
415     if (fvalue.second.active) {
416       new_val.offset = shared_float4s_data_.size();
417       shared_float4s_data_.push_back(new_val.value);
418     }
419   }
420   for (const auto& ivalue : args.int_values_) {
421     auto& new_val = int_values_[ivalue.first];
422     new_val.value = ivalue.second.value;
423     new_val.active = ivalue.second.active;
424     if (ivalue.second.active) {
425       new_val.offset = shared_int4s_data_.size();
426       shared_int4s_data_.push_back(new_val.value);
427     }
428   }
429   for (const auto& hfvalue : args.half_values_) {
430     auto& new_val = half_values_[hfvalue.first];
431     new_val.value = hfvalue.second.value;
432     new_val.active = hfvalue.second.active;
433     if (hfvalue.second.active) {
434       if (use_f32_for_halfs) {
435         new_val.store_as_f32 = true;
436         new_val.offset = shared_float4s_data_.size();
437         shared_float4s_data_.push_back(new_val.value);
438       } else {
439         new_val.store_as_f32 = false;
440         new_val.offset = shared_half4s_data_.size();
441         shared_half4s_data_.push_back(new_val.value);
442       }
443     }
444   }
445   int shared_int4s_aligned_size = AlignByN(shared_int4s_data_.size(), 4);
446   shared_int4s_data_.resize(shared_int4s_aligned_size);
447   int shared_float4s_aligned_size = AlignByN(shared_float4s_data_.size(), 4);
448   shared_float4s_data_.resize(shared_float4s_aligned_size);
449   int shared_half4s_aligned_size = AlignByN(shared_half4s_data_.size(), 4);
450   shared_half4s_data_.resize(shared_half4s_aligned_size);
451 }
452 
RenameArgumentsInCode(std::string * code)453 void CLArguments::RenameArgumentsInCode(std::string* code) {
454   const std::string postfixes[4] = {"x", "y", "z", "w"};
455   for (const auto& fvalue : float_values_) {
456     if (fvalue.second.active) {
457       std::string index = std::to_string(fvalue.second.offset / 4);
458       std::string new_name =
459           "shared_float4_" + index + "." + postfixes[fvalue.second.offset % 4];
460       ReplaceAllWords(kArgsPrefix + fvalue.first, new_name, code);
461     }
462   }
463   for (const auto& ivalue : int_values_) {
464     if (ivalue.second.active) {
465       std::string index = std::to_string(ivalue.second.offset / 4);
466       std::string new_name =
467           "shared_int4_" + index + "." + postfixes[ivalue.second.offset % 4];
468       ReplaceAllWords(kArgsPrefix + ivalue.first, new_name, code);
469     }
470   }
471   for (const auto& hfvalue : half_values_) {
472     if (hfvalue.second.active) {
473       std::string index = std::to_string(hfvalue.second.offset / 4);
474       std::string new_name;
475       if (hfvalue.second.store_as_f32) {
476         new_name = "(half)(shared_float4_" + index + "." +
477                    postfixes[hfvalue.second.offset % 4] + ")";
478       } else {
479         new_name = "shared_half4_" + index + "." +
480                    postfixes[hfvalue.second.offset % 4];
481       }
482       ReplaceAllWords(kArgsPrefix + hfvalue.first, new_name, code);
483     }
484   }
485 }
486 
AddBuffer(const std::string & name,const GPUBufferDescriptor & desc)487 void CLArguments::AddBuffer(const std::string& name,
488                             const GPUBufferDescriptor& desc) {
489   buffers_[name].desc = desc;
490 }
AddImage2D(const std::string & name,const GPUImage2DDescriptor & desc)491 void CLArguments::AddImage2D(const std::string& name,
492                              const GPUImage2DDescriptor& desc) {
493   images2d_[name].desc = desc;
494 }
495 
AddImage2DArray(const std::string & name,const GPUImage2DArrayDescriptor & desc)496 void CLArguments::AddImage2DArray(const std::string& name,
497                                   const GPUImage2DArrayDescriptor& desc) {
498   image2d_arrays_[name].desc = desc;
499 }
500 
AddImage3D(const std::string & name,const GPUImage3DDescriptor & desc)501 void CLArguments::AddImage3D(const std::string& name,
502                              const GPUImage3DDescriptor& desc) {
503   images3d_[name].desc = desc;
504 }
505 
AddImageBuffer(const std::string & name,const GPUImageBufferDescriptor & desc)506 void CLArguments::AddImageBuffer(const std::string& name,
507                                  const GPUImageBufferDescriptor& desc) {
508   image_buffers_[name].desc = desc;
509 }
510 
AddCustomMemory(const std::string & name,const GPUCustomMemoryDescriptor & desc)511 void CLArguments::AddCustomMemory(const std::string& name,
512                                   const GPUCustomMemoryDescriptor& desc) {
513   custom_memories_[name].desc = desc;
514 }
515 
AddGPUResources(const std::string & name,const GPUResources & resources,Arguments * args)516 void CLArguments::AddGPUResources(const std::string& name,
517                                   const GPUResources& resources,
518                                   Arguments* args) {
519   for (const auto& r : resources.ints) {
520     args->AddInt(absl::StrCat(name, "_", r));
521   }
522   for (const auto& r : resources.floats) {
523     args->AddFloat(absl::StrCat(name, "_", r));
524   }
525   for (const auto& r : resources.buffers) {
526     AddBuffer(absl::StrCat(name, "_", r.first), r.second);
527   }
528   for (const auto& r : resources.images2d) {
529     AddImage2D(absl::StrCat(name, "_", r.first), r.second);
530   }
531   for (const auto& r : resources.image2d_arrays) {
532     AddImage2DArray(absl::StrCat(name, "_", r.first), r.second);
533   }
534   for (const auto& r : resources.images3d) {
535     AddImage3D(absl::StrCat(name, "_", r.first), r.second);
536   }
537   for (const auto& r : resources.image_buffers) {
538     AddImageBuffer(absl::StrCat(name, "_", r.first), r.second);
539   }
540   for (const auto& r : resources.custom_memories) {
541     AddCustomMemory(absl::StrCat(name, "_", r.first), r.second);
542   }
543 }
544 
SetInt(const std::string & name,int value)545 absl::Status CLArguments::SetInt(const std::string& name, int value) {
546   auto it = int_values_.find(name);
547   if (it == int_values_.end()) {
548     return absl::NotFoundError(
549         absl::StrCat("No int argument with name - ", name));
550   }
551   it->second.value = value;
552   if (it->second.active) {
553     shared_int4s_data_[it->second.offset] = value;
554   }
555   return absl::OkStatus();
556 }
SetFloat(const std::string & name,float value)557 absl::Status CLArguments::SetFloat(const std::string& name, float value) {
558   auto it = float_values_.find(name);
559   if (it == float_values_.end()) {
560     return absl::NotFoundError(
561         absl::StrCat("No float argument with name - ", name));
562   }
563   it->second.value = value;
564   if (it->second.active) {
565     shared_float4s_data_[it->second.offset] = value;
566   }
567   return absl::OkStatus();
568 }
569 
SetHalf(const std::string & name,half value)570 absl::Status CLArguments::SetHalf(const std::string& name, half value) {
571   auto it = half_values_.find(name);
572   if (it == half_values_.end()) {
573     return absl::NotFoundError(
574         absl::StrCat("No half argument with name - ", name));
575   }
576   it->second.value = value;
577   if (it->second.active) {
578     if (it->second.store_as_f32) {
579       shared_float4s_data_[it->second.offset] = value;
580     } else {
581       shared_half4s_data_[it->second.offset] = value;
582     }
583   }
584   return absl::OkStatus();
585 }
586 
SetImage2D(const std::string & name,cl_mem memory)587 absl::Status CLArguments::SetImage2D(const std::string& name, cl_mem memory) {
588   auto it = images2d_.find(name);
589   if (it == images2d_.end()) {
590     return absl::NotFoundError(
591         absl::StrCat("No image2D argument with name - ", name));
592   }
593   it->second.memory = memory;
594   return absl::OkStatus();
595 }
596 
SetBuffer(const std::string & name,cl_mem memory)597 absl::Status CLArguments::SetBuffer(const std::string& name, cl_mem memory) {
598   auto it = buffers_.find(name);
599   if (it == buffers_.end()) {
600     return absl::NotFoundError(
601         absl::StrCat("No buffer argument with name - ", name));
602   }
603   it->second.memory = memory;
604   return absl::OkStatus();
605 }
606 
SetImage2DArray(const std::string & name,cl_mem memory)607 absl::Status CLArguments::SetImage2DArray(const std::string& name,
608                                           cl_mem memory) {
609   auto it = image2d_arrays_.find(name);
610   if (it == image2d_arrays_.end()) {
611     return absl::NotFoundError(
612         absl::StrCat("No image2D array argument with name - ", name));
613   }
614   it->second.memory = memory;
615   return absl::OkStatus();
616 }
617 
SetImage3D(const std::string & name,cl_mem memory)618 absl::Status CLArguments::SetImage3D(const std::string& name, cl_mem memory) {
619   auto it = images3d_.find(name);
620   if (it == images3d_.end()) {
621     return absl::NotFoundError(
622         absl::StrCat("No image3D argument with name - ", name));
623   }
624   it->second.memory = memory;
625   return absl::OkStatus();
626 }
627 
SetImageBuffer(const std::string & name,cl_mem memory)628 absl::Status CLArguments::SetImageBuffer(const std::string& name,
629                                          cl_mem memory) {
630   auto it = image_buffers_.find(name);
631   if (it == image_buffers_.end()) {
632     return absl::NotFoundError(
633         absl::StrCat("No image buffer argument with name - ", name));
634   }
635   it->second.memory = memory;
636   return absl::OkStatus();
637 }
638 
SetCustomMemory(const std::string & name,cl_mem memory)639 absl::Status CLArguments::SetCustomMemory(const std::string& name,
640                                           cl_mem memory) {
641   auto it = custom_memories_.find(name);
642   if (it == custom_memories_.end()) {
643     return absl::NotFoundError(
644         absl::StrCat("No custom memory argument with name - ", name));
645   }
646   it->second.memory = memory;
647   return absl::OkStatus();
648 }
649 
SetObjectRef(const std::string & name,const GPUObject * object)650 absl::Status CLArguments::SetObjectRef(const std::string& name,
651                                        const GPUObject* object) {
652   auto it = object_refs_.find(name);
653   if (it == object_refs_.end()) {
654     return absl::NotFoundError(
655         absl::StrCat("No object ref with name - ", name));
656   }
657   GPUResourcesWithValue resources;
658   RETURN_IF_ERROR(object->GetGPUResources(it->second.get(), &resources));
659   return SetGPUResources(name, resources);
660 }
661 
SetGPUResources(const std::string & name,const GPUResourcesWithValue & resources)662 absl::Status CLArguments::SetGPUResources(
663     const std::string& name, const GPUResourcesWithValue& resources) {
664   for (const auto& r : resources.ints) {
665     RETURN_IF_ERROR(SetInt(absl::StrCat(name, "_", r.first), r.second));
666   }
667   for (const auto& r : resources.floats) {
668     RETURN_IF_ERROR(SetFloat(absl::StrCat(name, "_", r.first), r.second));
669   }
670   for (const auto& r : resources.buffers) {
671     RETURN_IF_ERROR(SetBuffer(absl::StrCat(name, "_", r.first), r.second));
672   }
673   for (const auto& r : resources.images2d) {
674     RETURN_IF_ERROR(SetImage2D(absl::StrCat(name, "_", r.first), r.second));
675   }
676   for (const auto& r : resources.image2d_arrays) {
677     RETURN_IF_ERROR(
678         SetImage2DArray(absl::StrCat(name, "_", r.first), r.second));
679   }
680   for (const auto& r : resources.images3d) {
681     RETURN_IF_ERROR(SetImage3D(absl::StrCat(name, "_", r.first), r.second));
682   }
683   for (const auto& r : resources.image_buffers) {
684     RETURN_IF_ERROR(SetImageBuffer(absl::StrCat(name, "_", r.first), r.second));
685   }
686   for (const auto& r : resources.custom_memories) {
687     RETURN_IF_ERROR(
688         SetCustomMemory(absl::StrCat(name, "_", r.first), r.second));
689   }
690   return absl::OkStatus();
691 }
692 
GetListOfArgs()693 std::string CLArguments::GetListOfArgs() {
694   std::string result;
695   for (auto& t : buffers_) {
696     const std::string type_name =
697         t.second.desc.data_type == DataType::FLOAT32 ? "float" : "half";
698     std::string attributes;
699     for (const auto& attr : t.second.desc.attributes) {
700       attributes += absl::StrCat("  __attribute__((", attr, "))");
701     }
702     AppendArgument(
703         absl::StrCat(
704             MemoryTypeToCLType(t.second.desc.memory_type), " ",
705             ToCLDataType(t.second.desc.data_type, t.second.desc.element_size),
706             "* ", t.first, attributes),
707         &result);
708   }
709   for (auto& t : image_buffers_) {
710     AppendArgument(absl::StrCat(GetImageModifier(t.second.desc.access_type),
711                                 " image1d_buffer_t ", t.first),
712                    &result);
713   }
714   for (auto& t : images2d_) {
715     AppendArgument(absl::StrCat(GetImageModifier(t.second.desc.access_type),
716                                 " image2d_t ", t.first),
717                    &result);
718   }
719   for (auto& t : image2d_arrays_) {
720     AppendArgument(absl::StrCat(GetImageModifier(t.second.desc.access_type),
721                                 " image2d_array_t ", t.first),
722                    &result);
723   }
724   for (auto& t : images3d_) {
725     AppendArgument(absl::StrCat(GetImageModifier(t.second.desc.access_type),
726                                 " image3d_t ", t.first),
727                    &result);
728   }
729   for (auto& t : custom_memories_) {
730     AppendArgument(absl::StrCat(t.second.desc.type_name, " ", t.first),
731                    &result);
732   }
733   for (int i = 0; i < shared_int4s_data_.size() / 4; ++i) {
734     AppendArgument(absl::StrCat("int4 shared_int4_", i), &result);
735   }
736   for (int i = 0; i < shared_float4s_data_.size() / 4; ++i) {
737     AppendArgument(absl::StrCat("float4 shared_float4_", i), &result);
738   }
739   for (int i = 0; i < shared_half4s_data_.size() / 4; ++i) {
740     AppendArgument(absl::StrCat("half4 shared_half4_", i), &result);
741   }
742   return result;
743 }
744 
Bind(cl_kernel kernel,int offset)745 absl::Status CLArguments::Bind(cl_kernel kernel, int offset) {
746   for (auto& t : buffers_) {
747     const int error_code =
748         clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
749     if (error_code != CL_SUCCESS) {
750       return absl::UnknownError(absl::StrCat(
751           "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
752           "(at index - ", offset, ")"));
753     }
754     offset++;
755   }
756   for (auto& t : image_buffers_) {
757     const int error_code =
758         clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
759     if (error_code != CL_SUCCESS) {
760       return absl::UnknownError(absl::StrCat(
761           "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
762           "(at index - ", offset, ")"));
763     }
764     offset++;
765   }
766   for (auto& t : images2d_) {
767     const int error_code =
768         clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
769     if (error_code != CL_SUCCESS) {
770       return absl::UnknownError(absl::StrCat(
771           "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
772           "(at index - ", offset, ")"));
773     }
774     offset++;
775   }
776   for (auto& t : image2d_arrays_) {
777     const int error_code =
778         clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
779     if (error_code != CL_SUCCESS) {
780       return absl::UnknownError(absl::StrCat(
781           "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
782           "(at index - ", offset, ")"));
783     }
784     offset++;
785   }
786   for (auto& t : images3d_) {
787     const int error_code =
788         clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
789     if (error_code != CL_SUCCESS) {
790       return absl::UnknownError(absl::StrCat(
791           "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
792           "(at index - ", offset, ")"));
793     }
794     offset++;
795   }
796   for (auto& t : custom_memories_) {
797     const int error_code =
798         clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
799     if (error_code != CL_SUCCESS) {
800       return absl::UnknownError(absl::StrCat(
801           "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
802           "(at index - ", offset, ")"));
803     }
804     offset++;
805   }
806   for (int i = 0; i < shared_int4s_data_.size() / 4; ++i) {
807     const int error_code = clSetKernelArg(kernel, offset, sizeof(int32_t) * 4,
808                                           &shared_int4s_data_[i * 4]);
809     if (error_code != CL_SUCCESS) {
810       return absl::UnknownError(absl::StrCat(
811           "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
812           "(at index - ", offset, ")"));
813     }
814     offset++;
815   }
816   for (int i = 0; i < shared_float4s_data_.size() / 4; ++i) {
817     const int error_code = clSetKernelArg(kernel, offset, sizeof(int32_t) * 4,
818                                           &shared_float4s_data_[i * 4]);
819     if (error_code != CL_SUCCESS) {
820       return absl::UnknownError(absl::StrCat(
821           "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
822           "(at index - ", offset, ")"));
823     }
824     offset++;
825   }
826   for (int i = 0; i < shared_half4s_data_.size() / 4; ++i) {
827     const int error_code = clSetKernelArg(kernel, offset, sizeof(int16_t) * 4,
828                                           &shared_half4s_data_[i * 4]);
829     if (error_code != CL_SUCCESS) {
830       return absl::UnknownError(absl::StrCat(
831           "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
832           "(at index - ", offset, ")"));
833     }
834     offset++;
835   }
836   return absl::OkStatus();
837 }
838 
839 }  // namespace cl
840 }  // namespace gpu
841 }  // namespace tflite
842