1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/depthwise_conv_3x3_stride_h2.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "tensorflow/lite/delegates/gpu/common/status.h"
22 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
23
24 namespace tflite {
25 namespace gpu {
26 namespace {
GetKernelDepthWiseConv3x3StrideH2(const OperationDef & definition,bool weights_are_buffer,bool local_mem_uploads)27 std::string GetKernelDepthWiseConv3x3StrideH2(const OperationDef& definition,
28 bool weights_are_buffer,
29 bool local_mem_uploads) {
30 const auto src_tensor_type = definition.src_tensors[0].storage_type;
31 const bool manual_clamp = src_tensor_type == TensorStorageType::BUFFER ||
32 src_tensor_type == TensorStorageType::IMAGE_BUFFER;
33
34 std::string c = "MAIN_FUNCTION($0) {\n";
35 if (definition.dst_tensors[0].HasAxis(Axis::BATCH)) {
36 c += " int linear_id = GLOBAL_ID_0;\n";
37 c += " int X = linear_id / args.dst_tensor.Batch();\n";
38 c += " int B = linear_id % args.dst_tensor.Batch();\n";
39 c += " args.dst_tensor.SetBatchRef(B);\n";
40 c += " args.src_tensor.SetBatchRef(B);\n";
41 } else {
42 c += " int X = GLOBAL_ID_0;\n";
43 }
44 c += R"(
45 int Y = GLOBAL_ID_1 * 2;
46 int S = GLOBAL_ID_2;
47
48 ACCUM_FLT4 r0 = INIT_ACCUM_FLT4(0.0f);
49 ACCUM_FLT4 l0 = INIT_ACCUM_FLT4(0.0f);
50 )";
51 if (local_mem_uploads) {
52 c += " __local FLT4 f[10];\n";
53 c += " int local_id = LOCAL_ID_1 * 8 + LOCAL_ID_0;\n";
54 c += " if (local_id < 10) {\n";
55 c += " f[local_id] = args.weights.Read(S * 10 + local_id);\n";
56 c += " }\n";
57 c += " LOCAL_MEM_BARRIER;\n";
58 } else if (weights_are_buffer) {
59 c += " __global FLT4* f = args.weights.GetPtr() + S * 10;\n";
60 }
61 c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
62 "|| S >= args.dst_tensor.Slices()) { \n";
63 c += " return; \n";
64 c += " } \n";
65 c += " FLT4 s0, s1, s2;\n";
66 c += " int x0 = X * args.stride_x + args.padding_x;\n";
67 c += " int x1 = X * args.stride_x + args.padding_x + args.dilation_x;\n";
68 c += " int x2 = X * args.stride_x + args.padding_x + 2 * args.dilation_x;\n";
69 c += " int y0 = Y * 2 + args.padding_y;\n";
70 c += " int y1 = Y * 2 + args.padding_y + 1;\n";
71 c += " int y2 = Y * 2 + args.padding_y + 2;\n";
72 c += " int y3 = Y * 2 + args.padding_y + 3;\n";
73 c += " int y4 = Y * 2 + args.padding_y + 4;\n";
74 std::string W[9] = {"f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8"};
75 std::string bias = "bias";
76 if (!weights_are_buffer) {
77 c += " FLT4 f0 = args.weights.Read(0, S);\n";
78 c += " FLT4 f1 = args.weights.Read(1, S);\n";
79 c += " FLT4 f2 = args.weights.Read(2, S);\n";
80 c += " FLT4 f3 = args.weights.Read(3, S);\n";
81 c += " FLT4 f4 = args.weights.Read(4, S);\n";
82 c += " FLT4 f5 = args.weights.Read(5, S);\n";
83 c += " FLT4 f6 = args.weights.Read(6, S);\n";
84 c += " FLT4 f7 = args.weights.Read(7, S);\n";
85 c += " FLT4 f8 = args.weights.Read(8, S);\n";
86 }
87 if (manual_clamp) {
88 c += " bool x0_in = x0 >= 0 && x0 < args.src_tensor.Width();\n";
89 c += " bool x1_in = x1 >= 0 && x1 < args.src_tensor.Width();\n";
90 c += " bool x2_in = x2 >= 0 && x2 < args.src_tensor.Width();\n";
91 c += " bool y0_in = y0 >= 0 && y0 < args.src_tensor.Height();\n";
92 c += " bool y1_in = y1 >= 0 && y1 < args.src_tensor.Height();\n";
93 c += " bool y2_in = y2 >= 0 && y2 < args.src_tensor.Height();\n";
94 c += " bool y3_in = y3 >= 0 && y3 < args.src_tensor.Height();\n";
95 c += " bool y4_in = y4 >= 0 && y4 < args.src_tensor.Height();\n";
96 c += " x0 = clamp(x0, 0, args.src_tensor.Width() - 1);\n";
97 c += " x1 = clamp(x1, 0, args.src_tensor.Width() - 1);\n";
98 c += " x2 = clamp(x2, 0, args.src_tensor.Width() - 1);\n";
99 c += " y0 = clamp(y0, 0, args.src_tensor.Height() - 1);\n";
100 c += " y1 = clamp(y1, 0, args.src_tensor.Height() - 1);\n";
101 c += " y2 = clamp(y2, 0, args.src_tensor.Height() - 1);\n";
102 c += " y3 = clamp(y3, 0, args.src_tensor.Height() - 1);\n";
103 c += " y4 = clamp(y4, 0, args.src_tensor.Height() - 1);\n";
104 if (src_tensor_type == TensorStorageType::BUFFER) {
105 c += " __global FLT4* src_loc = "
106 "args.src_tensor.GetPtrWithSliceOffset(S);\n";
107 }
108 }
109 if (local_mem_uploads || weights_are_buffer) {
110 W[0] = "f[0]";
111 W[1] = "f[1]";
112 W[2] = "f[2]";
113 W[3] = "f[3]";
114 W[4] = "f[4]";
115 W[5] = "f[5]";
116 W[6] = "f[6]";
117 W[7] = "f[7]";
118 W[8] = "f[8]";
119 bias = "f[9]";
120 }
121 auto read_3x_line = [&](int y) {
122 const std::string yc = "y" + std::to_string(y);
123 if (src_tensor_type == TensorStorageType::BUFFER) {
124 const std::string y_in = "y" + std::to_string(y) + "_in";
125 c += " s0 = src_loc[args.src_tensor.GetWHOffset(x0, " + yc +
126 ")] * INIT_FLT(x0_in && " + y_in + ");\n";
127 c += " s1 = src_loc[args.src_tensor.GetWHOffset(x1, " + yc +
128 ")] * INIT_FLT(x1_in && " + y_in + ");\n";
129 c += " s2 = src_loc[args.src_tensor.GetWHOffset(x2, " + yc +
130 ")] * INIT_FLT(x2_in && " + y_in + ");\n";
131 } else if (src_tensor_type == TensorStorageType::IMAGE_BUFFER) {
132 const std::string y_in = "y" + std::to_string(y) + "_in";
133 c += " s0 = args.src_tensor.Read(x0, " + yc +
134 ", S) * INIT_FLT(x0_in && " + y_in + ");\n";
135 c += " s1 = args.src_tensor.Read(x1, " + yc +
136 ", S) * INIT_FLT(x1_in && " + y_in + ");\n";
137 c += " s2 = args.src_tensor.Read(x2, " + yc +
138 ", S) * INIT_FLT(x2_in && " + y_in + ");\n";
139 } else {
140 c += " s0 = args.src_tensor.Read(x0, " + yc + ", S);\n";
141 c += " s1 = args.src_tensor.Read(x1, " + yc + ", S);\n";
142 c += " s2 = args.src_tensor.Read(x2, " + yc + ", S);\n";
143 }
144 };
145 read_3x_line(0);
146 c += " r0 += TO_ACCUM_TYPE(" + W[0] + " * s0);\n";
147 c += " r0 += TO_ACCUM_TYPE(" + W[1] + " * s1);\n";
148 c += " r0 += TO_ACCUM_TYPE(" + W[2] + " * s2);\n";
149 read_3x_line(1);
150 c += " r0 += TO_ACCUM_TYPE(" + W[3] + " * s0);\n";
151 c += " r0 += TO_ACCUM_TYPE(" + W[4] + " * s1);\n";
152 c += " r0 += TO_ACCUM_TYPE(" + W[5] + " * s2);\n";
153 read_3x_line(2);
154 c += " r0 += TO_ACCUM_TYPE(" + W[6] + " * s0);\n";
155 c += " r0 += TO_ACCUM_TYPE(" + W[7] + " * s1);\n";
156 c += " r0 += TO_ACCUM_TYPE(" + W[8] + " * s2);\n";
157 c += " l0 += TO_ACCUM_TYPE(" + W[0] + " * s0);\n";
158 c += " l0 += TO_ACCUM_TYPE(" + W[1] + " * s1);\n";
159 c += " l0 += TO_ACCUM_TYPE(" + W[2] + " * s2);\n";
160 read_3x_line(3);
161 c += " l0 += TO_ACCUM_TYPE(" + W[3] + " * s0);\n";
162 c += " l0 += TO_ACCUM_TYPE(" + W[4] + " * s1);\n";
163 c += " l0 += TO_ACCUM_TYPE(" + W[5] + " * s2);\n";
164 read_3x_line(4);
165 c += " l0 += TO_ACCUM_TYPE(" + W[6] + " * s0);\n";
166 c += " l0 += TO_ACCUM_TYPE(" + W[7] + " * s1);\n";
167 c += " l0 += TO_ACCUM_TYPE(" + W[8] + " * s2);\n";
168 if (!weights_are_buffer) {
169 c += " FLT4 bias = args.weights.Read(9, S);\n";
170 }
171 c += " r0 += TO_ACCUM_TYPE(" + bias + ");\n";
172 c += " l0 += TO_ACCUM_TYPE(" + bias + ");\n";
173 c += R"(
174 if (Y < args.dst_tensor.Height()) {
175 FLT4 value = TO_FLT4(r0);
176 args.dst_tensor.Write(value, X, Y, S);
177 }
178 if (Y + 1 < args.dst_tensor.Height()) {
179 FLT4 value = TO_FLT4(l0);
180 args.dst_tensor.Write(value, X, Y + 1, S);
181 }
182 }
183 )";
184
185 return c;
186 }
187
188 } // namespace
189
GetGridSize() const190 int3 DepthWiseConv3x3StrideH2::GetGridSize() const {
191 const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
192 const int grid_y = DivideRoundUp(dst_[0]->Height(), 2);
193 const int grid_z = dst_[0]->Slices();
194 return int3(grid_x, grid_y, grid_z);
195 }
196
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const197 void DepthWiseConv3x3StrideH2::GetPossibleKernelWorkGroups(
198 TuningType tuning_type, const GpuInfo& gpu_info,
199 const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
200 if (local_mem_uploads_) {
201 work_groups->push_back(work_group_size_);
202 } else {
203 GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
204 work_groups);
205 }
206 }
207
CreateDepthWiseConv3x3StrideH2(const OperationDef & definition,const DepthwiseConvolution2DAttributes & attr,const GpuInfo & gpu_info)208 DepthWiseConv3x3StrideH2 CreateDepthWiseConv3x3StrideH2(
209 const OperationDef& definition,
210 const DepthwiseConvolution2DAttributes& attr, const GpuInfo& gpu_info) {
211 bool weights_are_buffer = !gpu_info.SupportsImages() ||
212 gpu_info.IsPowerVR() || gpu_info.IsMali() ||
213 gpu_info.IsApple();
214
215 DepthWiseConv3x3StrideH2 desc(definition);
216 desc.local_mem_uploads_ = weights_are_buffer && gpu_info.IsPowerVR();
217 desc.work_group_size_ = int3(8, 4, 1);
218 desc.code_ = GetKernelDepthWiseConv3x3StrideH2(definition, weights_are_buffer,
219 desc.local_mem_uploads_);
220 auto src_desc = definition.src_tensors[0];
221 src_desc.SetAddressMode(AddressMode::kZero);
222 desc.AddSrcTensor("src_tensor", src_desc);
223 desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
224
225 desc.args_.AddInt("padding_x", -attr.padding.prepended.w);
226 desc.args_.AddInt("padding_y", -attr.padding.prepended.h);
227 desc.args_.AddInt("stride_x", attr.strides.w);
228 desc.args_.AddInt("dilation_x", attr.dilations.w);
229
230 desc.UploadWeightsAndBiases(attr.weights, attr.bias, weights_are_buffer);
231 return desc;
232 }
233
IsDepthWiseConv3x3StrideH2Supported(const DepthwiseConvolution2DAttributes & attr)234 bool IsDepthWiseConv3x3StrideH2Supported(
235 const DepthwiseConvolution2DAttributes& attr) {
236 return attr.weights.shape.o == 1 && attr.weights.shape.h == 3 &&
237 attr.weights.shape.w == 3 && attr.strides.h == 2 &&
238 attr.dilations.h == 1;
239 }
240
241 } // namespace gpu
242 } // namespace tflite
243