1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/reduce.h"
17
18 #include <set>
19 #include <string>
20
21 #include "tensorflow/lite/delegates/gpu/common/status.h"
22 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
23 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
24 #include "tensorflow/lite/delegates/gpu/common/util.h"
25
26 namespace tflite {
27 namespace gpu {
28
29 namespace {
GetMaximumWGTotalSize(const GpuInfo & gpu_info)30 int GetMaximumWGTotalSize(const GpuInfo& gpu_info) {
31 // total_wg_size must be power of 2 and >= 4;
32 int total_wg_size = 256;
33 if (gpu_info.IsAdreno() && gpu_info.adreno_info.IsAdreno3xx()) {
34 total_wg_size = 128;
35 }
36 if (gpu_info.IsMali()) {
37 const MaliInfo& mali_info = gpu_info.mali_info;
38 if (mali_info.IsMaliT6xx() || mali_info.IsMaliT7xx() ||
39 mali_info.IsMaliT8xx()) {
40 total_wg_size = 32;
41 } else {
42 total_wg_size = 64;
43 }
44 }
45 return total_wg_size;
46 }
47
HasAxis(const std::vector<Axis> & axis,Axis a)48 bool HasAxis(const std::vector<Axis>& axis, Axis a) {
49 for (const auto& a2 : axis) {
50 if (a2 == a) {
51 return true;
52 }
53 }
54 return false;
55 }
56
MakeOp(OperationType op_type,const std::string & a,const std::string & b)57 std::string MakeOp(OperationType op_type, const std::string& a,
58 const std::string& b) {
59 if (op_type == OperationType::REDUCE_SUM || op_type == OperationType::MEAN) {
60 return "((" + a + ") + (" + b + "))";
61 } else if (op_type == OperationType::REDUCE_PRODUCT) {
62 return "((" + a + ") * (" + b + "))";
63 } else if (op_type == OperationType::REDUCE_MAXIMUM) {
64 return "max(" + a + ", " + b + ")";
65 } else if (op_type == OperationType::REDUCE_MINIMUM) {
66 return "min(" + a + ", " + b + ")";
67 }
68 return "UnsupportedOperation";
69 }
70
71 // max_total_wg_size is pot
GetMaximumPossibleWGSize(const std::vector<int> & ordered_sizes,int max_total_wg_size)72 int3 GetMaximumPossibleWGSize(const std::vector<int>& ordered_sizes,
73 int max_total_wg_size) {
74 int3 wg_size = int3(1, 1, 1);
75 int wg_size_total = 1;
76 for (int i = ordered_sizes.size() - 1; i >= 0; i--) {
77 const int wg_index = ordered_sizes.size() - 1 - i;
78 if (wg_index >= 3) {
79 return wg_size;
80 }
81 while (ordered_sizes[i] >= wg_size[wg_index] * 2) {
82 wg_size_total *= 2;
83 if (wg_size_total > max_total_wg_size) {
84 return wg_size;
85 }
86 wg_size[wg_index] *= 2;
87 }
88 }
89 return wg_size;
90 }
91
GetSizesFromShape(const std::set<Axis> & axis,const BHWC & shape)92 std::map<Axis, int> GetSizesFromShape(const std::set<Axis>& axis,
93 const BHWC& shape) {
94 std::map<Axis, int> result;
95 for (auto a : axis) {
96 result[a] = shape.get(a);
97 }
98 return result;
99 }
100
GetSizesFromShape(const std::set<Axis> & axis,const BHWDC & shape)101 std::map<Axis, int> GetSizesFromShape(const std::set<Axis>& axis,
102 const BHWDC& shape) {
103 std::map<Axis, int> result;
104 for (auto a : axis) {
105 result[a] = shape.get(a);
106 }
107 return result;
108 }
109
110 } // namespace
111
Reduce(const std::map<Axis,int> & axis_to_reduce,OperationType op_type,const OperationDef & definition,const GpuInfo & gpu_info)112 Reduce::Reduce(const std::map<Axis, int>& axis_to_reduce, OperationType op_type,
113 const OperationDef& definition, const GpuInfo& gpu_info)
114 : GPUOperation(definition) {
115 std::vector<Axis> ordered_axis_to_reduce;
116 std::vector<int> ordered_sizes;
117 for (const auto& a :
118 {Axis::CHANNELS, Axis::DEPTH, Axis::HEIGHT, Axis::WIDTH, Axis::BATCH}) {
119 auto it = axis_to_reduce.find(a);
120 if (it != axis_to_reduce.end()) {
121 ordered_axis_to_reduce.push_back(it->first);
122 int reduction_size = it->second;
123 if (a == Axis::CHANNELS) {
124 reduction_size = DivideRoundUp(reduction_size, 4);
125 }
126 ordered_sizes.push_back(reduction_size);
127 }
128 }
129 const int max_total_wg_size = GetMaximumWGTotalSize(gpu_info);
130 int3 current_wg_size =
131 GetMaximumPossibleWGSize(ordered_sizes, max_total_wg_size);
132 int current_wg_size_total =
133 current_wg_size.x * current_wg_size.y * current_wg_size.z;
134 int threshold = max_total_wg_size / 4;
135 if (gpu_info.IsApple()) {
136 threshold = 16;
137 }
138 if (current_wg_size_total < threshold) {
139 use_wg_reduction_ = false;
140 } else {
141 use_wg_reduction_ = true;
142 work_group_size_ = current_wg_size;
143 }
144 code_ = GetReduceKernelCode(definition_, work_group_size_,
145 ordered_axis_to_reduce, op_type);
146 }
147
Reduce(Reduce && operation)148 Reduce::Reduce(Reduce&& operation)
149 : GPUOperation(std::move(operation)),
150 use_wg_reduction_(operation.use_wg_reduction_) {}
151
operator =(Reduce && operation)152 Reduce& Reduce::operator=(Reduce&& operation) {
153 if (this != &operation) {
154 use_wg_reduction_ = operation.use_wg_reduction_;
155 GPUOperation::operator=(std::move(operation));
156 }
157 return *this;
158 }
159
GetReduceKernelCode(const OperationDef & op_def,const int3 & work_group_size,const std::vector<Axis> & axis_to_reduce,OperationType op_type)160 std::string Reduce::GetReduceKernelCode(const OperationDef& op_def,
161 const int3& work_group_size,
162 const std::vector<Axis>& axis_to_reduce,
163 OperationType op_type) {
164 AddSrcTensor("src_tensor", op_def.src_tensors[0]);
165 AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
166 args_.AddFloat("inv_multiplier_1");
167 args_.AddFloat("inv_multiplier_2");
168 args_.AddFloat("mask_x");
169 args_.AddFloat("mask_y");
170 args_.AddFloat("mask_z");
171 args_.AddFloat("mask_w");
172
173 std::set<Axis> axis_to_leave;
174 const std::vector<Axis> all_axis = {Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH,
175 Axis::CHANNELS, Axis::BATCH};
176 for (const auto& a : all_axis) {
177 if (op_def.dst_tensors[0].HasAxis(a)) {
178 if (!HasAxis(axis_to_reduce, a)) {
179 axis_to_leave.insert(a);
180 }
181 }
182 }
183 const bool channels_reductin = HasAxis(axis_to_reduce, Axis::CHANNELS);
184 int wg_dims = 0;
185 if (use_wg_reduction_) {
186 if (work_group_size.y == 1 && work_group_size.z == 1) {
187 wg_dims = 1;
188 } else if (work_group_size.z == 1) {
189 wg_dims = 2;
190 } else {
191 wg_dims = 3;
192 }
193 }
194
195 auto get_global_id = [&](int i) {
196 if (use_wg_reduction_) {
197 return "GROUP_ID_" + std::to_string(i);
198 } else {
199 return "GLOBAL_ID_" + std::to_string(i);
200 }
201 };
202
203 std::string c;
204 const std::string wg_x = std::to_string(work_group_size.x);
205 const std::string wg_y = std::to_string(work_group_size.y);
206 const std::string wg_z = std::to_string(work_group_size.z);
207 const int wg_total_size =
208 work_group_size.x * work_group_size.y * work_group_size.z;
209 c += "MAIN_FUNCTION($0) {\n";
210 if (use_wg_reduction_) {
211 c += " __local float4 accum[" + std::to_string(wg_total_size) + "];\n";
212 if (wg_dims == 1) {
213 c += " int local_x = LOCAL_ID_0;\n";
214 c += " int local_id = local_x;\n";
215 } else if (wg_dims == 2) {
216 c += " int local_x = LOCAL_ID_0;\n";
217 c += " int local_y = LOCAL_ID_1;\n";
218 c += " int local_id = local_y * " + wg_x + " + local_x;\n";
219 } else if (wg_dims == 3) {
220 c += " int local_x = LOCAL_ID_0;\n";
221 c += " int local_y = LOCAL_ID_1;\n";
222 c += " int local_z = LOCAL_ID_2;\n";
223 c += " int local_id = (local_z * " + wg_y + " + local_y) * " + wg_x +
224 " + local_x;\n";
225 }
226 }
227 if (axis_to_leave.count(Axis::WIDTH)) {
228 if (axis_to_leave.count(Axis::BATCH)) {
229 c += " int linear_id = " + get_global_id(0) + ";\n";
230 c += " int DST_X = linear_id / args.dst_tensor.Batch();\n";
231 c += " int DST_B = linear_id % args.dst_tensor.Batch();\n";
232 } else {
233 c += " int DST_X = " + get_global_id(0) + ";\n";
234 }
235 } else if (axis_to_leave.count(Axis::BATCH)) {
236 c += " int DST_B = " + get_global_id(0) + ";\n";
237 }
238 if (axis_to_leave.count(Axis::HEIGHT)) {
239 if (axis_to_leave.count(Axis::DEPTH)) {
240 c += " int linear_id = " + get_global_id(1) + ";\n";
241 c += " int DST_Y = linear_id % args.dst_tensor.Height();\n";
242 c += " int DST_Z = linear_id / args.dst_tensor.Height();\n";
243 } else {
244 c += " int DST_Y = " + get_global_id(1) + ";\n";
245 }
246 } else if (axis_to_leave.count(Axis::DEPTH)) {
247 c += " int DST_Z = " + get_global_id(1) + ";\n";
248 }
249 if (axis_to_leave.count(Axis::CHANNELS)) {
250 c += " int DST_S = " + get_global_id(2) + ";\n";
251 }
252 std::map<Axis, std::string> axis_to_selector = {
253 {Axis::BATCH, "Batch()"}, {Axis::WIDTH, "Width()"},
254 {Axis::HEIGHT, "Height()"}, {Axis::DEPTH, "Depth()"},
255 {Axis::CHANNELS, "Slices()"},
256 };
257 std::map<Axis, std::string> axis_to_coord = {
258 {Axis::BATCH, "B"}, {Axis::WIDTH, "X"}, {Axis::HEIGHT, "Y"},
259 {Axis::DEPTH, "Z"}, {Axis::CHANNELS, "S"},
260 };
261 std::string dst_check;
262 for (auto& axis : axis_to_leave) {
263 if (!dst_check.empty()) {
264 dst_check += " || ";
265 }
266 dst_check += "DST_" + axis_to_coord[axis] + " >= args.dst_tensor." +
267 axis_to_selector[axis];
268 }
269 if (!dst_check.empty()) {
270 c += " if (" + dst_check + ") return;\n";
271 }
272 std::map<Axis, std::string> src_coords;
273 for (const auto& a : all_axis) {
274 if (op_def.dst_tensors[0].HasAxis(a) && !HasAxis(axis_to_reduce, a)) {
275 src_coords[a] = "DST_" + axis_to_coord[a];
276 } else {
277 src_coords[a] = "0";
278 }
279 }
280 std::string src_coordinates;
281 for (const auto& a : all_axis) {
282 if (op_def.src_tensors[0].HasAxis(a)) {
283 if (!src_coordinates.empty()) {
284 src_coordinates += ", ";
285 }
286 src_coordinates += src_coords[a];
287 }
288 }
289 if (op_type == OperationType::REDUCE_SUM || op_type == OperationType::MEAN) {
290 c += " float4 reducer = INIT_FLOAT4(0.0f);\n";
291 } else if (op_type == OperationType::REDUCE_PRODUCT) {
292 c += " float4 reducer = INIT_FLOAT4(1.0f);\n";
293 } else if (op_type == OperationType::REDUCE_MAXIMUM ||
294 op_type == OperationType::REDUCE_MINIMUM) {
295 c += " float4 reducer = args.src_tensor.Read<float>(" + src_coordinates +
296 ");\n";
297 if (channels_reductin) {
298 c += " reducer.y = reducer.x;\n";
299 c += " reducer.z = reducer.x;\n";
300 c += " reducer.w = reducer.x;\n";
301 }
302 }
303 const std::vector<std::string> local_ids = {"local_x", "local_y", "local_z"};
304 const std::vector<std::string> local_sizes = {wg_x, wg_y, wg_z};
305 for (int i = 0; i < axis_to_reduce.size(); ++i) {
306 const auto& axis = axis_to_reduce[i];
307 const int index = axis_to_reduce.size() - 1 - i;
308 const std::string first = index < wg_dims ? local_ids[index] : "0";
309 const std::string step = index < wg_dims ? local_sizes[index] : "1";
310 const std::string src_coord = "SRC_" + axis_to_coord[axis];
311 src_coords[axis] = src_coord;
312 c += " for (int " + src_coord + " = " + first + "; " + src_coord +
313 " < args.src_tensor." + axis_to_selector[axis] + "; " + src_coord +
314 " += " + step + ") {\n";
315 if (axis == Axis::CHANNELS) {
316 c += " bool last = SRC_S == args.src_tensor.Slices() - 1;\n";
317 c += " float4 mask_a = last ? INIT_FLOAT4v4(args.mask_x, args.mask_y, "
318 "args.mask_z, args.mask_w) : INIT_FLOAT4(1.0f);\n";
319 if (op_type == OperationType::REDUCE_PRODUCT ||
320 op_type == OperationType::REDUCE_MAXIMUM ||
321 op_type == OperationType::REDUCE_MINIMUM) {
322 c += " float4 mask_b = INIT_FLOAT4(1.0f) - mask_a;\n";
323 }
324 }
325 }
326 src_coordinates = "";
327 for (const auto& a : all_axis) {
328 if (op_def.src_tensors[0].HasAxis(a)) {
329 if (!src_coordinates.empty()) {
330 src_coordinates += ", ";
331 }
332 src_coordinates += src_coords[a];
333 }
334 }
335 c += " float4 src_val = args.src_tensor.Read<float>(" + src_coordinates +
336 ");\n";
337 if (channels_reductin) {
338 if (op_type == OperationType::REDUCE_SUM ||
339 op_type == OperationType::MEAN) {
340 c += " src_val = src_val * mask_a;\n";
341 } else if (op_type == OperationType::REDUCE_PRODUCT) {
342 c += " src_val = src_val * mask_a + mask_b;\n";
343 } else if (op_type == OperationType::REDUCE_MAXIMUM ||
344 op_type == OperationType::REDUCE_MINIMUM) {
345 c += " src_val = src_val * mask_a + mask_b * src_val.x;\n";
346 }
347 }
348 c += " reducer = " + MakeOp(op_type, "reducer", "src_val") + ";\n";
349 for (int i = 0; i < axis_to_reduce.size(); ++i) {
350 c += " }\n";
351 }
352 if (op_type == OperationType::MEAN) {
353 c += " reducer *= args.inv_multiplier_1;\n";
354 }
355 if (use_wg_reduction_) {
356 c += " accum[local_id] = reducer;\n";
357 c += " LOCAL_MEM_BARRIER;\n";
358 const int total_size =
359 work_group_size.x * work_group_size.y * work_group_size.z;
360 int offset = 1;
361 int reminder = total_size / 4;
362 for (; reminder >= 8; reminder /= 4, offset *= 4) {
363 c += " if (local_id < " + std::to_string(reminder) + ") {\n";
364 c += " int t = local_id * " + std::to_string(offset * 4) + ";\n";
365 c += " float4 sum = accum[t + " + std::to_string(offset) + "];\n";
366 c += " sum = " +
367 MakeOp(op_type, "sum",
368 "accum[t + " + std::to_string(offset * 2) + "]") +
369 ";\n";
370 c += " sum = " +
371 MakeOp(op_type, "sum",
372 "accum[t + " + std::to_string(offset * 3) + "]") +
373 ";\n";
374 c += " accum[t] = " + MakeOp(op_type, "accum[t]", "sum") + ";\n";
375 c += " }\n";
376 c += " LOCAL_MEM_BARRIER;\n";
377 }
378 c += " reducer = accum[0];\n";
379 reminder *= 4;
380 for (int i = 1; i < reminder; ++i) {
381 c += " reducer = " +
382 MakeOp(op_type, "reducer",
383 "accum[" + std::to_string(offset * i) + "]") +
384 ";\n";
385 }
386 if (op_type == OperationType::MEAN) {
387 c += " reducer *= args.inv_multiplier_2;\n";
388 }
389 }
390 if (channels_reductin) {
391 if (op_type == OperationType::REDUCE_SUM ||
392 op_type == OperationType::MEAN) {
393 c += " reducer.x += reducer.y + reducer.z + reducer.w;\n";
394 } else if (op_type == OperationType::REDUCE_PRODUCT) {
395 c += " reducer.x *= reducer.y * reducer.z * reducer.w;\n";
396 } else if (op_type == OperationType::REDUCE_MAXIMUM) {
397 c += " reducer.x = max(reducer.x, reducer.y);\n";
398 c += " reducer.x = max(reducer.x, reducer.z);\n";
399 c += " reducer.x = max(reducer.x, reducer.w);\n";
400 } else if (op_type == OperationType::REDUCE_MINIMUM) {
401 c += " reducer.x = min(reducer.x, reducer.y);\n";
402 c += " reducer.x = min(reducer.x, reducer.z);\n";
403 c += " reducer.x = min(reducer.x, reducer.w);\n";
404 }
405 }
406 c += " FLT4 result = TO_FLT4(reducer);\n";
407 std::string dst_coordinates;
408 for (const auto& a : all_axis) {
409 if (op_def.dst_tensors[0].HasAxis(a)) {
410 if (!dst_coordinates.empty()) {
411 dst_coordinates += ", ";
412 }
413 if (axis_to_leave.count(a)) {
414 dst_coordinates += "DST_" + axis_to_coord[a];
415 } else {
416 dst_coordinates += "0";
417 }
418 }
419 }
420 c += " args.dst_tensor.Write(result, " + dst_coordinates + ");\n";
421 c += "}\n";
422 return c;
423 }
424
BindArguments(ArgumentsBinder * args)425 absl::Status Reduce::BindArguments(ArgumentsBinder* args) {
426 const double total_src_elements = 1.0 * src_[0]->Batch() * src_[0]->Width() *
427 src_[0]->Height() * src_[0]->Depth() *
428 src_[0]->Channels();
429 const double total_dst_elements = 1.0 * dst_[0]->Batch() * dst_[0]->Width() *
430 dst_[0]->Height() * dst_[0]->Depth() *
431 dst_[0]->Channels();
432 const double reduction_size = total_src_elements / total_dst_elements;
433 if (use_wg_reduction_) {
434 const double size_0 =
435 work_group_size_.x * work_group_size_.y * work_group_size_.z;
436 const double size_1 = reduction_size / size_0;
437 RETURN_IF_ERROR(args->SetFloat("inv_multiplier_1", 1.0 / size_1));
438 RETURN_IF_ERROR(args->SetFloat("inv_multiplier_2", 1.0 / size_0));
439 } else {
440 RETURN_IF_ERROR(args->SetFloat("inv_multiplier_1", 1.0 / reduction_size));
441 RETURN_IF_ERROR(args->SetFloat("inv_multiplier_2", 1.0));
442 }
443 float4 mask = GetMaskForLastPlane(src_[0]->Channels());
444 RETURN_IF_ERROR(args->SetFloat("mask_x", mask.x));
445 RETURN_IF_ERROR(args->SetFloat("mask_y", mask.y));
446 RETURN_IF_ERROR(args->SetFloat("mask_z", mask.z));
447 RETURN_IF_ERROR(args->SetFloat("mask_w", mask.w));
448 return absl::OkStatus();
449 }
450
GetGridSize() const451 int3 Reduce::GetGridSize() const {
452 int grid_x = dst_[0]->Width() * dst_[0]->Batch();
453 int grid_y = dst_[0]->Height() * dst_[0]->Depth();
454 int grid_z = dst_[0]->Slices();
455 if (use_wg_reduction_) {
456 grid_x *= work_group_size_.x;
457 grid_y *= work_group_size_.y;
458 grid_z *= work_group_size_.z;
459 }
460 return int3(grid_x, grid_y, grid_z);
461 }
462
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const463 void Reduce::GetPossibleKernelWorkGroups(TuningType tuning_type,
464 const GpuInfo& gpu_info,
465 const KernelInfo& kernel_info,
466 std::vector<int3>* work_groups) const {
467 if (use_wg_reduction_) {
468 work_groups->push_back(work_group_size_);
469 } else {
470 GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
471 work_groups);
472 }
473 }
474
CreateReduce(const std::set<Axis> & axis_to_reduce,const BHWC & src_shape,OperationType op_type,const OperationDef & definition,const GpuInfo & gpu_info)475 Reduce CreateReduce(const std::set<Axis>& axis_to_reduce, const BHWC& src_shape,
476 OperationType op_type, const OperationDef& definition,
477 const GpuInfo& gpu_info) {
478 return Reduce(GetSizesFromShape(axis_to_reduce, src_shape), op_type,
479 definition, gpu_info);
480 }
481
CreateReduce(const std::set<Axis> & axis_to_reduce,const BHWDC & src_shape,OperationType op_type,const OperationDef & definition,const GpuInfo & gpu_info)482 Reduce CreateReduce(const std::set<Axis>& axis_to_reduce,
483 const BHWDC& src_shape, OperationType op_type,
484 const OperationDef& definition, const GpuInfo& gpu_info) {
485 return Reduce(GetSizesFromShape(axis_to_reduce, src_shape), op_type,
486 definition, gpu_info);
487 }
488
489 } // namespace gpu
490 } // namespace tflite
491