1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/conv_metal.h"
17
18 #include <cmath>
19 #include <cstdint>
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/strings/substitute.h"
27 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
28 #include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
29 #include "tensorflow/lite/delegates/gpu/common/operations.h"
30 #include "tensorflow/lite/delegates/gpu/common/shape.h"
31 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
32 #include "tensorflow/lite/delegates/gpu/common/task/weights_conversion.h"
33 #include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h"
34 #include "tensorflow/lite/delegates/gpu/common/types.h"
35 #include "tensorflow/lite/delegates/gpu/common/util.h"
36 #include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
37
38 namespace tflite {
39 namespace gpu {
40
41 namespace {
42
GetNumOutputSlices(int dst_channels)43 int GetNumOutputSlices(int dst_channels) {
44 const int dst_depth = DivideRoundUp(dst_channels, 4);
45 if (dst_depth % 4 == 0 || dst_depth >= 16) {
46 return 4;
47 } else if (dst_depth % 2 == 0 || dst_depth >= 4) {
48 return 2;
49 } else {
50 return 1;
51 }
52 }
53
54 struct GlobalIdsParams {
55 std::vector<std::string> global_ids;
56 std::vector<std::string> group_ids;
57 std::vector<std::string> local_sizes;
58 std::vector<std::string> local_ids;
59 int3 block_size;
60 int3 launch_order;
61 bool linear_wh;
62 bool linear_whs;
63 std::string task_size_w; // must be filled if linear_wh or linear_whs enabled
64 std::string task_size_wh; // must be filled if linear_whs enabled
65 };
66
GlobalIdsGen(const GlobalIdsParams & params)67 std::string GlobalIdsGen(const GlobalIdsParams& params) {
68 std::string c;
69 int3 launch_remap;
70 launch_remap[params.launch_order.x] = 0;
71 launch_remap[params.launch_order.y] = 1;
72 launch_remap[params.launch_order.z] = 2;
73 if (params.linear_whs) {
74 c += " int linear_whs = " + params.global_ids[0] + ";\n";
75 c += " int Z = (linear_whs / " + params.task_size_wh + ") * " +
76 std::to_string(params.block_size.z) + ";\n";
77 c += " int linear_wh = linear_whs % " + params.task_size_wh + ";\n";
78 c += " int Y = (linear_wh / " + params.task_size_w + ") * " +
79 std::to_string(params.block_size.y) + ";\n";
80 c += " int X = (linear_wh % " + params.task_size_w + ") * " +
81 std::to_string(params.block_size.x) + ";\n";
82 } else if (params.linear_wh) {
83 if (params.launch_order.x == 0) {
84 c += " int linear_wh = " + params.global_ids[0] + ";\n";
85 } else {
86 c += " int linear_wh = " + params.group_ids[launch_remap.x] + " * " +
87 params.local_sizes[0] + " + " + params.local_ids[0] + ";\n";
88 }
89 c += " int Y = (linear_wh / " + params.task_size_w + ") * " +
90 std::to_string(params.block_size.y) + ";\n";
91 c += " int X = (linear_wh % " + params.task_size_w + ") * " +
92 std::to_string(params.block_size.x) + ";\n";
93 if (params.launch_order.y == 1) {
94 c += " int Z = " + params.global_ids[1] + " * " +
95 std::to_string(params.block_size.z) + ";\n";
96 } else {
97 c += " int Z = (" + params.group_ids[launch_remap.y] + " * " +
98 params.local_sizes[1] + " + " + params.local_ids[1] + ") * " +
99 std::to_string(params.block_size.z) + ";\n";
100 }
101 } else {
102 if (params.launch_order.x == 0) {
103 c += " int X = " + params.global_ids[0] + " * " +
104 std::to_string(params.block_size.x) + ";\n";
105 } else {
106 c += " int X = (" + params.group_ids[launch_remap.x] + " * " +
107 params.local_sizes[0] + " + " + params.local_ids[0] + ") * " +
108 std::to_string(params.block_size.x) + ";\n";
109 }
110 if (params.launch_order.y == 1) {
111 c += " int Y = " + params.global_ids[1] + " * " +
112 std::to_string(params.block_size.y) + ";\n";
113 } else {
114 c += " int Y = (" + params.group_ids[launch_remap.y] + " * " +
115 params.local_sizes[1] + " + " + params.local_ids[1] + ") * " +
116 std::to_string(params.block_size.y) + ";\n";
117 }
118 if (params.launch_order.z == 2) {
119 c += " int Z = " + params.global_ids[2] + " * " +
120 std::to_string(params.block_size.z) + ";\n";
121 } else {
122 c += " int Z = (" + params.group_ids[launch_remap.z] + " * " +
123 params.local_sizes[2] + " + " + params.local_ids[2] + ") * " +
124 std::to_string(params.block_size.z) + ";\n";
125 }
126 }
127 return c;
128 }
129
GenerateUploadByThreads(const std::string & local_ptr_name,const std::string & global_ptr_name,const std::string & global_offset_name,const std::string & lid_name,int total_work_items,int elements_to_upload)130 std::string GenerateUploadByThreads(const std::string& local_ptr_name,
131 const std::string& global_ptr_name,
132 const std::string& global_offset_name,
133 const std::string& lid_name,
134 int total_work_items,
135 int elements_to_upload) {
136 std::string c;
137 std::string offset =
138 global_offset_name.empty() ? "" : global_offset_name + " + ";
139 const int groups = elements_to_upload / total_work_items;
140 const int reminder = elements_to_upload % total_work_items;
141 for (int i = 0; i < groups; ++i) {
142 c += " " + local_ptr_name + "[" + lid_name + " + " +
143 std::to_string(total_work_items * i) + "] = " + global_ptr_name + "[" +
144 offset + lid_name + " + " + std::to_string(total_work_items * i) +
145 "];\n";
146 }
147 if (reminder != 0) {
148 c += " if (" + lid_name + " < " + std::to_string(reminder) + ") {\n";
149 c += " " + local_ptr_name + "[" + lid_name + " + " +
150 std::to_string(total_work_items * groups) + "] = " + global_ptr_name +
151 "[" + offset + lid_name + " + " +
152 std::to_string(total_work_items * groups) + "];\n";
153 c += " }\n";
154 }
155 return c;
156 }
157
GenerateConvolution(const ConvolutionMetal::ConvParams & params,const OperationDef & definition,bool stride_correction)158 std::string GenerateConvolution(const ConvolutionMetal::ConvParams& params,
159 const OperationDef& definition,
160 bool stride_correction) {
161 GlobalIdsParams ids_params;
162 ids_params.group_ids = {"group_id.x", "group_id.y", "group_id.z"};
163 ids_params.global_ids = {"ugid.x", "ugid.y", "ugid.z"};
164 ids_params.local_ids = {"tid3d.x", "tid3d.y", "tid3d.z"};
165 ids_params.local_sizes = {"lsize.x", "lsize.y", "lsize.z"};
166 ids_params.linear_wh = params.linear_wh;
167 ids_params.task_size_w = "args.task_size_x";
168 ids_params.task_size_wh = "args.task_size_y";
169 ids_params.linear_whs = params.linear_whs;
170 ids_params.block_size = params.block_size;
171 ids_params.launch_order = params.work_group_launch_order;
172
173 std::string addr_space =
174 params.weights_upload_type ==
175 ConvolutionMetal::WeightsUploadType::CONSTANT_MEM
176 ? "constant"
177 : "device";
178 const bool use_local_mem =
179 params.weights_upload_type ==
180 ConvolutionMetal::WeightsUploadType::LOCAL_MEM_BY_THREADS;
181 const int local_mem_size =
182 params.block_size.z * 4 * params.src_depth_loop_size;
183
184 const bool use_simd_broadcast =
185 params.weights_upload_type ==
186 ConvolutionMetal::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST ||
187 params.weights_upload_type ==
188 ConvolutionMetal::WeightsUploadType::PRIVATE_MEM_SIMD16_BROADCAST ||
189 params.weights_upload_type ==
190 ConvolutionMetal::WeightsUploadType::PRIVATE_MEM_SIMD32_BROADCAST;
191 int simd_size = 1;
192 if (params.weights_upload_type ==
193 ConvolutionMetal::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST) {
194 simd_size = 8;
195 } else if (params.weights_upload_type == ConvolutionMetal::WeightsUploadType::
196 PRIVATE_MEM_SIMD16_BROADCAST) {
197 simd_size = 16;
198 } else if (params.weights_upload_type == ConvolutionMetal::WeightsUploadType::
199 PRIVATE_MEM_SIMD32_BROADCAST) {
200 simd_size = 32;
201 }
202
203 const bool use_filters_constants =
204 !params.need_dst_loop && !params.need_src_loop && params.x_kernel_is_1 &&
205 params.y_kernel_is_1;
206
207 const auto src_storage_type = definition.src_tensors[0].storage_type;
208 const auto dst_storage_type = definition.dst_tensors[0].storage_type;
209 const bool src_is_linear =
210 src_storage_type == TensorStorageType::BUFFER ||
211 src_storage_type == TensorStorageType::IMAGE_BUFFER;
212 const bool dst_is_linear =
213 dst_storage_type == TensorStorageType::BUFFER ||
214 dst_storage_type == TensorStorageType::IMAGE_BUFFER;
215
216 std::string channels[4] = {"x", "y", "z", "w"};
217 std::string c;
218 c.reserve(16 * 1024); // Reserve large enough buffer.
219 c += R"(
220 kernel void ComputeFunction(
221 $0
222 uint tid[[thread_index_in_threadgroup]],
223 uint3 group_id[[threadgroup_position_in_grid]],
224 uint3 tid3d[[thread_position_in_threadgroup]],
225 uint3 lsize[[threads_per_threadgroup]],
226 )";
227 if (use_simd_broadcast) {
228 c += " uint simd_id[[thread_index_in_simdgroup]],\n";
229 }
230 c += " uint3 ugid[[thread_position_in_grid]]){\n";
231 c += GlobalIdsGen(ids_params);
232 c += " if (Z >= args.dst_tensor.Slices()) return;\n";
233 bool late_xy_check = use_local_mem || use_simd_broadcast;
234 if (!late_xy_check && !params.linear_whs) {
235 c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) "
236 "return;\n";
237 }
238 for (int z = 0; z < params.block_size.z; ++z) {
239 for (int y = 0; y < params.block_size.y; ++y) {
240 for (int x = 0; x < params.block_size.x; ++x) {
241 const std::string s_i =
242 std::to_string(z) + std::to_string(y) + std::to_string(x);
243 c +=
244 " ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n";
245 }
246 }
247 }
248 auto for_every_yx =
249 [&](std::function<std::string(const std::string&, const std::string&,
250 const std::string&, int, int)>
251 lambda) {
252 for (int y = 0; y < params.block_size.y; ++y) {
253 const std::string s_y = std::to_string(y);
254 for (int x = 0; x < params.block_size.x; ++x) {
255 const std::string s_x = std::to_string(x);
256 const std::string s_yx = s_y + s_x;
257 c += lambda(s_yx, s_x, s_y, x, y) + "\n";
258 }
259 }
260 };
261 if (!use_filters_constants) {
262 std::string kern_x = params.x_kernel_is_1 ? "" : " * args.kernel_size_x";
263 std::string kern_y = params.y_kernel_is_1 ? "" : " * args.kernel_size_y";
264 std::string dst_offset =
265 params.need_dst_loop ? " + Z * 4 * args.src_tensor.Slices()" : "";
266 if (!params.need_dst_loop) {
267 c += " " + addr_space + " FLT4* tmp = args.weights.GetPtr();\n";
268 } else {
269 if (params.different_weights_for_height) {
270 c += " " + addr_space +
271 " FLT4* tmp = args.weights.GetPtr() + (Z * "
272 "args.src_tensor.Height() + Y * " +
273 std::to_string(params.block_size.z) +
274 ") * 4 * args.src_tensor.Slices();\n";
275 } else {
276 c += " " + addr_space +
277 " FLT4* tmp = args.weights.GetPtr() + Z * 4 * "
278 "args.src_tensor.Slices()" +
279 kern_x + kern_y + ";\n";
280 }
281 }
282 }
283 if (!params.x_kernel_is_1) {
284 for (int x = 0; x < params.block_size.x; ++x) {
285 const std::string s_x = std::to_string(x);
286 if (stride_correction) {
287 c += " int x" + s_x + " = " +
288 GetXStrideCorrected("(X + " + s_x + ")", "args.src_tensor.Batch()",
289 "args.stride_x", "args.padding_x") +
290 ";\n";
291 } else {
292 c += " int x" + s_x + " = (X + " + s_x +
293 ") * args.stride_x + args.padding_x;\n";
294 }
295 }
296 }
297 if (!params.y_kernel_is_1) {
298 for (int y = 0; y < params.block_size.y; ++y) {
299 const std::string s_y = std::to_string(y);
300 c += " int y" + s_y + " = (Y + " + s_y +
301 ") * args.stride_y + args.padding_y;\n";
302 }
303 }
304 if (use_local_mem) {
305 c += " threadgroup FLT4 weights_cache[" + std::to_string(local_mem_size) +
306 "];\n";
307 }
308 if (!params.y_kernel_is_1) {
309 c += " int y = 0;\n";
310 c += " do {\n";
311 for (int y = 0; y < params.block_size.y; ++y) {
312 const std::string s_y = std::to_string(y);
313 c += " int c_y" + s_y + " = y * args.dilation_y + y" + s_y + ";\n";
314 if (src_is_linear) {
315 c += " bool y" + s_y + "_out = c_y" + s_y + " < 0 || c_y" + s_y +
316 " >= args.src_tensor.Height();\n";
317 c += " c_y" + s_y + " = clamp(c_y" + s_y +
318 ", 0, args.src_tensor.Height() - 1);\n";
319 }
320 }
321 } else {
322 for (int y = 0; y < params.block_size.y; ++y) {
323 const std::string s_y = std::to_string(y);
324 c += " int c_y" + s_y + " = clamp(Y + " + s_y +
325 ", 0, args.src_tensor.Height() - 1);\n";
326 }
327 }
328 if (!params.x_kernel_is_1) {
329 c += " int x = 0;\n";
330 c += " do {\n";
331 for (int x = 0; x < params.block_size.x; ++x) {
332 const std::string s_x = std::to_string(x);
333 c += " int c_x" + s_x + " = x * args.dilation_x + x" + s_x + ";\n";
334 if (src_is_linear) {
335 c += " bool x" + s_x + "_out = c_x" + s_x + " < 0 || c_x" + s_x +
336 " >= args.src_tensor.Width();\n";
337 c += " c_x" + s_x + " = clamp(c_x" + s_x +
338 ", 0, args.src_tensor.Width() - 1);\n";
339 }
340 }
341 } else {
342 for (int x = 0; x < params.block_size.x; ++x) {
343 const std::string s_x = std::to_string(x);
344 c += " int c_x" + s_x + " = clamp(X + " + s_x +
345 ", 0, args.src_tensor.Width() - 1);\n";
346 }
347 }
348 if (src_is_linear) {
349 for (int y = 0; y < params.block_size.y; ++y) {
350 const std::string s_y = std::to_string(y);
351 for (int x = 0; x < params.block_size.x; ++x) {
352 const std::string s_x = std::to_string(x);
353 const std::string s_yx = s_y + s_x;
354 if (!params.y_kernel_is_1 && !params.x_kernel_is_1) {
355 c += " FLT m" + s_yx + " = !(y" + s_y + "_out || x" + s_x +
356 "_out);\n";
357 } else if (!params.y_kernel_is_1) {
358 c += " FLT m" + s_yx + " = !y" + s_y + "_out;\n";
359 } else if (!params.x_kernel_is_1) {
360 c += " FLT m" + s_yx + " = !x" + s_x + "_out;\n";
361 }
362 }
363 }
364 for (int y = 0; y < params.block_size.y; ++y) {
365 const std::string s_y = std::to_string(y);
366 for (int x = 0; x < params.block_size.x; ++x) {
367 const std::string s_x = std::to_string(x);
368 const std::string s_yx = s_y + s_x;
369 if (definition.src_tensors[0].storage_type ==
370 TensorStorageType::BUFFER) {
371 c += " device FLT4* src_loc_" + s_yx +
372 " = args.src_tensor.GetHandle() + "
373 "args.src_tensor.GetWHOffset(c_x" +
374 s_x + ", c_y" + s_y + ");\n";
375 } else if (definition.src_tensors[0].storage_type ==
376 TensorStorageType::IMAGE_BUFFER) {
377 c += " int src_loc_" + s_yx + " = args.src_tensor.GetWHOffset(c_x" +
378 s_x + ", c_y" + s_y + ");\n";
379 }
380 }
381 }
382 }
383 c += " int s = 0;\n";
384 if (params.need_src_loop) {
385 c += " do {\n";
386 }
387 if (use_local_mem) {
388 const int total_work_items = params.work_group_size.x *
389 params.work_group_size.y *
390 params.work_group_size.z;
391 c += " SIMDGROUP_BARRIER(mem_flags::mem_none);\n";
392 c += GenerateUploadByThreads("weights_cache", "tmp",
393 /*global_offset_name*/ "", "tid",
394 total_work_items, local_mem_size);
395 c += " SIMDGROUP_BARRIER(mem_flags::mem_threadgroup);\n";
396 } else if (use_simd_broadcast) {
397 int parts = local_mem_size / simd_size;
398 int reminder = local_mem_size % simd_size;
399 for (int i = 0; i < parts; ++i) {
400 c += " FLT4 simd_w" + std::to_string(i) + " = tmp[simd_id + " +
401 std::to_string(i * simd_size) + "];\n";
402 }
403 if (reminder) {
404 c += " FLT4 simd_w" + std::to_string(parts) + ";\n";
405 c += " if (simd_id < " + std::to_string(reminder) + ") {\n";
406 c += " simd_w" + std::to_string(parts) + " = tmp[simd_id + " +
407 std::to_string(parts * simd_size) + "];\n";
408 c += " }\n";
409 }
410 }
411 auto declare_src = [&]() {
412 for (int y = 0; y < params.block_size.y; ++y) {
413 for (int x = 0; x < params.block_size.x; ++x) {
414 const std::string s_yx = std::to_string(y) + std::to_string(x);
415 c += " FLT4 src" + s_yx + ";\n";
416 }
417 }
418 };
419 auto read_src = [&]() {
420 for (int y = 0; y < params.block_size.y; ++y) {
421 for (int x = 0; x < params.block_size.x; ++x) {
422 const std::string s_yx = std::to_string(y) + std::to_string(x);
423 if (src_is_linear) {
424 if (definition.src_tensors[0].storage_type ==
425 TensorStorageType::BUFFER) {
426 if (!params.y_kernel_is_1 || !params.x_kernel_is_1) {
427 c += " src" + s_yx + " = *src_loc_" + s_yx + " * m" + s_yx +
428 ";\n";
429 } else {
430 c += " src" + s_yx + " = *src_loc_" + s_yx + ";\n";
431 }
432 } else if (definition.src_tensors[0].storage_type ==
433 TensorStorageType::IMAGE_BUFFER) {
434 if (!params.y_kernel_is_1 || !params.x_kernel_is_1) {
435 c += " src" + s_yx + " = args.src_tensor.Read(src_loc_" +
436 s_yx + ") * m" + s_yx + ";\n";
437 } else {
438 c += " src" + s_yx + " = args.src_tensor.Read(src_loc_" +
439 s_yx + ");\n";
440 }
441 }
442 } else {
443 c += " src" + s_yx + " = args.src_tensor.Read(c_x" +
444 std::to_string(x) + ", c_y" + std::to_string(y) + ", s);\n";
445 }
446 }
447 }
448 if (src_is_linear) {
449 for (int y = 0; y < params.block_size.y; ++y) {
450 for (int x = 0; x < params.block_size.x; ++x) {
451 const std::string s_yx = std::to_string(y) + std::to_string(x);
452 c += " src_loc_" + s_yx + " += args.src_tensor.SliceStride();\n";
453 }
454 }
455 }
456 };
457 auto conv_core = [&](int offset) {
458 std::string name = use_local_mem ? "weights_cache" : "tmp";
459 if (use_filters_constants) {
460 name = "args.weights.GetPtr()";
461 }
462 for (int z = 0; z < params.block_size.z; ++z) {
463 for (int ch = 0; ch < 4; ++ch) {
464 for (int y = 0; y < params.block_size.y; ++y) {
465 for (int x = 0; x < params.block_size.x; ++x) {
466 std::string s_id = std::to_string(y) + std::to_string(x);
467 std::string r_id =
468 std::to_string(z) + std::to_string(y) + std::to_string(x);
469 std::string f_val =
470 name + "[" + std::to_string(z * 4 + ch + offset) + "]";
471 if (use_simd_broadcast) {
472 int simd_id = (z * 4 + ch + offset) / simd_size;
473 int thread_id = (z * 4 + ch + offset) % simd_size;
474 f_val = "simd_broadcast(simd_w" + std::to_string(simd_id) + ", " +
475 std::to_string(thread_id) + "u)";
476 }
477 std::string s_val = "src" + s_id;
478 std::string r_val = "r" + r_id;
479 if (params.weights_layout == WeightsLayout::kOHWIOGroupO4I4) {
480 c += " " + r_val + "." + channels[ch] + " += dot(" + f_val +
481 ", " + s_val + ");\n";
482 } else { // WeightsInnerBlockLayout::I404
483 std::string temp_sum = f_val + " * " + s_val + "." + channels[ch];
484 if (definition.precision == CalculationsPrecision::F32_F16) {
485 temp_sum = "float4(" + temp_sum + ")";
486 }
487 c += " " + r_val + " += " + temp_sum + ";\n";
488 }
489 }
490 }
491 }
492 }
493 };
494 declare_src();
495 read_src();
496 c += " s += 1;\n";
497 conv_core(0);
498 for (int i = 1; i < params.src_depth_loop_size; ++i) {
499 read_src();
500 conv_core(i * params.block_size.z * 4);
501 c += " s += 1;\n";
502 }
503 if (!use_filters_constants) {
504 c += " tmp += " +
505 std::to_string(params.block_size.z * 4 * params.src_depth_loop_size) +
506 ";\n";
507 }
508 if (params.need_src_loop) {
509 c += " } while (s < args.src_tensor.Slices());\n";
510 }
511 if (!params.x_kernel_is_1) {
512 c += " x++;\n";
513 c += " } while (x < args.kernel_size_x);\n";
514 }
515 if (!params.y_kernel_is_1) {
516 c += " y++;\n";
517 c += " } while (y < args.kernel_size_y);\n";
518 }
519
520 if (late_xy_check && !params.linear_whs) {
521 c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) "
522 "return;\n";
523 }
524
525 if (dst_is_linear) {
526 for_every_yx([](const std::string& s_yx, const std::string& s_x,
527 const std::string& s_y, int x, int y) {
528 return " args.dst_tensor.GetAddress(offset_" + s_yx + ", X + " + s_x +
529 ", Y + " + s_y + ", Z);";
530 });
531 }
532
533 std::string bias_name = "args.biases.GetPtr()";
534 if (params.need_dst_loop) {
535 c += " device FLT4* bias_loc = args.biases.GetPtr() + Z;\n";
536 bias_name = "bias_loc";
537 }
538 for (int y = 0; y < params.block_size.y; ++y) {
539 for (int x = 0; x < params.block_size.x; ++x) {
540 for (int z = 0; z < params.block_size.z; ++z) {
541 std::string r_id =
542 std::to_string(z) + std::to_string(y) + std::to_string(x);
543 c += " r" + r_id + " += TO_ACCUM_TYPE(" + bias_name + "[" +
544 std::to_string(z) + "]);\n";
545 }
546 }
547 }
548 for (int z = 0; z < params.block_size.z; ++z) {
549 const std::string s_z = std::to_string(z);
550 c += " if (Z + " + s_z + " < args.dst_tensor.Slices()) {\n";
551 for (int y = 0; y < params.block_size.y; ++y) {
552 const std::string s_y = std::to_string(y);
553 for (int x = 0; x < params.block_size.x; ++x) {
554 const std::string s_x = std::to_string(x);
555 const std::string s_yx = s_y + s_x;
556 const std::string s_zyx = s_z + s_yx;
557 bool need_check_x = x >= 1;
558 bool need_check_y = y >= 1;
559 std::string check;
560 if (need_check_x) {
561 check += "(X + " + s_x + ") < args.dst_tensor.Width()";
562 }
563 if (need_check_y) {
564 check += check.empty() ? "" : " && ";
565 check += "(Y + " + s_y + ") < args.dst_tensor.Height()";
566 }
567 if (!check.empty()) {
568 c += " if (" + check + ") {\n";
569 } else {
570 c += " {\n";
571 }
572 c += " FLT4 value = FLT4(r" + s_zyx + ");\n";
573 if (dst_is_linear) {
574 c += " int linear_index = offset_" + s_yx +
575 " + args.dst_tensor.SliceStride() * " + s_z + ";\n";
576 c += " args.dst_tensor.Linking(value, X + " + s_x + ", Y + " +
577 s_y + ", Z + " + s_z + ");\n";
578 c += " args.dst_tensor.WriteLinear(value, linear_index);\n";
579 } else {
580 c += " args.dst_tensor.Write(value, X + " + s_x + ", Y + " +
581 s_y + ", Z + " + s_z + ");\n";
582 }
583 c += " }\n";
584 }
585 }
586 c += " }\n";
587 }
588 c += "}\n";
589 return c;
590 }
591
ReorderWeightsForConv(const tflite::gpu::Tensor<OHWI,DataType::FLOAT32> & weights,const WeightsDescription & weights_desc,const DataType & weights_type)592 std::vector<uint8_t> ReorderWeightsForConv(
593 const tflite::gpu::Tensor<OHWI, DataType::FLOAT32>& weights,
594 const WeightsDescription& weights_desc, const DataType& weights_type) {
595 const int flt_count =
596 GetTotalElementsCountForLayout(weights_desc, weights.shape);
597 std::vector<uint8_t> result(flt_count * SizeOf(weights_type));
598 RearrangeWeights(weights, weights_desc, weights_type, absl::MakeSpan(result));
599 return result;
600 }
601
ReorderBiasesForConv(const tflite::gpu::Tensor<Linear,DataType::FLOAT32> & biases,const DataType & biases_type,int output_size)602 std::vector<uint8_t> ReorderBiasesForConv(
603 const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases,
604 const DataType& biases_type, int output_size) {
605 std::vector<uint8_t> result(output_size * SizeOf(biases_type));
606 if (biases_type == DataType::FLOAT32) {
607 float* gpu_data = reinterpret_cast<float*>(result.data());
608 for (int i = 0; i < output_size; ++i) {
609 gpu_data[i] = i < biases.shape.v ? biases.data[i] : 0.0f;
610 }
611 } else {
612 half* gpu_data = reinterpret_cast<half*>(result.data());
613 for (int i = 0; i < output_size; ++i) {
614 gpu_data[i] = i < biases.shape.v ? biases.data[i] : 0.0f;
615 }
616 }
617 return result;
618 }
619
GetGroupsCount(const BHWC & dst_shape,const int3 & wg_size,const int3 & block_size)620 int GetGroupsCount(const BHWC& dst_shape, const int3& wg_size,
621 const int3& block_size) {
622 const int dst_slices = DivideRoundUp(dst_shape.c, 4);
623
624 int grid_x = DivideRoundUp(dst_shape.w, block_size.x);
625 int grid_y = DivideRoundUp(dst_shape.h, block_size.y);
626 int grid_z = DivideRoundUp(dst_slices, block_size.z);
627
628 return DivideRoundUp(grid_x, wg_size.x) * DivideRoundUp(grid_y, wg_size.y) *
629 DivideRoundUp(grid_z, wg_size.z);
630 }
631
GetGroupsCountForLinearWH(const BHWC & dst_shape,const int3 & wg_size,const int3 & block_size)632 int GetGroupsCountForLinearWH(const BHWC& dst_shape, const int3& wg_size,
633 const int3& block_size) {
634 const int dst_slices = DivideRoundUp(dst_shape.c, 4);
635
636 int grid_x = DivideRoundUp(dst_shape.w, block_size.x);
637 int grid_y = DivideRoundUp(dst_shape.h, block_size.y);
638 int grid_z = DivideRoundUp(dst_slices, block_size.z);
639
640 return DivideRoundUp(grid_x * grid_y, wg_size.x) *
641 DivideRoundUp(grid_z, wg_size.y);
642 }
643
GetGroupsCountForLinearWHS(const BHWC & dst_shape,const int3 & wg_size,const int3 & block_size)644 int GetGroupsCountForLinearWHS(const BHWC& dst_shape, const int3& wg_size,
645 const int3& block_size) {
646 const int dst_slices = DivideRoundUp(dst_shape.c, 4);
647
648 int grid_x = DivideRoundUp(dst_shape.w, block_size.x);
649 int grid_y = DivideRoundUp(dst_shape.h, block_size.y);
650 int grid_z = DivideRoundUp(dst_slices, block_size.z);
651
652 return DivideRoundUp(grid_x * grid_y * grid_z, wg_size.x);
653 }
654
IsKernelXIs1(const Convolution2DAttributes & attr)655 bool IsKernelXIs1(const Convolution2DAttributes& attr) {
656 return attr.weights.shape.w == 1 && attr.strides.w == 1 &&
657 attr.dilations.w == 1 && attr.padding.prepended.w == 0 &&
658 attr.padding.appended.w == 0;
659 }
660
IsKernelYIs1(const Convolution2DAttributes & attr)661 bool IsKernelYIs1(const Convolution2DAttributes& attr) {
662 return attr.weights.shape.h == 1 && attr.strides.h == 1 &&
663 attr.dilations.h == 1 && attr.padding.prepended.h == 0 &&
664 attr.padding.appended.h == 0;
665 }
666
GetMaximumPossibleWavesCount(const AppleInfo & apple_info,const BHWC & dst_shape)667 int GetMaximumPossibleWavesCount(const AppleInfo& apple_info,
668 const BHWC& dst_shape) {
669 if (apple_info.IsLocalMemoryPreferredOverGlobal()) {
670 return GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, {1, 1, 1});
671 } else {
672 return GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, {1, 1, 1});
673 }
674 }
675
GetRecommendedBlockSize(const AppleInfo & apple_info,const BHWC & dst_shape)676 int GetRecommendedBlockSize(const AppleInfo& apple_info,
677 const BHWC& dst_shape) {
678 const int max_waves = GetMaximumPossibleWavesCount(apple_info, dst_shape);
679 const int cu_count = apple_info.GetComputeUnitsCount();
680 if (max_waves >= cu_count * 64) {
681 return 8;
682 } else if (max_waves >= cu_count * 32) {
683 return 4;
684 } else if (max_waves >= cu_count * 16) {
685 return 2;
686 } else {
687 return 1;
688 }
689 }
690
GetConvParamsForA7A8(const AppleInfo & apple_info,const Convolution2DAttributes & attr,const BHWC & dst_shape)691 ConvolutionMetal::ConvParams GetConvParamsForA7A8(
692 const AppleInfo& apple_info, const Convolution2DAttributes& attr,
693 const BHWC& dst_shape) {
694 const int dst_slices = DivideRoundUp(dst_shape.c, 4);
695 const int src_slices = DivideRoundUp(attr.weights.shape.i, 4);
696
697 ConvolutionMetal::ConvParams params;
698 params.weights_upload_type =
699 ConvolutionMetal::WeightsUploadType::LOCAL_MEM_BY_THREADS;
700 params.x_kernel_is_1 = IsKernelXIs1(attr);
701 params.y_kernel_is_1 = IsKernelYIs1(attr);
702 params.src_depth_loop_size = 1;
703 params.block_size = int3(1, 1, 1);
704 params.linear_wh = false;
705 params.linear_whs = false;
706 params.work_group_launch_order = int3(0, 1, 2);
707 params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
708
709 int blk_total_size = GetRecommendedBlockSize(apple_info, dst_shape);
710
711 if (blk_total_size >= 4 && (dst_slices % 4 == 0 || dst_slices >= 16)) {
712 params.block_size.z = 4;
713 blk_total_size /= 4;
714 } else if (blk_total_size >= 2 && (dst_slices % 2 == 0 || dst_slices >= 4)) {
715 params.block_size.z = 2;
716 blk_total_size /= 2;
717 }
718 if (blk_total_size >= 4) {
719 params.block_size.x = 2;
720 params.block_size.y = 2;
721 blk_total_size /= 4;
722 } else if (blk_total_size >= 2) {
723 if (dst_shape.w % 2 != 0 && dst_shape.h % 2 == 0) {
724 params.block_size.y = 2;
725 } else {
726 params.block_size.x = 2;
727 }
728 blk_total_size /= 2;
729 }
730
731 params.work_group_size = params.block_size.x <= params.block_size.y
732 ? int3(8, 4, 1)
733 : int3(4, 8, 1);
734
735 int g1 = GetGroupsCount(dst_shape, params.work_group_size, params.block_size);
736 int g2 = GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, params.block_size);
737 int g3 = GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, params.block_size);
738
739 if (g2 < g1) {
740 params.linear_wh = true;
741 params.work_group_size = int3(32, 1, 1);
742 params.work_group_launch_order = int3(0, 1, 2);
743 }
744 float precise_threshold = 3.1f;
745 float precise_ratio = static_cast<float>(g2) / static_cast<float>(g3);
746 if (precise_ratio > precise_threshold) {
747 params.linear_wh = false;
748 params.linear_whs = true;
749 params.work_group_size = int3(32, 1, 1);
750 params.weights_upload_type =
751 ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
752 }
753
754 if (params.src_depth_loop_size == src_slices) {
755 params.need_src_loop = false;
756 }
757 if (params.block_size.z == dst_slices) {
758 params.need_dst_loop = false;
759 }
760 const bool use_filters_constants =
761 !params.need_dst_loop && !params.need_src_loop && params.x_kernel_is_1 &&
762 params.y_kernel_is_1;
763 if (use_filters_constants) {
764 params.weights_upload_type =
765 ConvolutionMetal::WeightsUploadType::CONSTANT_MEM;
766 }
767
768 return params;
769 }
770
GetConvParamsForA9AndHigher(const AppleInfo & apple_info,const Convolution2DAttributes & attr,const BHWC & dst_shape)771 ConvolutionMetal::ConvParams GetConvParamsForA9AndHigher(
772 const AppleInfo& apple_info, const Convolution2DAttributes& attr,
773 const BHWC& dst_shape) {
774 const int dst_slices = DivideRoundUp(dst_shape.c, 4);
775 const int src_slices = DivideRoundUp(attr.weights.shape.i, 4);
776 int blk_total_size = GetRecommendedBlockSize(apple_info, dst_shape);
777 int3 block_size = int3(1, 1, 1);
778 if (blk_total_size >= 2 && apple_info.IsBionic()) {
779 if (dst_shape.h % 2 != 0 && dst_shape.w % 2 == 0) {
780 block_size.x = 2;
781 } else {
782 block_size.y = 2;
783 }
784 blk_total_size /= 2;
785 }
786 if (blk_total_size >= 4 && (dst_slices % 4 == 0 || dst_slices >= 16)) {
787 block_size.z = 4;
788 blk_total_size /= 4;
789 } else if (blk_total_size >= 2 && (dst_slices % 2 == 0 || dst_slices >= 4)) {
790 block_size.z = 2;
791 blk_total_size /= 2;
792 }
793 if (blk_total_size >= 4 && dst_slices == 3) {
794 block_size.z = 3;
795 blk_total_size /= 4;
796 }
797
798 ConvolutionMetal::ConvParams params;
799 params.weights_upload_type = ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
800 params.x_kernel_is_1 = IsKernelXIs1(attr);
801 params.y_kernel_is_1 = IsKernelYIs1(attr);
802 params.src_depth_loop_size = 1;
803 params.block_size = block_size;
804 params.linear_wh = false;
805 params.linear_whs = false;
806 params.work_group_size = int3(8, 4, 1);
807 params.work_group_launch_order = int3(2, 0, 1);
808 params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
809 int g1 = GetGroupsCount(dst_shape, {8, 4, 1}, block_size);
810 int g2 = GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, block_size);
811 int g3 = GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, block_size);
812 if (g2 < g1) {
813 params.linear_wh = true;
814 params.work_group_size = int3(32, 1, 1);
815 params.work_group_launch_order = int3(0, 1, 2);
816 }
817 float precise_threshold = apple_info.IsBionic() ? 1.0f : 1.04f;
818 float precise_ratio = static_cast<float>(g2) / static_cast<float>(g3);
819 if (precise_ratio > precise_threshold) {
820 params.linear_wh = false;
821 params.linear_whs = true;
822 params.work_group_size = int3(32, 1, 1);
823 }
824 int total_elements =
825 params.block_size.x * params.block_size.y * params.block_size.z;
826 if (total_elements == 1) {
827 if (src_slices % 4 == 0) {
828 params.src_depth_loop_size = 4;
829 } else if (src_slices % 2 == 0) {
830 params.src_depth_loop_size = 2;
831 }
832 } else if (total_elements == 2) {
833 if (src_slices % 2 == 0) {
834 params.src_depth_loop_size = 2;
835 }
836 }
837 if (params.src_depth_loop_size == src_slices) {
838 params.need_src_loop = false;
839 }
840 if (params.block_size.z == dst_slices) {
841 params.need_dst_loop = false;
842 }
843 const bool use_filters_constants =
844 !params.need_dst_loop && !params.need_src_loop && params.x_kernel_is_1 &&
845 params.y_kernel_is_1;
846 if (use_filters_constants) {
847 params.weights_upload_type =
848 ConvolutionMetal::WeightsUploadType::CONSTANT_MEM;
849 }
850
851 return params;
852 }
853
GetConvParamsForIntel(const Convolution2DAttributes & attr,CalculationsPrecision precision,const BHWC & dst_shape)854 ConvolutionMetal::ConvParams GetConvParamsForIntel(
855 const Convolution2DAttributes& attr, CalculationsPrecision precision,
856 const BHWC& dst_shape) {
857 const int dst_slices = DivideRoundUp(dst_shape.c, 4);
858 const int src_slices = DivideRoundUp(attr.weights.shape.i, 4);
859 ConvolutionMetal::ConvParams params;
860 params.weights_upload_type =
861 ConvolutionMetal::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST;
862 params.x_kernel_is_1 = IsKernelXIs1(attr);
863 params.y_kernel_is_1 = IsKernelYIs1(attr);
864 params.src_depth_loop_size = 1;
865 params.linear_wh = false;
866 params.linear_whs = false;
867 params.work_group_launch_order = int3(2, 0, 1);
868 params.block_size = int3(1, 1, 1);
869 if (dst_slices % 4 == 0 || dst_slices >= 8) {
870 params.block_size.z = 4;
871 } else if (dst_slices % 2 == 0 || dst_slices >= 4) {
872 params.block_size.z = 2;
873 }
874 params.work_group_size = int3(8, 2, 1);
875 if (precision == CalculationsPrecision::F32_F16) {
876 params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
877 } else {
878 params.weights_layout = WeightsLayout::kOHWIOGroupI4O4;
879 }
880
881 if (src_slices % 2 == 0) {
882 params.src_depth_loop_size = 2;
883 }
884
885 int g1 = GetGroupsCount(dst_shape, params.work_group_size, params.block_size);
886 int g2 = GetGroupsCountForLinearWH(dst_shape, {16, 1, 1}, params.block_size);
887
888 if (g2 < g1) {
889 params.linear_wh = true;
890 params.work_group_size = int3(16, 1, 1);
891 params.work_group_launch_order = int3(1, 0, 2);
892 }
893
894 return params;
895 }
896
GetConvParamsForAMD(const Convolution2DAttributes & attr,CalculationsPrecision precision,const BHWC & dst_shape)897 ConvolutionMetal::ConvParams GetConvParamsForAMD(
898 const Convolution2DAttributes& attr, CalculationsPrecision precision,
899 const BHWC& dst_shape) {
900 ConvolutionMetal::ConvParams params;
901 params.block_size = int3(1, 1, 4);
902 params.work_group_size = int3(8, 4, 1);
903 params.work_group_launch_order = int3(2, 0, 1);
904 params.src_depth_loop_size = 1;
905 params.need_src_loop = true;
906 params.need_dst_loop = true;
907 params.linear_wh = false;
908 params.linear_whs = false;
909 params.weights_upload_type = ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
910 params.different_weights_for_height = false;
911 params.x_kernel_is_1 = IsKernelXIs1(attr);
912 params.y_kernel_is_1 = IsKernelYIs1(attr);
913 if (precision == CalculationsPrecision::F32_F16) {
914 params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
915 } else {
916 params.weights_layout = WeightsLayout::kOHWIOGroupI4O4;
917 }
918 return params;
919 }
920
GetConvParams(const GpuInfo & gpu_info,const Convolution2DAttributes & attr,CalculationsPrecision precision,const BHWC & dst_shape)921 ConvolutionMetal::ConvParams GetConvParams(const GpuInfo& gpu_info,
922 const Convolution2DAttributes& attr,
923 CalculationsPrecision precision,
924 const BHWC& dst_shape) {
925 if (gpu_info.IsApple()) {
926 if (gpu_info.apple_info.IsLocalMemoryPreferredOverGlobal()) {
927 return GetConvParamsForA7A8(gpu_info.apple_info, attr, dst_shape);
928 } else {
929 return GetConvParamsForA9AndHigher(gpu_info.apple_info, attr, dst_shape);
930 }
931 } else if (gpu_info.IsIntel()) {
932 return GetConvParamsForIntel(attr, precision, dst_shape);
933 } else if (gpu_info.IsAMD()) {
934 return GetConvParamsForAMD(attr, precision, dst_shape);
935 } else {
936 ConvolutionMetal::ConvParams params;
937 params.block_size = int3(1, 1, 4);
938 params.work_group_size = int3(8, 4, 1);
939 params.work_group_launch_order = int3(2, 0, 1);
940 params.src_depth_loop_size = 1;
941 params.need_src_loop = true;
942 params.need_dst_loop = true;
943 params.linear_wh = false;
944 params.linear_whs = false;
945 params.weights_upload_type =
946 ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
947 params.different_weights_for_height = false;
948 params.x_kernel_is_1 = IsKernelXIs1(attr);
949 params.y_kernel_is_1 = IsKernelYIs1(attr);
950 params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
951 return params;
952 }
953 }
954
955 } // namespace
956
BindArguments(ArgumentsBinder * args)957 absl::Status ConvolutionMetal::BindArguments(ArgumentsBinder* args) {
958 RETURN_IF_ERROR(args->SetInt("padding_x", padding_.x * src_[0]->Batch()));
959 RETURN_IF_ERROR(args->SetInt("dilation_x", dilation_.x * src_[0]->Batch()));
960 const int grid_x =
961 DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(), params_.block_size.x);
962 const int grid_y = DivideRoundUp(dst_[0]->Height(), params_.block_size.y);
963 RETURN_IF_ERROR(args->SetInt("task_size_x", grid_x));
964 RETURN_IF_ERROR(args->SetInt("task_size_y", grid_x * grid_y));
965 return absl::OkStatus();
966 }
967
GetGridSize() const968 int3 ConvolutionMetal::GetGridSize() const {
969 int grid_x =
970 DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(), params_.block_size.x);
971 int grid_y = DivideRoundUp(dst_[0]->Height(), params_.block_size.y);
972 int grid_z = DivideRoundUp(dst_[0]->Slices(), params_.block_size.z);
973
974 int3 group_size(params_.work_group_size);
975 int3 wg;
976 uint3 groups_count;
977 if (params_.linear_whs) {
978 return int3(grid_x * grid_y * grid_z, 1, 1);
979 } else if (params_.linear_wh) {
980 return int3(grid_x * grid_y, grid_z, 1);
981 } else {
982 return int3(grid_x, grid_y, grid_z);
983 }
984 }
985
CreateConvolutionMetal(const OperationDef & definition,const BHWC & dst_shape,const Convolution2DAttributes & attr,const GpuInfo & gpu_info)986 ConvolutionMetal CreateConvolutionMetal(const OperationDef& definition,
987 const BHWC& dst_shape,
988 const Convolution2DAttributes& attr,
989 const GpuInfo& gpu_info) {
990 BHWC new_shape = BHWC(1, dst_shape.h, dst_shape.w * dst_shape.b, dst_shape.c);
991 ConvolutionMetal::ConvParams params =
992 GetConvParams(gpu_info, attr, definition.precision, new_shape);
993
994 ConvolutionMetal desc(definition);
995 desc.params_ = params;
996 const bool stride_correction =
997 definition.IsBatchSupported() && attr.strides.w != 1;
998 desc.code_ = GenerateConvolution(params, definition, stride_correction);
999
1000 auto src_desc = definition.src_tensors[0];
1001 if (definition.IsBatchSupported()) {
1002 src_desc.SetStateVar("BatchedWidth", "true");
1003 }
1004 desc.AddSrcTensor("src_tensor", src_desc);
1005 auto dst_desc = definition.dst_tensors[0];
1006 if (definition.IsBatchSupported()) {
1007 dst_desc.SetStateVar("BatchedWidth", "true");
1008 }
1009 desc.AddDstTensor("dst_tensor", dst_desc);
1010
1011 desc.args_.AddInt("kernel_size_x", attr.weights.shape.w);
1012 desc.args_.AddInt("kernel_size_y", attr.weights.shape.h);
1013 desc.args_.AddInt("dilation_x", attr.dilations.w);
1014 desc.args_.AddInt("dilation_y", attr.dilations.h);
1015 desc.args_.AddInt("stride_x", attr.strides.w);
1016 desc.args_.AddInt("stride_y", attr.strides.h);
1017 desc.args_.AddInt("padding_x", -attr.padding.prepended.w);
1018 desc.args_.AddInt("padding_y", -attr.padding.prepended.h);
1019 desc.padding_ = int2(-attr.padding.prepended.w, -attr.padding.prepended.h);
1020 desc.dilation_ = int2(attr.dilations.w, attr.dilations.h);
1021
1022 auto weights_type = DeduceDataTypeFromPrecision(definition.precision);
1023
1024 MemoryType mem_type =
1025 params.weights_upload_type ==
1026 ConvolutionMetal::WeightsUploadType::CONSTANT_MEM
1027 ? MemoryType::CONSTANT
1028 : MemoryType::GLOBAL;
1029
1030 if (definition.src_tensors.size() == 2) {
1031 // dynamic weights
1032 BufferDescriptor weights_desc;
1033 weights_desc.element_type = definition.src_tensors[1].data_type;
1034 weights_desc.element_size = 4;
1035 weights_desc.memory_type = mem_type;
1036 desc.AddSrcBuffer("weights", weights_desc);
1037 } else {
1038 BufferDescriptor weights_desc;
1039 weights_desc.element_type = weights_type;
1040 weights_desc.element_size = 4;
1041 weights_desc.memory_type = mem_type;
1042 weights_desc.data = ReorderWeightsForConv(
1043 attr.weights, desc.GetWeightsDescription(), weights_type);
1044 weights_desc.size = weights_desc.data.size();
1045 desc.args_.AddObject("weights", absl::make_unique<BufferDescriptor>(
1046 std::move(weights_desc)));
1047 }
1048
1049 BufferDescriptor bias_desc;
1050 bias_desc.element_type = weights_type;
1051 bias_desc.element_size = 4;
1052 bias_desc.memory_type = mem_type;
1053 bias_desc.data = ReorderBiasesForConv(
1054 attr.bias, weights_type,
1055 AlignByN(attr.weights.shape.o, params.block_size.z * 4));
1056 bias_desc.size = bias_desc.data.size();
1057 desc.args_.AddObject(
1058 "biases", absl::make_unique<BufferDescriptor>(std::move(bias_desc)));
1059
1060 desc.args_.AddInt("task_size_x");
1061 desc.args_.AddInt("task_size_y");
1062
1063 desc.work_group_size_ = params.work_group_size;
1064 desc.work_group_launch_order_ = params.work_group_launch_order;
1065 if (params.linear_whs) {
1066 desc.grid_dimension_ = 1;
1067 } else if (params.linear_wh) {
1068 desc.grid_dimension_ = 2;
1069 } else {
1070 desc.grid_dimension_ = 3;
1071 }
1072
1073 return desc;
1074 }
1075
CreateConvolutionMetalWino4x4To6x6(const OperationDef & definition,const BHWC & dst_shape,const Convolution2DAttributes & attr,const GpuInfo & gpu_info)1076 ConvolutionMetal CreateConvolutionMetalWino4x4To6x6(
1077 const OperationDef& definition, const BHWC& dst_shape,
1078 const Convolution2DAttributes& attr, const GpuInfo& gpu_info) {
1079 ConvolutionMetal::ConvParams params;
1080 params.work_group_launch_order = int3(2, 0, 1);
1081 params.src_depth_loop_size = 1;
1082 params.need_src_loop = true;
1083 params.need_dst_loop = true;
1084 params.linear_wh = false;
1085 params.linear_whs = false;
1086 params.different_weights_for_height = true;
1087 params.x_kernel_is_1 = true;
1088 params.y_kernel_is_1 = true;
1089 if (gpu_info.IsApple()) {
1090 params.weights_layout = WeightsLayout::kOHWIOGroupO4I4;
1091 if (gpu_info.apple_info.IsLocalMemoryPreferredOverGlobal()) {
1092 params.weights_upload_type =
1093 ConvolutionMetal::WeightsUploadType::LOCAL_MEM_BY_THREADS;
1094 params.work_group_size = int3(32, 1, 1);
1095 params.block_size = int3(4, 1, 4);
1096 } else {
1097 params.weights_upload_type =
1098 ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
1099 params.work_group_size = int3(8, 4, 1);
1100 params.block_size = int3(4, 1, 4);
1101 }
1102 } else if (gpu_info.IsIntel()) {
1103 params.weights_layout = WeightsLayout::kOHWIOGroupI4O4;
1104 params.weights_upload_type =
1105 ConvolutionMetal::WeightsUploadType::PRIVATE_MEM_SIMD8_BROADCAST;
1106 params.work_group_size = int3(16, 1, 1);
1107 params.block_size = int3(1, 1, 4);
1108 } else if (gpu_info.IsAMD()) {
1109 params.weights_layout = WeightsLayout::kOHWIOGroupI4O4;
1110 params.weights_upload_type =
1111 ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
1112 params.work_group_size = int3(32, 1, 1);
1113 params.block_size = int3(2, 1, 4);
1114 } else {
1115 params.weights_layout = WeightsLayout::kOHWIOGroupI4O4;
1116 params.weights_upload_type =
1117 ConvolutionMetal::WeightsUploadType::GLOBAL_MEM;
1118 params.work_group_size = int3(32, 1, 1);
1119 params.block_size = int3(2, 1, 4);
1120 }
1121
1122 ConvolutionMetal desc(definition);
1123 desc.params_ = params;
1124 desc.code_ = GenerateConvolution(params, definition, false);
1125 auto src_desc = definition.src_tensors[0];
1126 if (definition.IsBatchSupported()) {
1127 src_desc.SetStateVar("BatchedWidth", "true");
1128 }
1129 desc.AddSrcTensor("src_tensor", src_desc);
1130 auto dst_desc = definition.dst_tensors[0];
1131 if (definition.IsBatchSupported()) {
1132 dst_desc.SetStateVar("BatchedWidth", "true");
1133 }
1134 desc.AddDstTensor("dst_tensor", dst_desc);
1135
1136 desc.args_.AddInt("kernel_size_x", 1);
1137 desc.args_.AddInt("kernel_size_y", 1);
1138 desc.args_.AddInt("dilation_x", 1);
1139 desc.args_.AddInt("dilation_y", 1);
1140 desc.args_.AddInt("stride_x", 1);
1141 desc.args_.AddInt("stride_y", 1);
1142 desc.args_.AddInt("padding_x", 0);
1143 desc.args_.AddInt("padding_y", 0);
1144 desc.padding_ = int2(0, 0);
1145 desc.dilation_ = int2(1, 1);
1146
1147 auto weights_type = DeduceDataTypeFromPrecision(definition.precision);
1148
1149 tflite::gpu::Tensor<OHWI, DataType::FLOAT32> wino_weights;
1150 tflite::gpu::Tensor<Linear, DataType::FLOAT32> wino_biases;
1151 RearrangeWeightsToWinograd4x4To6x6Weights(attr.weights, &wino_weights);
1152 wino_biases.shape = Linear(attr.weights.shape.o);
1153 wino_biases.data.resize(attr.weights.shape.o, 0.0f);
1154
1155 BufferDescriptor weights_desc;
1156 weights_desc.element_type = weights_type;
1157 weights_desc.element_size = 4;
1158 weights_desc.data = ReorderWeightsForConv(
1159 wino_weights, desc.GetWeightsDescription(), weights_type);
1160 weights_desc.size = weights_desc.data.size();
1161 desc.args_.AddObject(
1162 "weights", absl::make_unique<BufferDescriptor>(std::move(weights_desc)));
1163
1164 BufferDescriptor bias_desc;
1165 bias_desc.element_type = weights_type;
1166 bias_desc.element_size = 4;
1167 bias_desc.data = ReorderBiasesForConv(
1168 wino_biases, weights_type,
1169 AlignByN(attr.weights.shape.o, params.block_size.z * 4));
1170 bias_desc.size = bias_desc.data.size();
1171 desc.args_.AddObject(
1172 "biases", absl::make_unique<BufferDescriptor>(std::move(bias_desc)));
1173
1174 desc.args_.AddInt("task_size_x");
1175 desc.args_.AddInt("task_size_y");
1176
1177 desc.work_group_size_ = params.work_group_size;
1178 desc.work_group_launch_order_ = params.work_group_launch_order;
1179 if (params.linear_whs) {
1180 desc.grid_dimension_ = 1;
1181 } else if (params.linear_wh) {
1182 desc.grid_dimension_ = 2;
1183 } else {
1184 desc.grid_dimension_ = 3;
1185 }
1186
1187 return desc;
1188 }
1189
IsConvolutionMetalSupported(const OperationDef & definition)1190 bool IsConvolutionMetalSupported(const OperationDef& definition) {
1191 return !definition.src_tensors[0].HasAxis(Axis::DEPTH);
1192 }
1193
1194 } // namespace gpu
1195 } // namespace tflite
1196