1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h"
16
17 #include <string>
18
19 #include "absl/strings/substitute.h"
20 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
21 #include "tensorflow/lite/delegates/gpu/common/util.h"
22 #include "tensorflow/lite/delegates/gpu/metal/buffer.h"
23 #include "tensorflow/lite/delegates/gpu/metal/linear_storage.h"
24 #include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
25 #include "tensorflow/lite/delegates/gpu/metal/texture2d.h"
26
27 namespace tflite {
28 namespace gpu {
29 namespace metal {
30 namespace {
IsWordSymbol(char symbol)31 bool IsWordSymbol(char symbol) {
32 return absl::ascii_isalnum(symbol) || symbol == '_';
33 }
34
ReplaceAllWords(const std::string & old_word,const std::string & new_word,std::string * str)35 void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
36 std::string* str) {
37 size_t position = str->find(old_word);
38 while (position != std::string::npos) {
39 char prev = position == 0 ? '.' : (*str)[position - 1];
40 char next = position + old_word.size() < str->size()
41 ? (*str)[position + old_word.size()]
42 : '.';
43 if (IsWordSymbol(prev) || IsWordSymbol(next)) {
44 position = str->find(old_word, position + 1);
45 continue;
46 }
47 str->replace(position, old_word.size(), new_word);
48 position = str->find(old_word, position + new_word.size());
49 }
50 }
51
GetNextWord(const std::string & code,size_t first_position)52 std::string GetNextWord(const std::string& code, size_t first_position) {
53 size_t pos = first_position;
54 char t = code[pos];
55 while (IsWordSymbol(t)) {
56 pos++;
57 t = code[pos];
58 }
59 return code.substr(first_position, pos - first_position);
60 }
61
FindEnclosingBracket(const std::string & text,size_t first_pos,char bracket)62 size_t FindEnclosingBracket(const std::string& text, size_t first_pos,
63 char bracket) {
64 const std::map<char, char> brackets = {
65 {'(', ')'},
66 {'{', '}'},
67 {'[', ']'},
68 {'<', '>'},
69 };
70 char b_open = bracket;
71 auto it = brackets.find(b_open);
72 if (it == brackets.end()) {
73 return -1;
74 }
75 char b_close = it->second;
76 size_t pos = first_pos;
77 int opened = 1;
78 int closed = 0;
79 while (opened != closed && pos < text.size()) {
80 if (text[pos] == b_open) {
81 opened++;
82 } else if (text[pos] == b_close) {
83 closed++;
84 }
85 pos++;
86 }
87 if (opened == closed) {
88 return pos;
89 } else {
90 return -1;
91 }
92 }
93
ParseArgsInsideBrackets(const std::string & text,size_t open_bracket_pos,size_t * close_bracket_pos,std::vector<std::string> * args)94 absl::Status ParseArgsInsideBrackets(const std::string& text,
95 size_t open_bracket_pos,
96 size_t* close_bracket_pos,
97 std::vector<std::string>* args) {
98 *close_bracket_pos =
99 FindEnclosingBracket(text, open_bracket_pos + 1, text[open_bracket_pos]);
100 if (*close_bracket_pos == -1) {
101 return absl::NotFoundError("Not found enclosing bracket");
102 }
103 std::string str_args = text.substr(open_bracket_pos + 1,
104 *close_bracket_pos - open_bracket_pos - 2);
105 std::vector<absl::string_view> words = absl::StrSplit(str_args, ',');
106 args->reserve(words.size());
107 for (const auto& word : words) {
108 absl::string_view arg = absl::StripAsciiWhitespace(word);
109 if (!arg.empty()) {
110 args->push_back(std::string(arg));
111 }
112 }
113 return absl::OkStatus();
114 }
115
AppendArgument(const std::string & arg,std::string * args)116 void AppendArgument(const std::string& arg, std::string* args) {
117 if (!args->empty()) {
118 absl::StrAppend(args, ",\n");
119 }
120 absl::StrAppend(args, arg);
121 }
122
CreateMetalObject(id<MTLDevice> device,GPUObjectDescriptor * desc,GPUObjectPtr * result)123 absl::Status CreateMetalObject(id<MTLDevice> device, GPUObjectDescriptor* desc,
124 GPUObjectPtr* result) {
125 const auto* buffer_desc = dynamic_cast<const BufferDescriptor*>(desc);
126 if (buffer_desc) {
127 Buffer gpu_buffer;
128 RETURN_IF_ERROR(
129 gpu_buffer.CreateFromBufferDescriptor(*buffer_desc, device));
130 *result = absl::make_unique<Buffer>(std::move(gpu_buffer));
131 return absl::OkStatus();
132 }
133
134 const auto* texture_desc = dynamic_cast<const Texture2DDescriptor*>(desc);
135 if (texture_desc) {
136 Texture2D gpu_texture;
137 RETURN_IF_ERROR(
138 gpu_texture.CreateFromTexture2DDescriptor(*texture_desc, device));
139 *result = absl::make_unique<Texture2D>(std::move(gpu_texture));
140 return absl::OkStatus();
141 }
142
143 const auto* linear_desc = dynamic_cast<const TensorLinearDescriptor*>(desc);
144 if (linear_desc) {
145 LinearStorage gpu_storage;
146 RETURN_IF_ERROR(
147 gpu_storage.CreateFromTensorLinearDescriptor(*linear_desc, device));
148 *result = absl::make_unique<LinearStorage>(std::move(gpu_storage));
149 return absl::OkStatus();
150 }
151
152 const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc);
153 if (tensor_desc) {
154 MetalSpatialTensor gpu_tensor;
155 RETURN_IF_ERROR(gpu_tensor.CreateFromDescriptor(*tensor_desc, device));
156 *result = absl::make_unique<MetalSpatialTensor>(std::move(gpu_tensor));
157 return absl::OkStatus();
158 }
159
160 return absl::InvalidArgumentError("Unknown GPU descriptor.");
161 }
162
AccessToMetalTextureAccess(AccessType access_type)163 std::string AccessToMetalTextureAccess(AccessType access_type) {
164 if (access_type == AccessType::READ) {
165 return "access::read";
166 } else if (access_type == AccessType::READ_WRITE) {
167 return "access::read_write";
168 } else if (access_type == AccessType::WRITE) {
169 return "access::write";
170 } else {
171 return "access::unknown";
172 }
173 }
174 } // namespace
175
176 // Static
177 constexpr char MetalArguments::kArgsPrefix[];
178
Init(const std::map<std::string,std::string> & linkables,MetalDevice * device,Arguments * args,std::string * code)179 absl::Status MetalArguments::Init(
180 const std::map<std::string, std::string>& linkables, MetalDevice* device,
181 Arguments* args, std::string* code) {
182 RETURN_IF_ERROR(AllocateObjects(*args, device->device()));
183 RETURN_IF_ERROR(AddObjectArgs(args));
184 RETURN_IF_ERROR(
185 ResolveSelectorsPass(device->GetInfo(), *args, linkables, code));
186 object_refs_ = std::move(args->object_refs_);
187 args->GetActiveArguments(kArgsPrefix, *code);
188 std::string struct_desc = ScalarArgumentsToStructWithVec4Fields(args, code);
189 RETURN_IF_ERROR(SetObjectsResources(*args));
190 ResolveArgsPass(code);
191 std::string header = R"(
192 #include <metal_stdlib>
193 using namespace metal;
194
195 )";
196 header += struct_desc + "\n";
197 *code = header + *code;
198 std::string arguments = GetListOfArgs(/*buffer_offset*/ 0);
199 const bool use_global_id = code->find("GLOBAL_ID_") != std::string::npos;
200 const bool use_local_id = code->find("LOCAL_ID_") != std::string::npos;
201 const bool use_group_id = code->find("GROUP_ID_") != std::string::npos;
202 const bool use_group_size = code->find("GROUP_SIZE_") != std::string::npos;
203 const bool use_simd_id =
204 code->find("SUB_GROUP_LOCAL_ID") != std::string::npos;
205 if (use_global_id) {
206 AppendArgument("uint3 reserved_gid[[thread_position_in_grid]]", &arguments);
207 }
208 if (use_local_id) {
209 AppendArgument("uint3 reserved_lid[[thread_position_in_threadgroup]]",
210 &arguments);
211 }
212 if (use_group_id) {
213 AppendArgument("uint3 reserved_group_id[[threadgroup_position_in_grid]]",
214 &arguments);
215 }
216 if (use_group_size) {
217 AppendArgument("uint3 reserved_group_size[[threads_per_threadgroup]]",
218 &arguments);
219 }
220 if (use_simd_id) {
221 AppendArgument("uint reserved_simd_id[[thread_index_in_simdgroup]]",
222 &arguments);
223 }
224 if (!use_global_id && !use_local_id && !use_group_id && !use_group_size &&
225 !arguments.empty()) {
226 arguments += ",\n";
227 }
228 *code = absl::Substitute(*code, arguments);
229 return absl::OkStatus();
230 }
231
ScalarArgumentsToStructWithScalarFields(Arguments * args,std::string * code)232 std::string MetalArguments::ScalarArgumentsToStructWithScalarFields(
233 Arguments* args, std::string* code) {
234 std::string struct_desc = "struct uniforms_buffer {\n";
235 int pos = 0;
236 for (auto& fvalue : args->float_values_) {
237 auto& new_val = float_values_[fvalue.first];
238 new_val.value = fvalue.second.value;
239 new_val.active = fvalue.second.active;
240 if (fvalue.second.active) {
241 new_val.bytes_offset = pos * 4;
242 pos++;
243 struct_desc += " float " + fvalue.first + ";\n";
244 ReplaceAllWords(kArgsPrefix + fvalue.first, "U." + fvalue.first, code);
245 }
246 }
247 for (const auto& hfvalue : args->half_values_) {
248 auto& new_val = float_values_[hfvalue.first];
249 new_val.value = hfvalue.second.value;
250 new_val.active = hfvalue.second.active;
251 if (hfvalue.second.active) {
252 new_val.bytes_offset = pos * 4;
253 pos++;
254 struct_desc += " float " + hfvalue.first + ";\n";
255 ReplaceAllWords(kArgsPrefix + hfvalue.first,
256 "static_cast<half>(U." + hfvalue.first + ")", code);
257 }
258 }
259 for (auto& ivalue : args->int_values_) {
260 auto& new_val = int_values_[ivalue.first];
261 new_val.value = ivalue.second.value;
262 new_val.active = ivalue.second.active;
263 if (ivalue.second.active) {
264 new_val.bytes_offset = pos * 4;
265 pos++;
266 struct_desc += " int " + ivalue.first + ";\n";
267 ReplaceAllWords(kArgsPrefix + ivalue.first, "U." + ivalue.first, code);
268 }
269 }
270 if (pos != 0) {
271 int aligned_pos = AlignByN(pos, 4);
272 for (int i = pos; i < aligned_pos; i++) {
273 struct_desc += " int dummy" + std::to_string(i - pos) + ";\n";
274 }
275 struct_desc += "};";
276 const_data_.resize(aligned_pos * 4);
277 for (auto& it : float_values_) {
278 if (it.second.active) {
279 float* ptr =
280 reinterpret_cast<float*>(&const_data_[it.second.bytes_offset]);
281 *ptr = it.second.value;
282 }
283 }
284 for (auto& it : int_values_) {
285 if (it.second.active) {
286 int32_t* ptr =
287 reinterpret_cast<int32_t*>(&const_data_[it.second.bytes_offset]);
288 *ptr = it.second.value;
289 }
290 }
291 } else {
292 struct_desc = "";
293 }
294 return struct_desc;
295 }
296
ScalarArgumentsToStructWithVec4Fields(Arguments * args,std::string * code)297 std::string MetalArguments::ScalarArgumentsToStructWithVec4Fields(
298 Arguments* args, std::string* code) {
299 std::string struct_desc = "struct uniforms_buffer {\n";
300 int pos = 0;
301 std::string channels[4] = {".x", ".y", ".z", ".w"};
302 for (auto& fvalue : args->float_values_) {
303 auto& new_val = float_values_[fvalue.first];
304 new_val.value = fvalue.second.value;
305 new_val.active = fvalue.second.active;
306 if (fvalue.second.active) {
307 new_val.bytes_offset = pos * 4;
308 if (pos % 4 == 0) {
309 struct_desc += " float4 cmp_float4_" + std::to_string(pos / 4) + ";\n";
310 }
311 std::string new_name =
312 "U.cmp_float4_" + std::to_string(pos / 4) + channels[pos % 4];
313 ReplaceAllWords(kArgsPrefix + fvalue.first, new_name, code);
314 pos++;
315 }
316 }
317 for (const auto& hfvalue : args->half_values_) {
318 auto& new_val = float_values_[hfvalue.first];
319 new_val.value = hfvalue.second.value;
320 new_val.active = hfvalue.second.active;
321 if (hfvalue.second.active) {
322 new_val.bytes_offset = pos * 4;
323 if (pos % 4 == 0) {
324 struct_desc += " float4 cmp_float4_" + std::to_string(pos / 4) + ";\n";
325 }
326 std::string new_name = "static_cast<half>(U.cmp_float4_" +
327 std::to_string(pos / 4) + channels[pos % 4] + ")";
328 ReplaceAllWords(kArgsPrefix + hfvalue.first, new_name, code);
329 pos++;
330 }
331 }
332 pos = AlignByN(pos, 4);
333 for (auto& ivalue : args->int_values_) {
334 auto& new_val = int_values_[ivalue.first];
335 new_val.value = ivalue.second.value;
336 new_val.active = ivalue.second.active;
337 if (ivalue.second.active) {
338 new_val.bytes_offset = pos * 4;
339 if (pos % 4 == 0) {
340 struct_desc += " int4 cmp_int4_" + std::to_string(pos / 4) + ";\n";
341 }
342 std::string new_name =
343 "U.cmp_int4_" + std::to_string(pos / 4) + channels[pos % 4];
344 ReplaceAllWords(kArgsPrefix + ivalue.first, new_name, code);
345 pos++;
346 }
347 }
348 if (pos != 0) {
349 int aligned_pos = AlignByN(pos, 4);
350 struct_desc += "};";
351 const_data_.resize(aligned_pos * 4);
352 for (auto& it : float_values_) {
353 if (it.second.active) {
354 float* ptr =
355 reinterpret_cast<float*>(&const_data_[it.second.bytes_offset]);
356 *ptr = it.second.value;
357 }
358 }
359 for (auto& it : int_values_) {
360 if (it.second.active) {
361 int32_t* ptr =
362 reinterpret_cast<int32_t*>(&const_data_[it.second.bytes_offset]);
363 *ptr = it.second.value;
364 }
365 }
366 } else {
367 struct_desc = "";
368 }
369 return struct_desc;
370 }
371
SetInt(const std::string & name,int value)372 absl::Status MetalArguments::SetInt(const std::string& name, int value) {
373 auto it = int_values_.find(name);
374 if (it == int_values_.end()) {
375 return absl::NotFoundError(
376 absl::StrCat("No int argument with name - ", name));
377 }
378 it->second.value = value;
379 if (it->second.active) {
380 int32_t* ptr =
381 reinterpret_cast<int32_t*>(&const_data_[it->second.bytes_offset]);
382 *ptr = value;
383 }
384 return absl::OkStatus();
385 }
SetFloat(const std::string & name,float value)386 absl::Status MetalArguments::SetFloat(const std::string& name, float value) {
387 auto it = float_values_.find(name);
388 if (it == float_values_.end()) {
389 return absl::NotFoundError(
390 absl::StrCat("No float argument with name - ", name));
391 }
392 it->second.value = value;
393 if (it->second.active) {
394 float* ptr =
395 reinterpret_cast<float*>(&const_data_[it->second.bytes_offset]);
396 *ptr = value;
397 }
398 return absl::OkStatus();
399 }
400
SetHalf(const std::string & name,half value)401 absl::Status MetalArguments::SetHalf(const std::string& name, half value) {
402 auto it = float_values_.find(name);
403 if (it == float_values_.end()) {
404 return absl::NotFoundError(
405 absl::StrCat("No half argument with name - ", name));
406 }
407 it->second.value = value;
408 if (it->second.active) {
409 float* ptr =
410 reinterpret_cast<float*>(&const_data_[it->second.bytes_offset]);
411 *ptr = value;
412 }
413 return absl::OkStatus();
414 }
415
SetObjectRef(const std::string & name,const GPUObject & object)416 absl::Status MetalArguments::SetObjectRef(const std::string& name,
417 const GPUObject& object) {
418 auto it = object_refs_.find(name);
419 if (it == object_refs_.end()) {
420 return absl::NotFoundError(
421 absl::StrCat("No object ref with name - ", name));
422 }
423 GPUResourcesWithValue resources;
424 RETURN_IF_ERROR(object.GetGPUResources(it->second.get(), &resources));
425 return SetGPUResources(name, resources);
426 }
427
Encode(id<MTLComputeCommandEncoder> encoder,int buffer_offset,int texture_offset) const428 void MetalArguments::Encode(id<MTLComputeCommandEncoder> encoder,
429 int buffer_offset, int texture_offset) const {
430 for (auto& b : buffers_) {
431 [encoder setBuffer:b.second.handle offset:0 atIndex:buffer_offset];
432 buffer_offset++;
433 }
434 for (auto& image : images2d_) {
435 [encoder setTexture:image.second.handle atIndex:texture_offset];
436 texture_offset++;
437 }
438 for (auto& image : image2d_arrays_) {
439 [encoder setTexture:image.second.handle atIndex:texture_offset];
440 texture_offset++;
441 }
442 for (auto& image : images3d_) {
443 [encoder setTexture:image.second.handle atIndex:texture_offset];
444 texture_offset++;
445 }
446 for (auto& image : image_buffers_) {
447 [encoder setTexture:image.second.handle atIndex:texture_offset];
448 texture_offset++;
449 }
450
451 if (!const_data_.empty()) {
452 [encoder setBytes:const_data_.data()
453 length:const_data_.size()
454 atIndex:buffer_offset];
455 }
456 }
457
AllocateObjects(const Arguments & args,id<MTLDevice> device)458 absl::Status MetalArguments::AllocateObjects(const Arguments& args,
459 id<MTLDevice> device) {
460 objects_.resize(args.objects_.size());
461 int i = 0;
462 for (auto& t : args.objects_) {
463 RETURN_IF_ERROR(CreateMetalObject(device, t.second.get(), &objects_[i]));
464 i++;
465 }
466 return absl::OkStatus();
467 }
468
AddObjectArgs(Arguments * args)469 absl::Status MetalArguments::AddObjectArgs(Arguments* args) {
470 for (auto& t : args->objects_) {
471 AddGPUResources(t.first, t.second->GetGPUResources(), args);
472 }
473 for (auto& t : args->object_refs_) {
474 AddGPUResources(t.first, t.second->GetGPUResources(), args);
475 }
476 return absl::OkStatus();
477 }
478
GetListOfArgs(int buffer_offset,int textures_offset)479 std::string MetalArguments::GetListOfArgs(int buffer_offset,
480 int textures_offset) {
481 std::string result;
482 for (auto& t : buffers_) {
483 AppendArgument(
484 absl::StrCat(MemoryTypeToMetalType(t.second.desc.memory_type), " ",
485 ToMetalDataType(t.second.desc.data_type,
486 t.second.desc.element_size),
487 "* ", t.first, "[[buffer(", buffer_offset, ")]]"),
488 &result);
489 buffer_offset++;
490 }
491 for (auto& t : images2d_) {
492 std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
493 std::string data_type = ToMetalDataType(t.second.desc.data_type);
494 AppendArgument(absl::StrCat("texture2d<", data_type, ", ", access, "> ",
495 t.first, "[[texture(", textures_offset, ")]]"),
496 &result);
497 textures_offset++;
498 }
499 for (auto& t : image2d_arrays_) {
500 std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
501 std::string data_type = ToMetalDataType(t.second.desc.data_type);
502 AppendArgument(
503 absl::StrCat("texture2d_array<", data_type, ", ", access, "> ", t.first,
504 "[[texture(", textures_offset, ")]]"),
505 &result);
506 textures_offset++;
507 }
508 for (auto& t : images3d_) {
509 std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
510 std::string data_type = ToMetalDataType(t.second.desc.data_type);
511 AppendArgument(absl::StrCat("texture3d<", data_type, ", ", access, "> ",
512 t.first, "[[texture(", textures_offset, ")]]"),
513 &result);
514 textures_offset++;
515 }
516 for (auto& t : image_buffers_) {
517 std::string access = AccessToMetalTextureAccess(t.second.desc.access_type);
518 std::string data_type = ToMetalDataType(t.second.desc.data_type);
519 AppendArgument(
520 absl::StrCat("texture_buffer<", data_type, ", ", access, "> ", t.first,
521 "[[texture(", textures_offset, ")]]"),
522 &result);
523 textures_offset++;
524 }
525 if (!const_data_.empty()) {
526 AppendArgument(absl::StrCat("constant uniforms_buffer& U[[buffer(",
527 buffer_offset, ")]]"),
528 &result);
529 buffer_offset++;
530 }
531 return result;
532 }
533
SetGPUResources(const std::string & name,const GPUResourcesWithValue & resources)534 absl::Status MetalArguments::SetGPUResources(
535 const std::string& name, const GPUResourcesWithValue& resources) {
536 for (const auto& r : resources.ints) {
537 RETURN_IF_ERROR(SetInt(absl::StrCat(name, "_", r.first), r.second));
538 }
539 for (const auto& r : resources.floats) {
540 RETURN_IF_ERROR(SetFloat(absl::StrCat(name, "_", r.first), r.second));
541 }
542 for (const auto& r : resources.buffers) {
543 RETURN_IF_ERROR(SetBuffer(absl::StrCat(name, "_", r.first), r.second));
544 }
545 for (const auto& r : resources.images2d) {
546 RETURN_IF_ERROR(SetImage2D(absl::StrCat(name, "_", r.first), r.second));
547 }
548 for (const auto& r : resources.image2d_arrays) {
549 RETURN_IF_ERROR(
550 SetImage2DArray(absl::StrCat(name, "_", r.first), r.second));
551 }
552 for (const auto& r : resources.images3d) {
553 RETURN_IF_ERROR(SetImage3D(absl::StrCat(name, "_", r.first), r.second));
554 }
555 for (const auto& r : resources.image_buffers) {
556 RETURN_IF_ERROR(SetImageBuffer(absl::StrCat(name, "_", r.first), r.second));
557 }
558 return absl::OkStatus();
559 }
560
AddBuffer(const std::string & name,const GPUBufferDescriptor & desc)561 void MetalArguments::AddBuffer(const std::string& name,
562 const GPUBufferDescriptor& desc) {
563 buffers_[name].desc = desc;
564 }
565
AddImage2D(const std::string & name,const GPUImage2DDescriptor & desc)566 void MetalArguments::AddImage2D(const std::string& name,
567 const GPUImage2DDescriptor& desc) {
568 images2d_[name].desc = desc;
569 }
570
AddImage2DArray(const std::string & name,const GPUImage2DArrayDescriptor & desc)571 void MetalArguments::AddImage2DArray(const std::string& name,
572 const GPUImage2DArrayDescriptor& desc) {
573 image2d_arrays_[name].desc = desc;
574 }
575
AddImage3D(const std::string & name,const GPUImage3DDescriptor & desc)576 void MetalArguments::AddImage3D(const std::string& name,
577 const GPUImage3DDescriptor& desc) {
578 images3d_[name].desc = desc;
579 }
580
AddImageBuffer(const std::string & name,const GPUImageBufferDescriptor & desc)581 void MetalArguments::AddImageBuffer(const std::string& name,
582 const GPUImageBufferDescriptor& desc) {
583 image_buffers_[name].desc = desc;
584 }
585
AddGPUResources(const std::string & name,const GPUResources & resources,Arguments * args)586 void MetalArguments::AddGPUResources(const std::string& name,
587 const GPUResources& resources,
588 Arguments* args) {
589 for (const auto& r : resources.ints) {
590 args->AddInt(absl::StrCat(name, "_", r));
591 }
592 for (const auto& r : resources.floats) {
593 args->AddFloat(absl::StrCat(name, "_", r));
594 }
595 for (const auto& r : resources.buffers) {
596 AddBuffer(absl::StrCat(name, "_", r.first), r.second);
597 }
598 for (const auto& r : resources.images2d) {
599 AddImage2D(absl::StrCat(name, "_", r.first), r.second);
600 }
601 for (const auto& r : resources.image2d_arrays) {
602 AddImage2DArray(absl::StrCat(name, "_", r.first), r.second);
603 }
604 for (const auto& r : resources.images3d) {
605 AddImage3D(absl::StrCat(name, "_", r.first), r.second);
606 }
607 for (const auto& r : resources.image_buffers) {
608 AddImageBuffer(absl::StrCat(name, "_", r.first), r.second);
609 }
610 }
611
SetBuffer(const std::string & name,id<MTLBuffer> handle)612 absl::Status MetalArguments::SetBuffer(const std::string& name,
613 id<MTLBuffer> handle) {
614 auto it = buffers_.find(name);
615 if (it == buffers_.end()) {
616 return absl::NotFoundError(
617 absl::StrCat("No buffer argument with name - ", name));
618 }
619 it->second.handle = handle;
620 return absl::OkStatus();
621 }
622
SetImage2D(const std::string & name,id<MTLTexture> handle)623 absl::Status MetalArguments::SetImage2D(const std::string& name,
624 id<MTLTexture> handle) {
625 auto it = images2d_.find(name);
626 if (it == images2d_.end()) {
627 return absl::NotFoundError(
628 absl::StrCat("No image2d argument with name - ", name));
629 }
630 it->second.handle = handle;
631 return absl::OkStatus();
632 }
633
SetImage2DArray(const std::string & name,id<MTLTexture> handle)634 absl::Status MetalArguments::SetImage2DArray(const std::string& name,
635 id<MTLTexture> handle) {
636 auto it = image2d_arrays_.find(name);
637 if (it == image2d_arrays_.end()) {
638 return absl::NotFoundError(
639 absl::StrCat("No image2d array argument with name - ", name));
640 }
641 it->second.handle = handle;
642 return absl::OkStatus();
643 }
644
SetImage3D(const std::string & name,id<MTLTexture> handle)645 absl::Status MetalArguments::SetImage3D(const std::string& name,
646 id<MTLTexture> handle) {
647 auto it = images3d_.find(name);
648 if (it == images3d_.end()) {
649 return absl::NotFoundError(
650 absl::StrCat("No image3d argument with name - ", name));
651 }
652 it->second.handle = handle;
653 return absl::OkStatus();
654 }
655
SetImageBuffer(const std::string & name,id<MTLTexture> handle)656 absl::Status MetalArguments::SetImageBuffer(const std::string& name,
657 id<MTLTexture> handle) {
658 auto it = image_buffers_.find(name);
659 if (it == image_buffers_.end()) {
660 return absl::NotFoundError(
661 absl::StrCat("No image buffer argument with name - ", name));
662 }
663 it->second.handle = handle;
664 return absl::OkStatus();
665 }
666
ResolveSelectorsPass(const GpuInfo & gpu_info,const Arguments & args,const std::map<std::string,std::string> & linkables,std::string * code)667 absl::Status MetalArguments::ResolveSelectorsPass(
668 const GpuInfo& gpu_info, const Arguments& args,
669 const std::map<std::string, std::string>& linkables, std::string* code) {
670 std::string result;
671 size_t position = 0;
672 size_t next_position = code->find(kArgsPrefix);
673 while (next_position != std::string::npos) {
674 size_t arg_pos = next_position;
675 next_position += strlen(kArgsPrefix);
676 std::string object_name = GetNextWord(*code, next_position);
677 char next = (*code)[next_position + object_name.size()];
678 if (next == '.') {
679 next_position += object_name.size() + 1;
680 std::string selector_name = GetNextWord(*code, next_position);
681 next_position += selector_name.size();
682 next = (*code)[next_position];
683 std::vector<std::string> template_args;
684 if (next == '<') {
685 size_t close_bracket_pos;
686 RETURN_IF_ERROR(ParseArgsInsideBrackets(
687 *code, next_position, &close_bracket_pos, &template_args));
688 next_position = close_bracket_pos;
689 next = (*code)[next_position];
690 }
691 if (next != '(') {
692 return absl::NotFoundError(absl::StrCat(
693 "Expected ( after ", object_name, ".", selector_name, " call"));
694 }
695 std::vector<std::string> function_args;
696 size_t close_bracket_pos;
697 RETURN_IF_ERROR(ParseArgsInsideBrackets(
698 *code, next_position, &close_bracket_pos, &function_args));
699 for (auto& arg : function_args) {
700 RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, args, {}, &arg));
701 }
702 std::string patch;
703 RETURN_IF_ERROR(ResolveSelector(gpu_info, args, linkables, object_name,
704 selector_name, function_args,
705 template_args, &patch));
706 code->replace(arg_pos, close_bracket_pos - arg_pos, patch);
707 position = arg_pos + patch.size();
708 } else {
709 position = arg_pos + strlen(kArgsPrefix);
710 }
711 next_position = code->find(kArgsPrefix, position);
712 }
713 return absl::OkStatus();
714 }
715
ResolveSelector(const GpuInfo & gpu_info,const Arguments & args,const std::map<std::string,std::string> & linkables,const std::string & object_name,const std::string & selector,const std::vector<std::string> & function_args,const std::vector<std::string> & template_args,std::string * result)716 absl::Status MetalArguments::ResolveSelector(
717 const GpuInfo& gpu_info, const Arguments& args,
718 const std::map<std::string, std::string>& linkables,
719 const std::string& object_name, const std::string& selector,
720 const std::vector<std::string>& function_args,
721 const std::vector<std::string>& template_args, std::string* result) {
722 const GPUObjectDescriptor* desc_ptr;
723 auto it_ref = args.object_refs_.find(object_name);
724 auto it_obj = args.objects_.find(object_name);
725 if (it_ref != args.object_refs_.end()) {
726 desc_ptr = it_ref->second.get();
727 } else if (it_obj != args.objects_.end()) {
728 desc_ptr = it_obj->second.get();
729 } else {
730 return absl::NotFoundError(
731 absl::StrCat("No object with name - ", object_name));
732 }
733 auto names = desc_ptr->GetGPUResources().GetNames();
734 const auto* tensor_desc = dynamic_cast<const TensorDescriptor*>(desc_ptr);
735 if (tensor_desc && (selector == "Write" || selector == "Linking")) {
736 auto it = linkables.find(object_name);
737 if (it != linkables.end()) {
738 if (desc_ptr->GetAccess() != AccessType::WRITE &&
739 desc_ptr->GetAccess() != AccessType::READ_WRITE) {
740 return absl::FailedPreconditionError(absl::StrCat(
741 "Object with name - ", object_name, " should have Write access."));
742 }
743 std::string value_name, x_coord, y_coord, s_coord;
744 RETURN_IF_ERROR(tensor_desc->GetLinkingContextFromWriteSelector(
745 function_args, &value_name, &x_coord, &y_coord, &s_coord));
746 // x_coord can have batch size property of link_object
747 ResolveObjectNames(object_name, names, &x_coord);
748 *result = it->second;
749 ReplaceAllWords("in_out_value", value_name, result);
750 ReplaceAllWords("X_COORD", x_coord, result);
751 ReplaceAllWords("Y_COORD", y_coord, result);
752 ReplaceAllWords("S_COORD", s_coord, result);
753 RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, args, {}, result));
754 if (selector == "Linking") {
755 return absl::OkStatus();
756 }
757 }
758 }
759 std::string patch;
760 RETURN_IF_ERROR(desc_ptr->PerformSelector(gpu_info, selector, function_args,
761 template_args, &patch));
762 ResolveObjectNames(object_name, names, &patch);
763 *result += patch;
764 return absl::OkStatus();
765 }
766
ResolveObjectNames(const std::string & object_name,const std::vector<std::string> & member_names,std::string * code)767 void MetalArguments::ResolveObjectNames(
768 const std::string& object_name,
769 const std::vector<std::string>& member_names, std::string* code) {
770 for (const auto& member_name : member_names) {
771 const std::string new_name = kArgsPrefix + object_name + "_" + member_name;
772 ReplaceAllWords(member_name, new_name, code);
773 }
774 }
775
ResolveArgsPass(std::string * code)776 void MetalArguments::ResolveArgsPass(std::string* code) {
777 size_t position = 0;
778 size_t next_position = code->find(kArgsPrefix);
779 while (next_position != std::string::npos) {
780 size_t arg_pos = next_position;
781 next_position += strlen(kArgsPrefix);
782 std::string object_name = GetNextWord(*code, next_position);
783 std::string new_name = object_name;
784 code->replace(arg_pos, object_name.size() + strlen(kArgsPrefix), new_name);
785 position = arg_pos + new_name.size();
786 next_position = code->find(kArgsPrefix, position);
787 }
788 }
789
SetObjectsResources(const Arguments & args)790 absl::Status MetalArguments::SetObjectsResources(const Arguments& args) {
791 int i = 0;
792 for (const auto& t : args.objects_) {
793 GPUResourcesWithValue resources;
794 RETURN_IF_ERROR(objects_[i]->GetGPUResources(t.second.get(), &resources));
795 RETURN_IF_ERROR(SetGPUResources(t.first, resources));
796 i++;
797 }
798 return absl::OkStatus();
799 }
800
801 } // namespace metal
802 } // namespace gpu
803 } // namespace tflite
804