• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/split.h"
17 
18 #include <string>
19 
20 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
21 
22 namespace tflite {
23 namespace gpu {
24 
Split(const OperationDef & definition,const SplitAttributes & attr)25 Split::Split(const OperationDef& definition, const SplitAttributes& attr)
26     : GPUOperation(definition), attr_(attr) {
27   work_group_size_ = int3(8, 4, 1);
28   code_ = attr.axis == Axis::CHANNELS ? GetSplitChannelsCode() : GetSplitCode();
29 }
30 
GetSplitCode()31 std::string Split::GetSplitCode() {
32   AddSrcTensor("src_tensor", definition_.src_tensors[0]);
33   for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
34     AddDstTensor("dst_tensor_" + std::to_string(i), definition_.dst_tensors[i]);
35   }
36   const std::string task_width =
37       attr_.axis == Axis::WIDTH ? "1" : "args.src_tensor.Width()";
38   const std::string task_height =
39       attr_.axis == Axis::HEIGHT ? "1" : "args.src_tensor.Height()";
40   const std::string task_depth =
41       attr_.axis == Axis::DEPTH ? "1" : "args.src_tensor.Depth()";
42   const std::string task_batch =
43       attr_.axis == Axis::BATCH ? "1" : "args.src_tensor.Batch()";
44   const std::string task_slices =
45       attr_.axis == Axis::CHANNELS ? "1" : "args.src_tensor.Slices()";
46 
47   std::map<Axis, std::string> axis_to_selector = {
48       {Axis::WIDTH, "Width"}, {Axis::HEIGHT, "Height"},
49       {Axis::DEPTH, "Depth"}, {Axis::CHANNELS, "Slices"},
50       {Axis::BATCH, "Batch"},
51   };
52   std::map<Axis, std::string> axis_to_coord = {
53       {Axis::WIDTH, "X"},    {Axis::HEIGHT, "Y"}, {Axis::DEPTH, "D"},
54       {Axis::CHANNELS, "S"}, {Axis::BATCH, "B"},
55   };
56 
57   std::string c;
58   c += "MAIN_FUNCTION($0) {\n";
59   if (definition_.src_tensors[0].HasAxis(Axis::BATCH)) {
60     c += "  int linear_id = GLOBAL_ID_0;\n";
61     c += "  int X = linear_id / " + task_batch + ";\n";
62     c += "  int B = linear_id % " + task_batch + ";\n";
63     c += "  if (X >= " + task_width + ") return;\n";
64   } else {
65     c += "  int X = GLOBAL_ID_0;\n";
66     c += "  if (X >= " + task_width + ") return;\n";
67   }
68   if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
69     c += "  int linear_id = GLOBAL_ID_1;\n";
70     c += "  int Y = linear_id % " + task_height + ";\n";
71     c += "  int D = linear_id / " + task_height + ";\n";
72     c += "  if (D >= " + task_depth + ") return;\n";
73   } else {
74     c += "  int Y = GLOBAL_ID_1;\n";
75     c += "  if (Y >= " + task_height + ") return;\n";
76   }
77   c += "  int S = GLOBAL_ID_2;\n";
78   c += "  if (S >= " + task_slices + ") return;\n";
79   c += "  int src_counter = 0;\n";
80   std::vector<std::string> src_coords;
81   for (auto axis :
82        {Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH, Axis::CHANNELS, Axis::BATCH}) {
83     if (definition_.src_tensors[0].HasAxis(axis)) {
84       const std::string coord_name =
85           attr_.axis == axis ? "src_counter" : axis_to_coord[axis];
86       src_coords.push_back(coord_name);
87     }
88   }
89   std::string src_coords_str = src_coords[0];
90   for (int i = 1; i < src_coords.size(); ++i) {
91     src_coords_str += ", " + src_coords[i];
92   }
93   for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
94     std::vector<std::string> dst_coords;
95     for (auto axis : {Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH, Axis::CHANNELS,
96                       Axis::BATCH}) {
97       if (definition_.dst_tensors[i].HasAxis(axis)) {
98         const std::string coord_name =
99             attr_.axis == axis ? "i" : axis_to_coord[axis];
100         dst_coords.push_back(coord_name);
101       }
102     }
103     std::string dst_coords_str = dst_coords[0];
104     for (int j = 1; j < dst_coords.size(); ++j) {
105       dst_coords_str += ", " + dst_coords[j];
106     }
107     const std::string dst_name = "args.dst_tensor_" + std::to_string(i);
108     c += "  for (int i = 0; i < " + dst_name + "." +
109          axis_to_selector[attr_.axis] + "(); ++i, src_counter++) {\n";
110     c += "    FLT4 result = args.src_tensor.Read(" + src_coords_str + ");\n";
111     c += "    " + dst_name + ".Write(result, " + dst_coords_str + ");\n";
112     c += "  }\n";
113   }
114   c += "}\n";
115   return c;
116 }
117 
GetSplitChannelsCode()118 std::string Split::GetSplitChannelsCode() {
119   AddSrcTensor("src_tensor", definition_.src_tensors[0]);
120   for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
121     AddDstTensor("dst_tensor_" + std::to_string(i), definition_.dst_tensors[i]);
122   }
123 
124   const std::string batch_coord =
125       definition_.src_tensors[0].HasAxis(Axis::BATCH) ? ", B" : "";
126   std::string coords = "X, Y";
127   std::string c;
128   c += "MAIN_FUNCTION($0) {\n";
129   if (definition_.src_tensors[0].HasAxis(Axis::BATCH)) {
130     c += "  int linear_id = GLOBAL_ID_0;\n";
131     c += "  int X = linear_id / args.src_tensor.Batch();\n";
132     c += "  int B = linear_id % args.src_tensor.Batch();\n";
133     c += "  if (X >= args.src_tensor.Width()) return;\n";
134   } else {
135     c += "  int X = GLOBAL_ID_0;\n";
136     c += "  if (X >= args.src_tensor.Width()) return;\n";
137   }
138   if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
139     c += "  int linear_id = GLOBAL_ID_1;\n";
140     c += "  int Y = linear_id % args.src_tensor.Height();\n";
141     c += "  int Z = linear_id / args.src_tensor.Height();\n";
142     c += "  if (Z >= args.src_tensor.Depth()) return;\n";
143     coords += ", Z";
144   } else {
145     c += "  int Y = GLOBAL_ID_1;\n";
146     c += "  if (Y >= args.src_tensor.Height()) return;\n";
147   }
148   c += "  int src_channel = 0;\n";
149   const std::string postfixes[] = {"x", "y", "z", "w"};
150   for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
151     const std::string dst_name = "args.dst_tensor_" + std::to_string(i);
152     c += "  for (int i = 0; i < " + dst_name + ".Slices(); ++i) {\n";
153     c += "    FLT4 result = INIT_FLT4(0.0f);\n";
154     for (int j = 0; j < 4; ++j) {
155       c += "    if (i * 4 + " + std::to_string(j) + " < " + dst_name +
156            ".Channels()) {\n";
157       c += "      int src_slice = src_channel >> 2;\n";
158       c += "      int src_sub_ch = src_channel & 3;\n";
159       c += "      FLT4 t = args.src_tensor.Read(" + coords + ", src_slice" +
160            batch_coord + ");\n";
161       c += "      result." + postfixes[j] +
162            " = SELECT_BY_INDEX_FROM_FLT4(t, src_sub_ch);\n";
163       c += "      src_channel++;\n";
164       c += "    }\n";
165     }
166     c += "    " + dst_name + ".Write(result, " + coords + ", i" + batch_coord +
167          ");\n";
168     c += "  }\n";
169   }
170   c += "}\n";
171   return c;
172 }
173 
GetGridSize() const174 int3 Split::GetGridSize() const {
175   const int width = attr_.axis == Axis::WIDTH ? 1 : src_[0]->Width();
176   const int height = attr_.axis == Axis::HEIGHT ? 1 : src_[0]->Height();
177   const int depth = attr_.axis == Axis::DEPTH ? 1 : src_[0]->Depth();
178   const int batch = attr_.axis == Axis::BATCH ? 1 : src_[0]->Batch();
179   const int slices = attr_.axis == Axis::CHANNELS ? 1 : src_[0]->Slices();
180   const int grid_x = width * batch;
181   const int grid_y = height * depth;
182   const int grid_z = slices;
183   return int3(grid_x, grid_y, grid_z);
184 }
185 
CreateSplit(const OperationDef & definition,const SplitAttributes & attr)186 Split CreateSplit(const OperationDef& definition, const SplitAttributes& attr) {
187   return Split(definition, attr);
188 }
189 
190 }  // namespace gpu
191 }  // namespace tflite
192