• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/reduce.h"
17 
18 #include <set>
19 #include <string>
20 
21 #include "tensorflow/lite/delegates/gpu/common/status.h"
22 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
23 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
24 #include "tensorflow/lite/delegates/gpu/common/util.h"
25 
26 namespace tflite {
27 namespace gpu {
28 
29 namespace {
GetMaximumWGTotalSize(const GpuInfo & gpu_info)30 int GetMaximumWGTotalSize(const GpuInfo& gpu_info) {
31   // total_wg_size must be power of 2 and >= 4;
32   int total_wg_size = 256;
33   if (gpu_info.IsAdreno() && gpu_info.adreno_info.IsAdreno3xx()) {
34     total_wg_size = 128;
35   }
36   if (gpu_info.IsMali()) {
37     const MaliInfo& mali_info = gpu_info.mali_info;
38     if (mali_info.IsMaliT6xx() || mali_info.IsMaliT7xx() ||
39         mali_info.IsMaliT8xx()) {
40       total_wg_size = 32;
41     } else {
42       total_wg_size = 64;
43     }
44   }
45   return total_wg_size;
46 }
47 
HasAxis(const std::vector<Axis> & axis,Axis a)48 bool HasAxis(const std::vector<Axis>& axis, Axis a) {
49   for (const auto& a2 : axis) {
50     if (a2 == a) {
51       return true;
52     }
53   }
54   return false;
55 }
56 
MakeOp(OperationType op_type,const std::string & a,const std::string & b)57 std::string MakeOp(OperationType op_type, const std::string& a,
58                    const std::string& b) {
59   if (op_type == OperationType::REDUCE_SUM || op_type == OperationType::MEAN) {
60     return "((" + a + ") + (" + b + "))";
61   } else if (op_type == OperationType::REDUCE_PRODUCT) {
62     return "((" + a + ") * (" + b + "))";
63   } else if (op_type == OperationType::REDUCE_MAXIMUM) {
64     return "max(" + a + ", " + b + ")";
65   } else if (op_type == OperationType::REDUCE_MINIMUM) {
66     return "min(" + a + ", " + b + ")";
67   }
68   return "UnsupportedOperation";
69 }
70 
71 // max_total_wg_size is pot
GetMaximumPossibleWGSize(const std::vector<int> & ordered_sizes,int max_total_wg_size)72 int3 GetMaximumPossibleWGSize(const std::vector<int>& ordered_sizes,
73                               int max_total_wg_size) {
74   int3 wg_size = int3(1, 1, 1);
75   int wg_size_total = 1;
76   for (int i = ordered_sizes.size() - 1; i >= 0; i--) {
77     const int wg_index = ordered_sizes.size() - 1 - i;
78     if (wg_index >= 3) {
79       return wg_size;
80     }
81     while (ordered_sizes[i] >= wg_size[wg_index] * 2) {
82       wg_size_total *= 2;
83       if (wg_size_total > max_total_wg_size) {
84         return wg_size;
85       }
86       wg_size[wg_index] *= 2;
87     }
88   }
89   return wg_size;
90 }
91 
GetSizesFromShape(const std::set<Axis> & axis,const BHWC & shape)92 std::map<Axis, int> GetSizesFromShape(const std::set<Axis>& axis,
93                                       const BHWC& shape) {
94   std::map<Axis, int> result;
95   for (auto a : axis) {
96     result[a] = shape.get(a);
97   }
98   return result;
99 }
100 
GetSizesFromShape(const std::set<Axis> & axis,const BHWDC & shape)101 std::map<Axis, int> GetSizesFromShape(const std::set<Axis>& axis,
102                                       const BHWDC& shape) {
103   std::map<Axis, int> result;
104   for (auto a : axis) {
105     result[a] = shape.get(a);
106   }
107   return result;
108 }
109 
110 }  // namespace
111 
Reduce(const std::map<Axis,int> & axis_to_reduce,OperationType op_type,const OperationDef & definition,const GpuInfo & gpu_info)112 Reduce::Reduce(const std::map<Axis, int>& axis_to_reduce, OperationType op_type,
113                const OperationDef& definition, const GpuInfo& gpu_info)
114     : GPUOperation(definition) {
115   std::vector<Axis> ordered_axis_to_reduce;
116   std::vector<int> ordered_sizes;
117   for (const auto& a :
118        {Axis::CHANNELS, Axis::DEPTH, Axis::HEIGHT, Axis::WIDTH, Axis::BATCH}) {
119     auto it = axis_to_reduce.find(a);
120     if (it != axis_to_reduce.end()) {
121       ordered_axis_to_reduce.push_back(it->first);
122       int reduction_size = it->second;
123       if (a == Axis::CHANNELS) {
124         reduction_size = DivideRoundUp(reduction_size, 4);
125       }
126       ordered_sizes.push_back(reduction_size);
127     }
128   }
129   const int max_total_wg_size = GetMaximumWGTotalSize(gpu_info);
130   int3 current_wg_size =
131       GetMaximumPossibleWGSize(ordered_sizes, max_total_wg_size);
132   int current_wg_size_total =
133       current_wg_size.x * current_wg_size.y * current_wg_size.z;
134   int threshold = max_total_wg_size / 4;
135   if (gpu_info.IsApple()) {
136     threshold = 16;
137   }
138   if (current_wg_size_total < threshold) {
139     use_wg_reduction_ = false;
140   } else {
141     use_wg_reduction_ = true;
142     work_group_size_ = current_wg_size;
143   }
144   code_ = GetReduceKernelCode(definition_, work_group_size_,
145                               ordered_axis_to_reduce, op_type);
146 }
147 
Reduce(Reduce && operation)148 Reduce::Reduce(Reduce&& operation)
149     : GPUOperation(std::move(operation)),
150       use_wg_reduction_(operation.use_wg_reduction_) {}
151 
operator =(Reduce && operation)152 Reduce& Reduce::operator=(Reduce&& operation) {
153   if (this != &operation) {
154     use_wg_reduction_ = operation.use_wg_reduction_;
155     GPUOperation::operator=(std::move(operation));
156   }
157   return *this;
158 }
159 
GetReduceKernelCode(const OperationDef & op_def,const int3 & work_group_size,const std::vector<Axis> & axis_to_reduce,OperationType op_type)160 std::string Reduce::GetReduceKernelCode(const OperationDef& op_def,
161                                         const int3& work_group_size,
162                                         const std::vector<Axis>& axis_to_reduce,
163                                         OperationType op_type) {
164   AddSrcTensor("src_tensor", op_def.src_tensors[0]);
165   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
166   args_.AddFloat("inv_multiplier_1");
167   args_.AddFloat("inv_multiplier_2");
168   args_.AddFloat("mask_x");
169   args_.AddFloat("mask_y");
170   args_.AddFloat("mask_z");
171   args_.AddFloat("mask_w");
172 
173   std::set<Axis> axis_to_leave;
174   const std::vector<Axis> all_axis = {Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH,
175                                       Axis::CHANNELS, Axis::BATCH};
176   for (const auto& a : all_axis) {
177     if (op_def.dst_tensors[0].HasAxis(a)) {
178       if (!HasAxis(axis_to_reduce, a)) {
179         axis_to_leave.insert(a);
180       }
181     }
182   }
183   const bool channels_reductin = HasAxis(axis_to_reduce, Axis::CHANNELS);
184   int wg_dims = 0;
185   if (use_wg_reduction_) {
186     if (work_group_size.y == 1 && work_group_size.z == 1) {
187       wg_dims = 1;
188     } else if (work_group_size.z == 1) {
189       wg_dims = 2;
190     } else {
191       wg_dims = 3;
192     }
193   }
194 
195   auto get_global_id = [&](int i) {
196     if (use_wg_reduction_) {
197       return "GROUP_ID_" + std::to_string(i);
198     } else {
199       return "GLOBAL_ID_" + std::to_string(i);
200     }
201   };
202 
203   std::string c;
204   const std::string wg_x = std::to_string(work_group_size.x);
205   const std::string wg_y = std::to_string(work_group_size.y);
206   const std::string wg_z = std::to_string(work_group_size.z);
207   const int wg_total_size =
208       work_group_size.x * work_group_size.y * work_group_size.z;
209   c += "MAIN_FUNCTION($0) {\n";
210   if (use_wg_reduction_) {
211     c += "  __local float4 accum[" + std::to_string(wg_total_size) + "];\n";
212     if (wg_dims == 1) {
213       c += "  int local_x = LOCAL_ID_0;\n";
214       c += "  int local_id = local_x;\n";
215     } else if (wg_dims == 2) {
216       c += "  int local_x = LOCAL_ID_0;\n";
217       c += "  int local_y = LOCAL_ID_1;\n";
218       c += "  int local_id = local_y * " + wg_x + " + local_x;\n";
219     } else if (wg_dims == 3) {
220       c += "  int local_x = LOCAL_ID_0;\n";
221       c += "  int local_y = LOCAL_ID_1;\n";
222       c += "  int local_z = LOCAL_ID_2;\n";
223       c += "  int local_id = (local_z * " + wg_y + " + local_y) * " + wg_x +
224            " + local_x;\n";
225     }
226   }
227   if (axis_to_leave.count(Axis::WIDTH)) {
228     if (axis_to_leave.count(Axis::BATCH)) {
229       c += "  int linear_id = " + get_global_id(0) + ";\n";
230       c += "  int DST_X = linear_id / args.dst_tensor.Batch();\n";
231       c += "  int DST_B = linear_id % args.dst_tensor.Batch();\n";
232     } else {
233       c += "  int DST_X = " + get_global_id(0) + ";\n";
234     }
235   } else if (axis_to_leave.count(Axis::BATCH)) {
236     c += "  int DST_B = " + get_global_id(0) + ";\n";
237   }
238   if (axis_to_leave.count(Axis::HEIGHT)) {
239     if (axis_to_leave.count(Axis::DEPTH)) {
240       c += "  int linear_id = " + get_global_id(1) + ";\n";
241       c += "  int DST_Y = linear_id % args.dst_tensor.Height();\n";
242       c += "  int DST_Z = linear_id / args.dst_tensor.Height();\n";
243     } else {
244       c += "  int DST_Y = " + get_global_id(1) + ";\n";
245     }
246   } else if (axis_to_leave.count(Axis::DEPTH)) {
247     c += "  int DST_Z = " + get_global_id(1) + ";\n";
248   }
249   if (axis_to_leave.count(Axis::CHANNELS)) {
250     c += "  int DST_S = " + get_global_id(2) + ";\n";
251   }
252   std::map<Axis, std::string> axis_to_selector = {
253       {Axis::BATCH, "Batch()"},     {Axis::WIDTH, "Width()"},
254       {Axis::HEIGHT, "Height()"},   {Axis::DEPTH, "Depth()"},
255       {Axis::CHANNELS, "Slices()"},
256   };
257   std::map<Axis, std::string> axis_to_coord = {
258       {Axis::BATCH, "B"}, {Axis::WIDTH, "X"},    {Axis::HEIGHT, "Y"},
259       {Axis::DEPTH, "Z"}, {Axis::CHANNELS, "S"},
260   };
261   std::string dst_check;
262   for (auto& axis : axis_to_leave) {
263     if (!dst_check.empty()) {
264       dst_check += " || ";
265     }
266     dst_check += "DST_" + axis_to_coord[axis] + " >= args.dst_tensor." +
267                  axis_to_selector[axis];
268   }
269   if (!dst_check.empty()) {
270     c += "  if (" + dst_check + ") return;\n";
271   }
272   std::map<Axis, std::string> src_coords;
273   for (const auto& a : all_axis) {
274     if (op_def.dst_tensors[0].HasAxis(a) && !HasAxis(axis_to_reduce, a)) {
275       src_coords[a] = "DST_" + axis_to_coord[a];
276     } else {
277       src_coords[a] = "0";
278     }
279   }
280   std::string src_coordinates;
281   for (const auto& a : all_axis) {
282     if (op_def.src_tensors[0].HasAxis(a)) {
283       if (!src_coordinates.empty()) {
284         src_coordinates += ", ";
285       }
286       src_coordinates += src_coords[a];
287     }
288   }
289   if (op_type == OperationType::REDUCE_SUM || op_type == OperationType::MEAN) {
290     c += "  float4 reducer = INIT_FLOAT4(0.0f);\n";
291   } else if (op_type == OperationType::REDUCE_PRODUCT) {
292     c += "  float4 reducer = INIT_FLOAT4(1.0f);\n";
293   } else if (op_type == OperationType::REDUCE_MAXIMUM ||
294              op_type == OperationType::REDUCE_MINIMUM) {
295     c += "  float4 reducer = args.src_tensor.Read<float>(" + src_coordinates +
296          ");\n";
297     if (channels_reductin) {
298       c += "  reducer.y = reducer.x;\n";
299       c += "  reducer.z = reducer.x;\n";
300       c += "  reducer.w = reducer.x;\n";
301     }
302   }
303   const std::vector<std::string> local_ids = {"local_x", "local_y", "local_z"};
304   const std::vector<std::string> local_sizes = {wg_x, wg_y, wg_z};
305   for (int i = 0; i < axis_to_reduce.size(); ++i) {
306     const auto& axis = axis_to_reduce[i];
307     const int index = axis_to_reduce.size() - 1 - i;
308     const std::string first = index < wg_dims ? local_ids[index] : "0";
309     const std::string step = index < wg_dims ? local_sizes[index] : "1";
310     const std::string src_coord = "SRC_" + axis_to_coord[axis];
311     src_coords[axis] = src_coord;
312     c += "  for (int " + src_coord + " = " + first + "; " + src_coord +
313          " < args.src_tensor." + axis_to_selector[axis] + "; " + src_coord +
314          " += " + step + ") {\n";
315     if (axis == Axis::CHANNELS) {
316       c += "    bool last = SRC_S == args.src_tensor.Slices() - 1;\n";
317       c += "    float4 mask_a = last ? INIT_FLOAT4v4(args.mask_x, args.mask_y, "
318            "args.mask_z, args.mask_w) : INIT_FLOAT4(1.0f);\n";
319       if (op_type == OperationType::REDUCE_PRODUCT ||
320           op_type == OperationType::REDUCE_MAXIMUM ||
321           op_type == OperationType::REDUCE_MINIMUM) {
322         c += "    float4 mask_b = INIT_FLOAT4(1.0f) - mask_a;\n";
323       }
324     }
325   }
326   src_coordinates = "";
327   for (const auto& a : all_axis) {
328     if (op_def.src_tensors[0].HasAxis(a)) {
329       if (!src_coordinates.empty()) {
330         src_coordinates += ", ";
331       }
332       src_coordinates += src_coords[a];
333     }
334   }
335   c += "    float4 src_val = args.src_tensor.Read<float>(" + src_coordinates +
336        ");\n";
337   if (channels_reductin) {
338     if (op_type == OperationType::REDUCE_SUM ||
339         op_type == OperationType::MEAN) {
340       c += "    src_val = src_val * mask_a;\n";
341     } else if (op_type == OperationType::REDUCE_PRODUCT) {
342       c += "    src_val = src_val * mask_a + mask_b;\n";
343     } else if (op_type == OperationType::REDUCE_MAXIMUM ||
344                op_type == OperationType::REDUCE_MINIMUM) {
345       c += "    src_val = src_val * mask_a + mask_b * src_val.x;\n";
346     }
347   }
348   c += "    reducer = " + MakeOp(op_type, "reducer", "src_val") + ";\n";
349   for (int i = 0; i < axis_to_reduce.size(); ++i) {
350     c += "  }\n";
351   }
352   if (op_type == OperationType::MEAN) {
353     c += "  reducer *= args.inv_multiplier_1;\n";
354   }
355   if (use_wg_reduction_) {
356     c += "  accum[local_id] = reducer;\n";
357     c += "  LOCAL_MEM_BARRIER;\n";
358     const int total_size =
359         work_group_size.x * work_group_size.y * work_group_size.z;
360     int offset = 1;
361     int reminder = total_size / 4;
362     for (; reminder >= 8; reminder /= 4, offset *= 4) {
363       c += "  if (local_id < " + std::to_string(reminder) + ") {\n";
364       c += "    int t = local_id * " + std::to_string(offset * 4) + ";\n";
365       c += "    float4 sum = accum[t + " + std::to_string(offset) + "];\n";
366       c += "    sum = " +
367            MakeOp(op_type, "sum",
368                   "accum[t + " + std::to_string(offset * 2) + "]") +
369            ";\n";
370       c += "    sum = " +
371            MakeOp(op_type, "sum",
372                   "accum[t + " + std::to_string(offset * 3) + "]") +
373            ";\n";
374       c += "    accum[t] = " + MakeOp(op_type, "accum[t]", "sum") + ";\n";
375       c += "  }\n";
376       c += "  LOCAL_MEM_BARRIER;\n";
377     }
378     c += "  reducer = accum[0];\n";
379     reminder *= 4;
380     for (int i = 1; i < reminder; ++i) {
381       c += "  reducer = " +
382            MakeOp(op_type, "reducer",
383                   "accum[" + std::to_string(offset * i) + "]") +
384            ";\n";
385     }
386     if (op_type == OperationType::MEAN) {
387       c += "  reducer *= args.inv_multiplier_2;\n";
388     }
389   }
390   if (channels_reductin) {
391     if (op_type == OperationType::REDUCE_SUM ||
392         op_type == OperationType::MEAN) {
393       c += "  reducer.x += reducer.y + reducer.z + reducer.w;\n";
394     } else if (op_type == OperationType::REDUCE_PRODUCT) {
395       c += "  reducer.x *= reducer.y * reducer.z * reducer.w;\n";
396     } else if (op_type == OperationType::REDUCE_MAXIMUM) {
397       c += "  reducer.x = max(reducer.x, reducer.y);\n";
398       c += "  reducer.x = max(reducer.x, reducer.z);\n";
399       c += "  reducer.x = max(reducer.x, reducer.w);\n";
400     } else if (op_type == OperationType::REDUCE_MINIMUM) {
401       c += "  reducer.x = min(reducer.x, reducer.y);\n";
402       c += "  reducer.x = min(reducer.x, reducer.z);\n";
403       c += "  reducer.x = min(reducer.x, reducer.w);\n";
404     }
405   }
406   c += "  FLT4 result = TO_FLT4(reducer);\n";
407   std::string dst_coordinates;
408   for (const auto& a : all_axis) {
409     if (op_def.dst_tensors[0].HasAxis(a)) {
410       if (!dst_coordinates.empty()) {
411         dst_coordinates += ", ";
412       }
413       if (axis_to_leave.count(a)) {
414         dst_coordinates += "DST_" + axis_to_coord[a];
415       } else {
416         dst_coordinates += "0";
417       }
418     }
419   }
420   c += "  args.dst_tensor.Write(result, " + dst_coordinates + ");\n";
421   c += "}\n";
422   return c;
423 }
424 
BindArguments(ArgumentsBinder * args)425 absl::Status Reduce::BindArguments(ArgumentsBinder* args) {
426   const double total_src_elements = 1.0 * src_[0]->Batch() * src_[0]->Width() *
427                                     src_[0]->Height() * src_[0]->Depth() *
428                                     src_[0]->Channels();
429   const double total_dst_elements = 1.0 * dst_[0]->Batch() * dst_[0]->Width() *
430                                     dst_[0]->Height() * dst_[0]->Depth() *
431                                     dst_[0]->Channels();
432   const double reduction_size = total_src_elements / total_dst_elements;
433   if (use_wg_reduction_) {
434     const double size_0 =
435         work_group_size_.x * work_group_size_.y * work_group_size_.z;
436     const double size_1 = reduction_size / size_0;
437     RETURN_IF_ERROR(args->SetFloat("inv_multiplier_1", 1.0 / size_1));
438     RETURN_IF_ERROR(args->SetFloat("inv_multiplier_2", 1.0 / size_0));
439   } else {
440     RETURN_IF_ERROR(args->SetFloat("inv_multiplier_1", 1.0 / reduction_size));
441     RETURN_IF_ERROR(args->SetFloat("inv_multiplier_2", 1.0));
442   }
443   float4 mask = GetMaskForLastPlane(src_[0]->Channels());
444   RETURN_IF_ERROR(args->SetFloat("mask_x", mask.x));
445   RETURN_IF_ERROR(args->SetFloat("mask_y", mask.y));
446   RETURN_IF_ERROR(args->SetFloat("mask_z", mask.z));
447   RETURN_IF_ERROR(args->SetFloat("mask_w", mask.w));
448   return absl::OkStatus();
449 }
450 
GetGridSize() const451 int3 Reduce::GetGridSize() const {
452   int grid_x = dst_[0]->Width() * dst_[0]->Batch();
453   int grid_y = dst_[0]->Height() * dst_[0]->Depth();
454   int grid_z = dst_[0]->Slices();
455   if (use_wg_reduction_) {
456     grid_x *= work_group_size_.x;
457     grid_y *= work_group_size_.y;
458     grid_z *= work_group_size_.z;
459   }
460   return int3(grid_x, grid_y, grid_z);
461 }
462 
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const463 void Reduce::GetPossibleKernelWorkGroups(TuningType tuning_type,
464                                          const GpuInfo& gpu_info,
465                                          const KernelInfo& kernel_info,
466                                          std::vector<int3>* work_groups) const {
467   if (use_wg_reduction_) {
468     work_groups->push_back(work_group_size_);
469   } else {
470     GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
471                           work_groups);
472   }
473 }
474 
CreateReduce(const std::set<Axis> & axis_to_reduce,const BHWC & src_shape,OperationType op_type,const OperationDef & definition,const GpuInfo & gpu_info)475 Reduce CreateReduce(const std::set<Axis>& axis_to_reduce, const BHWC& src_shape,
476                     OperationType op_type, const OperationDef& definition,
477                     const GpuInfo& gpu_info) {
478   return Reduce(GetSizesFromShape(axis_to_reduce, src_shape), op_type,
479                 definition, gpu_info);
480 }
481 
CreateReduce(const std::set<Axis> & axis_to_reduce,const BHWDC & src_shape,OperationType op_type,const OperationDef & definition,const GpuInfo & gpu_info)482 Reduce CreateReduce(const std::set<Axis>& axis_to_reduce,
483                     const BHWDC& src_shape, OperationType op_type,
484                     const OperationDef& definition, const GpuInfo& gpu_info) {
485   return Reduce(GetSizesFromShape(axis_to_reduce, src_shape), op_type,
486                 definition, gpu_info);
487 }
488 
489 }  // namespace gpu
490 }  // namespace tflite
491