• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h"
16 
17 #include <string>
18 
19 #include "absl/strings/substitute.h"
20 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
21 #include "tensorflow/lite/delegates/gpu/common/util.h"
22 #include "tensorflow/lite/delegates/gpu/metal/buffer.h"
23 #include "tensorflow/lite/delegates/gpu/metal/linear_storage.h"
24 #include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
25 #include "tensorflow/lite/delegates/gpu/metal/texture2d.h"
26 
27 namespace tflite {
28 namespace gpu {
29 namespace metal {
30 namespace {
IsWordSymbol(char symbol)31 bool IsWordSymbol(char symbol) {
32   return absl::ascii_isalnum(symbol) || symbol == '_';
33 }
34 
ReplaceAllWords(const std::string & old_word,const std::string & new_word,std::string * str)35 void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
36                      std::string* str) {
37   size_t position = str->find(old_word);
38   while (position != std::string::npos) {
39     char prev = position == 0 ? '.' : (*str)[position - 1];
40     char next = position + old_word.size() < str->size()
41                     ? (*str)[position + old_word.size()]
42                     : '.';
43     if (IsWordSymbol(prev) || IsWordSymbol(next)) {
44       position = str->find(old_word, position + 1);
45       continue;
46     }
47     str->replace(position, old_word.size(), new_word);
48     position = str->find(old_word, position + new_word.size());
49   }
50 }
51 
GetNextWord(const std::string & code,size_t first_position)52 std::string GetNextWord(const std::string& code, size_t first_position) {
53   size_t pos = first_position;
54   char t = code[pos];
55   while (IsWordSymbol(t)) {
56     pos++;
57     t = code[pos];
58   }
59   return code.substr(first_position, pos - first_position);
60 }
61 
FindEnclosingBracket(const std::string & text,size_t first_pos,char bracket)62 size_t FindEnclosingBracket(const std::string& text, size_t first_pos,
63                             char bracket) {
64   const std::map<char, char> brackets = {
65       {'(', ')'},
66       {'{', '}'},
67       {'[', ']'},
68       {'<', '>'},
69   };
70   char b_open = bracket;
71   auto it = brackets.find(b_open);
72   if (it == brackets.end()) {
73     return -1;
74   }
75   char b_close = it->second;
76   size_t pos = first_pos;
77   int opened = 1;
78   int closed = 0;
79   while (opened != closed && pos < text.size()) {
80     if (text[pos] == b_open) {
81       opened++;
82     } else if (text[pos] == b_close) {
83       closed++;
84     }
85     pos++;
86   }
87   if (opened == closed) {
88     return pos;
89   } else {
90     return -1;
91   }
92 }
93 
ParseArgsInsideBrackets(const std::string & text,size_t open_bracket_pos,size_t * close_bracket_pos,std::vector<std::string> * args)94 absl::Status ParseArgsInsideBrackets(const std::string& text,
95                                      size_t open_bracket_pos,
96                                      size_t* close_bracket_pos,
97                                      std::vector<std::string>* args) {
98   *close_bracket_pos =
99       FindEnclosingBracket(text, open_bracket_pos + 1, text[open_bracket_pos]);
100   if (*close_bracket_pos == -1) {
101     return absl::NotFoundError("Not found enclosing bracket");
102   }
103   std::string str_args = text.substr(open_bracket_pos + 1,
104                                      *close_bracket_pos - open_bracket_pos - 2);
105   std::vector<absl::string_view> words = absl::StrSplit(str_args, ',');
106   args->reserve(words.size());
107   for (const auto& word : words) {
108     absl::string_view arg = absl::StripAsciiWhitespace(word);
109     if (!arg.empty()) {
110       args->push_back(std::string(arg));
111     }
112   }
113   return absl::OkStatus();
114 }
115 
AppendArgument(const std::string & arg,std::string * args)116 void AppendArgument(const std::string& arg, std::string* args) {
117   if (!args->empty()) {
118     absl::StrAppend(args, ",\n");
119   }
120   absl::StrAppend(args, arg);
121 }
122 
CreateMetalObject(id<MTLDevice> device,GPUObjectDescriptor * desc,GPUObjectPtr * result)123 absl::Status CreateMetalObject(id<MTLDevice> device, GPUObjectDescriptor* desc,
124                             GPUObjectPtr* result) {
125   const auto* buffer_desc = dynamic_cast<const BufferDescriptor*>(desc);
126   if (buffer_desc) {
127     Buffer gpu_buffer;
128     RETURN_IF_ERROR(
129         gpu_buffer.CreateFromBufferDescriptor(*buffer_desc, device));
130     *result = absl::make_unique<Buffer>(std::move(gpu_buffer));
131     return absl::OkStatus();
132   }
133 
134   const auto* texture_desc = dynamic_cast<const Texture2DDescriptor*>(desc);
135   if (texture_desc) {
136     Texture2D gpu_texture;
137     RETURN_IF_ERROR(
138         gpu_texture.CreateFromTexture2DDescriptor(*texture_desc, device));
139     *result = absl::make_unique<Texture2D>(std::move(gpu_texture));
140     return absl::OkStatus();
141   }
142 
143   const auto* linear_desc = dynamic_cast<const TensorLinearDescriptor*>(desc);
144   if (linear_desc) {
145     LinearStorage gpu_storage;
146     RETURN_IF_ERROR(
147         gpu_storage.CreateFromTensorLinearDescriptor(*linear_desc, device));
148     *result = absl::make_unique<LinearStorage>(std::move(gpu_storage));
149     return absl::OkStatus();
150   }
151 
152   const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc);
153   if (tensor_desc) {
154     MetalSpatialTensor gpu_tensor;
155     RETURN_IF_ERROR(gpu_tensor.CreateFromDescriptor(*tensor_desc, device));
156     *result = absl::make_unique<MetalSpatialTensor>(std::move(gpu_tensor));
157     return absl::OkStatus();
158   }
159 
160   return absl::InvalidArgumentError("Unknown GPU descriptor.");
161 }
162 
AccessToMetalTextureAccess(AccessType access_type)163 std::string AccessToMetalTextureAccess(AccessType access_type) {
164   if (access_type == AccessType::READ) {
165     return "access::read";
166   } else if (access_type == AccessType::READ_WRITE) {
167     return "access::read_write";
168   } else if (access_type == AccessType::WRITE) {
169     return "access::write";
170   } else {
171     return "access::unknown";
172   }
173 }
174 }  // namespace
175 
176 // Static
177 constexpr char MetalArguments::kArgsPrefix[];
178 
Init(const std::map<std::string,std::string> & linkables,MetalDevice * device,Arguments * args,std::string * code)179 absl::Status MetalArguments::Init(
180     const std::map<std::string, std::string>& linkables, MetalDevice* device,
181     Arguments* args, std::string* code) {
182   RETURN_IF_ERROR(AllocateObjects(*args, device->device()));
183   RETURN_IF_ERROR(AddObjectArgs(args));
184   RETURN_IF_ERROR(
185       ResolveSelectorsPass(device->GetInfo(), *args, linkables, code));
186   object_refs_ = std::move(args->object_refs_);
187   args->GetActiveArguments(kArgsPrefix, *code);
188   std::string struct_desc = ScalarArgumentsToStructWithVec4Fields(args, code);
189   RETURN_IF_ERROR(SetObjectsResources(*args));
190   ResolveArgsPass(code);
191   std::string header = R"(
192 #include <metal_stdlib>
193 using namespace metal;
194 
195 )";
196   header += struct_desc + "\n";
197   *code = header + *code;
198   std::string arguments = GetListOfArgs(/*buffer_offset*/ 0);
199   const bool use_global_id = code->find("GLOBAL_ID_") != std::string::npos;
200   const bool use_local_id = code->find("LOCAL_ID_") != std::string::npos;
201   const bool use_group_id = code->find("GROUP_ID_") != std::string::npos;
202   const bool use_group_size = code->find("GROUP_SIZE_") != std::string::npos;
203   const bool use_simd_id =
204       code->find("SUB_GROUP_LOCAL_ID") != std::string::npos;
205   if (use_global_id) {
206     AppendArgument("uint3 reserved_gid[[thread_position_in_grid]]", &arguments);
207   }
208   if (use_local_id) {
209     AppendArgument("uint3 reserved_lid[[thread_position_in_threadgroup]]",
210                    &arguments);
211   }
212   if (use_group_id) {
213     AppendArgument("uint3 reserved_group_id[[threadgroup_position_in_grid]]",
214                    &arguments);
215   }
216   if (use_group_size) {
217     AppendArgument("uint3 reserved_group_size[[threads_per_threadgroup]]",
218                    &arguments);
219   }
220   if (use_simd_id) {
221     AppendArgument("uint reserved_simd_id[[thread_index_in_simdgroup]]",
222                    &arguments);
223   }
224   if (!use_global_id && !use_local_id && !use_group_id && !use_group_size &&
225       !arguments.empty()) {
226     arguments += ",\n";
227   }
228   *code = absl::Substitute(*code, arguments);
229   return absl::OkStatus();
230 }
231 
ScalarArgumentsToStructWithScalarFields(Arguments * args,std::string * code)232 std::string MetalArguments::ScalarArgumentsToStructWithScalarFields(
233     Arguments* args, std::string* code) {
234   std::string struct_desc = "struct uniforms_buffer {\n";
235   int pos = 0;
236   for (auto& fvalue : args->float_values_) {
237     auto& new_val = float_values_[fvalue.first];
238     new_val.value = fvalue.second.value;
239     new_val.active = fvalue.second.active;
240     if (fvalue.second.active) {
241       new_val.bytes_offset = pos * 4;
242       pos++;
243       struct_desc += "  float " + fvalue.first + ";\n";
244       ReplaceAllWords(kArgsPrefix + fvalue.first, "U." + fvalue.first, code);
245     }
246   }
247   for (const auto& hfvalue : args->half_values_) {
248     auto& new_val = float_values_[hfvalue.first];
249     new_val.value = hfvalue.second.value;
250     new_val.active = hfvalue.second.active;
251     if (hfvalue.second.active) {
252       new_val.bytes_offset = pos * 4;
253       pos++;
254       struct_desc += "  float " + hfvalue.first + ";\n";
255       ReplaceAllWords(kArgsPrefix + hfvalue.first,
256                       "static_cast<half>(U." + hfvalue.first + ")", code);
257     }
258   }
259   for (auto& ivalue : args->int_values_) {
260     auto& new_val = int_values_[ivalue.first];
261     new_val.value = ivalue.second.value;
262     new_val.active = ivalue.second.active;
263     if (ivalue.second.active) {
264       new_val.bytes_offset = pos * 4;
265       pos++;
266       struct_desc += "  int " + ivalue.first + ";\n";
267       ReplaceAllWords(kArgsPrefix + ivalue.first, "U." + ivalue.first, code);
268     }
269   }
270   if (pos != 0) {
271     int aligned_pos = AlignByN(pos, 4);
272     for (int i = pos; i < aligned_pos; i++) {
273       struct_desc += "  int dummy" + std::to_string(i - pos) + ";\n";
274     }
275     struct_desc += "};";
276     const_data_.resize(aligned_pos * 4);
277     for (auto& it : float_values_) {
278       if (it.second.active) {
279         float* ptr =
280             reinterpret_cast<float*>(&const_data_[it.second.bytes_offset]);
281         *ptr = it.second.value;
282       }
283     }
284     for (auto& it : int_values_) {
285       if (it.second.active) {
286         int32_t* ptr =
287             reinterpret_cast<int32_t*>(&const_data_[it.second.bytes_offset]);
288         *ptr = it.second.value;
289       }
290     }
291   } else {
292     struct_desc = "";
293   }
294   return struct_desc;
295 }
296 
ScalarArgumentsToStructWithVec4Fields(Arguments * args,std::string * code)297 std::string MetalArguments::ScalarArgumentsToStructWithVec4Fields(
298     Arguments* args, std::string* code) {
299   std::string struct_desc = "struct uniforms_buffer {\n";
300   int pos = 0;
301   std::string channels[4] = {".x", ".y", ".z", ".w"};
302   for (auto& fvalue : args->float_values_) {
303     auto& new_val = float_values_[fvalue.first];
304     new_val.value = fvalue.second.value;
305     new_val.active = fvalue.second.active;
306     if (fvalue.second.active) {
307       new_val.bytes_offset = pos * 4;
308       if (pos % 4 == 0) {
309         struct_desc += "  float4 cmp_float4_" + std::to_string(pos / 4) + ";\n";
310       }
311       std::string new_name =
312           "U.cmp_float4_" + std::to_string(pos / 4) + channels[pos % 4];
313       ReplaceAllWords(kArgsPrefix + fvalue.first, new_name, code);
314       pos++;
315     }
316   }
317   for (const auto& hfvalue : args->half_values_) {
318     auto& new_val = float_values_[hfvalue.first];
319     new_val.value = hfvalue.second.value;
320     new_val.active = hfvalue.second.active;
321     if (hfvalue.second.active) {
322       new_val.bytes_offset = pos * 4;
323       if (pos % 4 == 0) {
324         struct_desc += "  float4 cmp_float4_" + std::to_string(pos / 4) + ";\n";
325       }
326       std::string new_name = "static_cast<half>(U.cmp_float4_" +
327                              std::to_string(pos / 4) + channels[pos % 4] + ")";
328       ReplaceAllWords(kArgsPrefix + hfvalue.first, new_name, code);
329       pos++;
330     }
331   }
332   pos = AlignByN(pos, 4);
333   for (auto& ivalue : args->int_values_) {
334     auto& new_val = int_values_[ivalue.first];
335     new_val.value = ivalue.second.value;
336     new_val.active = ivalue.second.active;
337     if (ivalue.second.active) {
338       new_val.bytes_offset = pos * 4;
339       if (pos % 4 == 0) {
340         struct_desc += "  int4 cmp_int4_" + std::to_string(pos / 4) + ";\n";
341       }
342       std::string new_name =
343           "U.cmp_int4_" + std::to_string(pos / 4) + channels[pos % 4];
344       ReplaceAllWords(kArgsPrefix + ivalue.first, new_name, code);
345       pos++;
346     }
347   }
348   if (pos != 0) {
349     int aligned_pos = AlignByN(pos, 4);
350     struct_desc += "};";
351     const_data_.resize(aligned_pos * 4);
352     for (auto& it : float_values_) {
353       if (it.second.active) {
354         float* ptr =
355             reinterpret_cast<float*>(&const_data_[it.second.bytes_offset]);
356         *ptr = it.second.value;
357       }
358     }
359     for (auto& it : int_values_) {
360       if (it.second.active) {
361         int32_t* ptr =
362             reinterpret_cast<int32_t*>(&const_data_[it.second.bytes_offset]);
363         *ptr = it.second.value;
364       }
365     }
366   } else {
367     struct_desc = "";
368   }
369   return struct_desc;
370 }
371 
SetInt(const std::string & name,int value)372 absl::Status MetalArguments::SetInt(const std::string& name, int value) {
373   auto it = int_values_.find(name);
374   if (it == int_values_.end()) {
375     return absl::NotFoundError(
376         absl::StrCat("No int argument with name - ", name));
377   }
378   it->second.value = value;
379   if (it->second.active) {
380     int32_t* ptr =
381         reinterpret_cast<int32_t*>(&const_data_[it->second.bytes_offset]);
382     *ptr = value;
383   }
384   return absl::OkStatus();
385 }
SetFloat(const std::string & name,float value)386 absl::Status MetalArguments::SetFloat(const std::string& name, float value) {
387   auto it = float_values_.find(name);
388   if (it == float_values_.end()) {
389     return absl::NotFoundError(
390         absl::StrCat("No float argument with name - ", name));
391   }
392   it->second.value = value;
393   if (it->second.active) {
394     float* ptr =
395         reinterpret_cast<float*>(&const_data_[it->second.bytes_offset]);
396     *ptr = value;
397   }
398   return absl::OkStatus();
399 }
400 
SetHalf(const std::string & name,half value)401 absl::Status MetalArguments::SetHalf(const std::string& name, half value) {
402   auto it = float_values_.find(name);
403   if (it == float_values_.end()) {
404     return absl::NotFoundError(
405         absl::StrCat("No half argument with name - ", name));
406   }
407   it->second.value = value;
408   if (it->second.active) {
409     float* ptr =
410         reinterpret_cast<float*>(&const_data_[it->second.bytes_offset]);
411     *ptr = value;
412   }
413   return absl::OkStatus();
414 }
415 
SetObjectRef(const std::string & name,const GPUObject & object)416 absl::Status MetalArguments::SetObjectRef(const std::string& name,
417                                           const GPUObject& object) {
418   auto it = object_refs_.find(name);
419   if (it == object_refs_.end()) {
420     return absl::NotFoundError(
421         absl::StrCat("No object ref with name - ", name));
422   }
423   GPUResourcesWithValue resources;
424   RETURN_IF_ERROR(object.GetGPUResources(it->second.get(), &resources));
425   return SetGPUResources(name, resources);
426 }
427 
Encode(id<MTLComputeCommandEncoder> encoder,int buffer_offset,int texture_offset) const428 void MetalArguments::Encode(id<MTLComputeCommandEncoder> encoder,
429                             int buffer_offset, int texture_offset) const {
430   for (auto& b : buffers_) {
431     [encoder setBuffer:b.second.handle offset:0 atIndex:buffer_offset];
432     buffer_offset++;
433   }
434   for (auto& image : images2d_) {
435     [encoder setTexture:image.second.handle atIndex:texture_offset];
436     texture_offset++;
437   }
438   for (auto& image : image2d_arrays_) {
439     [encoder setTexture:image.second.handle atIndex:texture_offset];
440     texture_offset++;
441   }
442   for (auto& image : images3d_) {
443     [encoder setTexture:image.second.handle atIndex:texture_offset];
444     texture_offset++;
445   }
446   for (auto& image : image_buffers_) {
447     [encoder setTexture:image.second.handle atIndex:texture_offset];
448     texture_offset++;
449   }
450 
451   if (!const_data_.empty()) {
452     [encoder setBytes:const_data_.data()
453                length:const_data_.size()
454               atIndex:buffer_offset];
455   }
456 }
457 
AllocateObjects(const Arguments & args,id<MTLDevice> device)458 absl::Status MetalArguments::AllocateObjects(const Arguments& args,
459                                           id<MTLDevice> device) {
460   objects_.resize(args.objects_.size());
461   int i = 0;
462   for (auto& t : args.objects_) {
463     RETURN_IF_ERROR(CreateMetalObject(device, t.second.get(), &objects_[i]));
464     i++;
465   }
466   return absl::OkStatus();
467 }
468 
AddObjectArgs(Arguments * args)469 absl::Status MetalArguments::AddObjectArgs(Arguments* args) {
470   for (auto& t : args->objects_) {
471     AddGPUResources(t.first, t.second->GetGPUResources(), args);
472   }
473   for (auto& t : args->object_refs_) {
474     AddGPUResources(t.first, t.second->GetGPUResources(), args);
475   }
476   return absl::OkStatus();
477 }
478 
GetListOfArgs(int buffer_offset,int textures_offset)479 std::string MetalArguments::GetListOfArgs(int buffer_offset,
480                                           int textures_offset) {
481   std::string result;
482   for (auto& t : buffers_) {
483     AppendArgument(
484         absl::StrCat(MemoryTypeToMetalType(t.second.desc.memory_type), " ",
485                      ToMetalDataType(t.second.desc.data_type,
486                                      t.second.desc.element_size),
487                      "* ", t.first, "[[buffer(", buffer_offset, ")]]"),
488         &result);
489     buffer_offset++;
490   }
491   for (auto& t : images2d_) {
492     std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
493     std::string data_type = ToMetalDataType(t.second.desc.data_type);
494     AppendArgument(absl::StrCat("texture2d<", data_type, ", ", access, "> ",
495                                 t.first, "[[texture(", textures_offset, ")]]"),
496                    &result);
497     textures_offset++;
498   }
499   for (auto& t : image2d_arrays_) {
500     std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
501     std::string data_type = ToMetalDataType(t.second.desc.data_type);
502     AppendArgument(
503         absl::StrCat("texture2d_array<", data_type, ", ", access, "> ", t.first,
504                      "[[texture(", textures_offset, ")]]"),
505         &result);
506     textures_offset++;
507   }
508   for (auto& t : images3d_) {
509     std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
510     std::string data_type = ToMetalDataType(t.second.desc.data_type);
511     AppendArgument(absl::StrCat("texture3d<", data_type, ", ", access, "> ",
512                                 t.first, "[[texture(", textures_offset, ")]]"),
513                    &result);
514     textures_offset++;
515   }
516   for (auto& t : image_buffers_) {
517     std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
518     std::string data_type = ToMetalDataType(t.second.desc.data_type);
519     AppendArgument(
520         absl::StrCat("texture_buffer<", data_type, ", ", access, "> ", t.first,
521                      "[[texture(", textures_offset, ")]]"),
522         &result);
523     textures_offset++;
524   }
525   if (!const_data_.empty()) {
526     AppendArgument(absl::StrCat("constant uniforms_buffer& U[[buffer(",
527                                 buffer_offset, ")]]"),
528                    &result);
529     buffer_offset++;
530   }
531   return result;
532 }
533 
SetGPUResources(const std::string & name,const GPUResourcesWithValue & resources)534 absl::Status MetalArguments::SetGPUResources(
535     const std::string& name, const GPUResourcesWithValue& resources) {
536   for (const auto& r : resources.ints) {
537     RETURN_IF_ERROR(SetInt(absl::StrCat(name, "_", r.first), r.second));
538   }
539   for (const auto& r : resources.floats) {
540     RETURN_IF_ERROR(SetFloat(absl::StrCat(name, "_", r.first), r.second));
541   }
542   for (const auto& r : resources.buffers) {
543     RETURN_IF_ERROR(SetBuffer(absl::StrCat(name, "_", r.first), r.second));
544   }
545   for (const auto& r : resources.images2d) {
546     RETURN_IF_ERROR(SetImage2D(absl::StrCat(name, "_", r.first), r.second));
547   }
548   for (const auto& r : resources.image2d_arrays) {
549     RETURN_IF_ERROR(
550         SetImage2DArray(absl::StrCat(name, "_", r.first), r.second));
551   }
552   for (const auto& r : resources.images3d) {
553     RETURN_IF_ERROR(SetImage3D(absl::StrCat(name, "_", r.first), r.second));
554   }
555   for (const auto& r : resources.image_buffers) {
556     RETURN_IF_ERROR(SetImageBuffer(absl::StrCat(name, "_", r.first), r.second));
557   }
558   return absl::OkStatus();
559 }
560 
AddBuffer(const std::string & name,const GPUBufferDescriptor & desc)561 void MetalArguments::AddBuffer(const std::string& name,
562                                const GPUBufferDescriptor& desc) {
563   buffers_[name].desc = desc;
564 }
565 
AddImage2D(const std::string & name,const GPUImage2DDescriptor & desc)566 void MetalArguments::AddImage2D(const std::string& name,
567                                 const GPUImage2DDescriptor& desc) {
568   images2d_[name].desc = desc;
569 }
570 
AddImage2DArray(const std::string & name,const GPUImage2DArrayDescriptor & desc)571 void MetalArguments::AddImage2DArray(const std::string& name,
572                                      const GPUImage2DArrayDescriptor& desc) {
573   image2d_arrays_[name].desc = desc;
574 }
575 
AddImage3D(const std::string & name,const GPUImage3DDescriptor & desc)576 void MetalArguments::AddImage3D(const std::string& name,
577                                 const GPUImage3DDescriptor& desc) {
578   images3d_[name].desc = desc;
579 }
580 
AddImageBuffer(const std::string & name,const GPUImageBufferDescriptor & desc)581 void MetalArguments::AddImageBuffer(const std::string& name,
582                                     const GPUImageBufferDescriptor& desc) {
583   image_buffers_[name].desc = desc;
584 }
585 
AddGPUResources(const std::string & name,const GPUResources & resources,Arguments * args)586 void MetalArguments::AddGPUResources(const std::string& name,
587                                   const GPUResources& resources,
588                                   Arguments* args) {
589   for (const auto& r : resources.ints) {
590     args->AddInt(absl::StrCat(name, "_", r));
591   }
592   for (const auto& r : resources.floats) {
593     args->AddFloat(absl::StrCat(name, "_", r));
594   }
595   for (const auto& r : resources.buffers) {
596     AddBuffer(absl::StrCat(name, "_", r.first), r.second);
597   }
598   for (const auto& r : resources.images2d) {
599     AddImage2D(absl::StrCat(name, "_", r.first), r.second);
600   }
601   for (const auto& r : resources.image2d_arrays) {
602     AddImage2DArray(absl::StrCat(name, "_", r.first), r.second);
603   }
604   for (const auto& r : resources.images3d) {
605     AddImage3D(absl::StrCat(name, "_", r.first), r.second);
606   }
607   for (const auto& r : resources.image_buffers) {
608     AddImageBuffer(absl::StrCat(name, "_", r.first), r.second);
609   }
610 }
611 
SetBuffer(const std::string & name,id<MTLBuffer> handle)612 absl::Status MetalArguments::SetBuffer(const std::string& name,
613                                        id<MTLBuffer> handle) {
614   auto it = buffers_.find(name);
615   if (it == buffers_.end()) {
616     return absl::NotFoundError(
617         absl::StrCat("No buffer argument with name - ", name));
618   }
619   it->second.handle = handle;
620   return absl::OkStatus();
621 }
622 
SetImage2D(const std::string & name,id<MTLTexture> handle)623 absl::Status MetalArguments::SetImage2D(const std::string& name,
624                                         id<MTLTexture> handle) {
625   auto it = images2d_.find(name);
626   if (it == images2d_.end()) {
627     return absl::NotFoundError(
628         absl::StrCat("No image2d argument with name - ", name));
629   }
630   it->second.handle = handle;
631   return absl::OkStatus();
632 }
633 
SetImage2DArray(const std::string & name,id<MTLTexture> handle)634 absl::Status MetalArguments::SetImage2DArray(const std::string& name,
635                                              id<MTLTexture> handle) {
636   auto it = image2d_arrays_.find(name);
637   if (it == image2d_arrays_.end()) {
638     return absl::NotFoundError(
639         absl::StrCat("No image2d array argument with name - ", name));
640   }
641   it->second.handle = handle;
642   return absl::OkStatus();
643 }
644 
SetImage3D(const std::string & name,id<MTLTexture> handle)645 absl::Status MetalArguments::SetImage3D(const std::string& name,
646                                         id<MTLTexture> handle) {
647   auto it = images3d_.find(name);
648   if (it == images3d_.end()) {
649     return absl::NotFoundError(
650         absl::StrCat("No image3d argument with name - ", name));
651   }
652   it->second.handle = handle;
653   return absl::OkStatus();
654 }
655 
SetImageBuffer(const std::string & name,id<MTLTexture> handle)656 absl::Status MetalArguments::SetImageBuffer(const std::string& name,
657                                             id<MTLTexture> handle) {
658   auto it = image_buffers_.find(name);
659   if (it == image_buffers_.end()) {
660     return absl::NotFoundError(
661         absl::StrCat("No image buffer argument with name - ", name));
662   }
663   it->second.handle = handle;
664   return absl::OkStatus();
665 }
666 
ResolveSelectorsPass(const GpuInfo & gpu_info,const Arguments & args,const std::map<std::string,std::string> & linkables,std::string * code)667 absl::Status MetalArguments::ResolveSelectorsPass(
668     const GpuInfo& gpu_info, const Arguments& args,
669     const std::map<std::string, std::string>& linkables, std::string* code) {
670   std::string result;
671   size_t position = 0;
672   size_t next_position = code->find(kArgsPrefix);
673   while (next_position != std::string::npos) {
674     size_t arg_pos = next_position;
675     next_position += strlen(kArgsPrefix);
676     std::string object_name = GetNextWord(*code, next_position);
677     char next = (*code)[next_position + object_name.size()];
678     if (next == '.') {
679       next_position += object_name.size() + 1;
680       std::string selector_name = GetNextWord(*code, next_position);
681       next_position += selector_name.size();
682       next = (*code)[next_position];
683       std::vector<std::string> template_args;
684       if (next == '<') {
685         size_t close_bracket_pos;
686         RETURN_IF_ERROR(ParseArgsInsideBrackets(
687             *code, next_position, &close_bracket_pos, &template_args));
688         next_position = close_bracket_pos;
689         next = (*code)[next_position];
690       }
691       if (next != '(') {
692         return absl::NotFoundError(absl::StrCat(
693             "Expected ( after ", object_name, ".", selector_name, " call"));
694       }
695       std::vector<std::string> function_args;
696       size_t close_bracket_pos;
697       RETURN_IF_ERROR(ParseArgsInsideBrackets(
698           *code, next_position, &close_bracket_pos, &function_args));
699       for (auto& arg : function_args) {
700         RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, args, {}, &arg));
701       }
702       std::string patch;
703       RETURN_IF_ERROR(ResolveSelector(gpu_info, args, linkables, object_name,
704                                       selector_name, function_args,
705                                       template_args, &patch));
706       code->replace(arg_pos, close_bracket_pos - arg_pos, patch);
707       position = arg_pos + patch.size();
708     } else {
709       position = arg_pos + strlen(kArgsPrefix);
710     }
711     next_position = code->find(kArgsPrefix, position);
712   }
713   return absl::OkStatus();
714 }
715 
ResolveSelector(const GpuInfo & gpu_info,const Arguments & args,const std::map<std::string,std::string> & linkables,const std::string & object_name,const std::string & selector,const std::vector<std::string> & function_args,const std::vector<std::string> & template_args,std::string * result)716 absl::Status MetalArguments::ResolveSelector(
717     const GpuInfo& gpu_info, const Arguments& args,
718     const std::map<std::string, std::string>& linkables,
719     const std::string& object_name, const std::string& selector,
720     const std::vector<std::string>& function_args,
721     const std::vector<std::string>& template_args, std::string* result) {
722   const GPUObjectDescriptor* desc_ptr;
723   auto it_ref = args.object_refs_.find(object_name);
724   auto it_obj = args.objects_.find(object_name);
725   if (it_ref != args.object_refs_.end()) {
726     desc_ptr = it_ref->second.get();
727   } else if (it_obj != args.objects_.end()) {
728     desc_ptr = it_obj->second.get();
729   } else {
730     return absl::NotFoundError(
731         absl::StrCat("No object with name - ", object_name));
732   }
733   auto names = desc_ptr->GetGPUResources().GetNames();
734   const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc_ptr);
735   if (tensor_desc && (selector == "Write" || selector == "Linking")) {
736     auto it = linkables.find(object_name);
737     if (it != linkables.end()) {
738       if (desc_ptr->GetAccess() != AccessType::WRITE &&
739           desc_ptr->GetAccess() != AccessType::READ_WRITE) {
740         return absl::FailedPreconditionError(absl::StrCat(
741             "Object with name - ", object_name, " should have Write access."));
742       }
743       std::string value_name, x_coord, y_coord, s_coord;
744       RETURN_IF_ERROR(tensor_desc->GetLinkingContextFromWriteSelector(
745           function_args, &value_name, &x_coord, &y_coord, &s_coord));
746       // x_coord can have batch size property of link_object
747       ResolveObjectNames(object_name, names, &x_coord);
748       *result = it->second;
749       ReplaceAllWords("in_out_value", value_name, result);
750       ReplaceAllWords("X_COORD", x_coord, result);
751       ReplaceAllWords("Y_COORD", y_coord, result);
752       ReplaceAllWords("S_COORD", s_coord, result);
753       RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, args, {}, result));
754       if (selector == "Linking") {
755         return absl::OkStatus();
756       }
757     }
758   }
759   std::string patch;
760   RETURN_IF_ERROR(desc_ptr->PerformSelector(gpu_info, selector, function_args,
761                                             template_args, &patch));
762   ResolveObjectNames(object_name, names, &patch);
763   *result += patch;
764   return absl::OkStatus();
765 }
766 
ResolveObjectNames(const std::string & object_name,const std::vector<std::string> & member_names,std::string * code)767 void MetalArguments::ResolveObjectNames(
768     const std::string& object_name,
769     const std::vector<std::string>& member_names, std::string* code) {
770   for (const auto& member_name : member_names) {
771     const std::string new_name = kArgsPrefix + object_name + "_" + member_name;
772     ReplaceAllWords(member_name, new_name, code);
773   }
774 }
775 
ResolveArgsPass(std::string * code)776 void MetalArguments::ResolveArgsPass(std::string* code) {
777   size_t position = 0;
778   size_t next_position = code->find(kArgsPrefix);
779   while (next_position != std::string::npos) {
780     size_t arg_pos = next_position;
781     next_position += strlen(kArgsPrefix);
782     std::string object_name = GetNextWord(*code, next_position);
783     std::string new_name = object_name;
784     code->replace(arg_pos, object_name.size() + strlen(kArgsPrefix), new_name);
785     position = arg_pos + new_name.size();
786     next_position = code->find(kArgsPrefix, position);
787   }
788 }
789 
SetObjectsResources(const Arguments & args)790 absl::Status MetalArguments::SetObjectsResources(const Arguments& args) {
791   int i = 0;
792   for (const auto& t : args.objects_) {
793     GPUResourcesWithValue resources;
794     RETURN_IF_ERROR(objects_[i]->GetGPUResources(t.second.get(), &resources));
795     RETURN_IF_ERROR(SetGPUResources(t.first, resources));
796     i++;
797   }
798   return absl::OkStatus();
799 }
800 
801 }  // namespace metal
802 }  // namespace gpu
803 }  // namespace tflite
804