• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/depthwise_conv_3x3_stride_h2.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "tensorflow/lite/delegates/gpu/common/status.h"
22 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
23 
24 namespace tflite {
25 namespace gpu {
26 namespace {
GetKernelDepthWiseConv3x3StrideH2(const OperationDef & definition,bool weights_are_buffer,bool local_mem_uploads)27 std::string GetKernelDepthWiseConv3x3StrideH2(const OperationDef& definition,
28                                               bool weights_are_buffer,
29                                               bool local_mem_uploads) {
30   const auto src_tensor_type = definition.src_tensors[0].storage_type;
31   const bool manual_clamp = src_tensor_type == TensorStorageType::BUFFER ||
32                             src_tensor_type == TensorStorageType::IMAGE_BUFFER;
33 
34   std::string c = "MAIN_FUNCTION($0) {\n";
35   if (definition.dst_tensors[0].HasAxis(Axis::BATCH)) {
36     c += "  int linear_id = GLOBAL_ID_0;\n";
37     c += "  int X = linear_id / args.dst_tensor.Batch();\n";
38     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
39     c += "  args.dst_tensor.SetBatchRef(B);\n";
40     c += "  args.src_tensor.SetBatchRef(B);\n";
41   } else {
42     c += "  int X = GLOBAL_ID_0;\n";
43   }
44   c += R"(
45   int Y = GLOBAL_ID_1 * 2;
46   int S = GLOBAL_ID_2;
47 
48   ACCUM_FLT4 r0 = INIT_ACCUM_FLT4(0.0f);
49   ACCUM_FLT4 l0 = INIT_ACCUM_FLT4(0.0f);
50 )";
51   if (local_mem_uploads) {
52     c += "  __local FLT4 f[10];\n";
53     c += "  int local_id = LOCAL_ID_1 * 8 + LOCAL_ID_0;\n";
54     c += "  if (local_id < 10) {\n";
55     c += "    f[local_id] = args.weights.Read(S * 10 + local_id);\n";
56     c += "  }\n";
57     c += "  LOCAL_MEM_BARRIER;\n";
58   } else if (weights_are_buffer) {
59     c += "  __global FLT4* f = args.weights.GetPtr() + S * 10;\n";
60   }
61   c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
62        "|| S >= args.dst_tensor.Slices()) { \n";
63   c += "    return; \n";
64   c += "  } \n";
65   c += "  FLT4 s0, s1, s2;\n";
66   c += "  int x0 = X * args.stride_x + args.padding_x;\n";
67   c += "  int x1 = X * args.stride_x + args.padding_x + args.dilation_x;\n";
68   c += "  int x2 = X * args.stride_x + args.padding_x + 2 * args.dilation_x;\n";
69   c += "  int y0 = Y * 2 + args.padding_y;\n";
70   c += "  int y1 = Y * 2 + args.padding_y + 1;\n";
71   c += "  int y2 = Y * 2 + args.padding_y + 2;\n";
72   c += "  int y3 = Y * 2 + args.padding_y + 3;\n";
73   c += "  int y4 = Y * 2 + args.padding_y + 4;\n";
74   std::string W[9] = {"f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8"};
75   std::string bias = "bias";
76   if (!weights_are_buffer) {
77     c += "   FLT4 f0 = args.weights.Read(0, S);\n";
78     c += "   FLT4 f1 = args.weights.Read(1, S);\n";
79     c += "   FLT4 f2 = args.weights.Read(2, S);\n";
80     c += "   FLT4 f3 = args.weights.Read(3, S);\n";
81     c += "   FLT4 f4 = args.weights.Read(4, S);\n";
82     c += "   FLT4 f5 = args.weights.Read(5, S);\n";
83     c += "   FLT4 f6 = args.weights.Read(6, S);\n";
84     c += "   FLT4 f7 = args.weights.Read(7, S);\n";
85     c += "   FLT4 f8 = args.weights.Read(8, S);\n";
86   }
87   if (manual_clamp) {
88     c += "  bool x0_in = x0 >= 0 && x0 < args.src_tensor.Width();\n";
89     c += "  bool x1_in = x1 >= 0 && x1 < args.src_tensor.Width();\n";
90     c += "  bool x2_in = x2 >= 0 && x2 < args.src_tensor.Width();\n";
91     c += "  bool y0_in = y0 >= 0 && y0 < args.src_tensor.Height();\n";
92     c += "  bool y1_in = y1 >= 0 && y1 < args.src_tensor.Height();\n";
93     c += "  bool y2_in = y2 >= 0 && y2 < args.src_tensor.Height();\n";
94     c += "  bool y3_in = y3 >= 0 && y3 < args.src_tensor.Height();\n";
95     c += "  bool y4_in = y4 >= 0 && y4 < args.src_tensor.Height();\n";
96     c += "  x0 = clamp(x0, 0, args.src_tensor.Width() - 1);\n";
97     c += "  x1 = clamp(x1, 0, args.src_tensor.Width() - 1);\n";
98     c += "  x2 = clamp(x2, 0, args.src_tensor.Width() - 1);\n";
99     c += "  y0 = clamp(y0, 0, args.src_tensor.Height() - 1);\n";
100     c += "  y1 = clamp(y1, 0, args.src_tensor.Height() - 1);\n";
101     c += "  y2 = clamp(y2, 0, args.src_tensor.Height() - 1);\n";
102     c += "  y3 = clamp(y3, 0, args.src_tensor.Height() - 1);\n";
103     c += "  y4 = clamp(y4, 0, args.src_tensor.Height() - 1);\n";
104     if (src_tensor_type == TensorStorageType::BUFFER) {
105       c += "  __global FLT4* src_loc = "
106            "args.src_tensor.GetPtrWithSliceOffset(S);\n";
107     }
108   }
109   if (local_mem_uploads || weights_are_buffer) {
110     W[0] = "f[0]";
111     W[1] = "f[1]";
112     W[2] = "f[2]";
113     W[3] = "f[3]";
114     W[4] = "f[4]";
115     W[5] = "f[5]";
116     W[6] = "f[6]";
117     W[7] = "f[7]";
118     W[8] = "f[8]";
119     bias = "f[9]";
120   }
121   auto read_3x_line = [&](int y) {
122     const std::string yc = "y" + std::to_string(y);
123     if (src_tensor_type == TensorStorageType::BUFFER) {
124       const std::string y_in = "y" + std::to_string(y) + "_in";
125       c += "    s0 = src_loc[args.src_tensor.GetWHOffset(x0, " + yc +
126            ")] * INIT_FLT(x0_in && " + y_in + ");\n";
127       c += "    s1 = src_loc[args.src_tensor.GetWHOffset(x1, " + yc +
128            ")] * INIT_FLT(x1_in && " + y_in + ");\n";
129       c += "    s2 = src_loc[args.src_tensor.GetWHOffset(x2, " + yc +
130            ")] * INIT_FLT(x2_in && " + y_in + ");\n";
131     } else if (src_tensor_type == TensorStorageType::IMAGE_BUFFER) {
132       const std::string y_in = "y" + std::to_string(y) + "_in";
133       c += "    s0 = args.src_tensor.Read(x0, " + yc +
134            ", S) * INIT_FLT(x0_in && " + y_in + ");\n";
135       c += "    s1 = args.src_tensor.Read(x1, " + yc +
136            ", S) * INIT_FLT(x1_in && " + y_in + ");\n";
137       c += "    s2 = args.src_tensor.Read(x2, " + yc +
138            ", S) * INIT_FLT(x2_in && " + y_in + ");\n";
139     } else {
140       c += "    s0 = args.src_tensor.Read(x0, " + yc + ", S);\n";
141       c += "    s1 = args.src_tensor.Read(x1, " + yc + ", S);\n";
142       c += "    s2 = args.src_tensor.Read(x2, " + yc + ", S);\n";
143     }
144   };
145   read_3x_line(0);
146   c += "    r0 += TO_ACCUM_TYPE(" + W[0] + " * s0);\n";
147   c += "    r0 += TO_ACCUM_TYPE(" + W[1] + " * s1);\n";
148   c += "    r0 += TO_ACCUM_TYPE(" + W[2] + " * s2);\n";
149   read_3x_line(1);
150   c += "    r0 += TO_ACCUM_TYPE(" + W[3] + " * s0);\n";
151   c += "    r0 += TO_ACCUM_TYPE(" + W[4] + " * s1);\n";
152   c += "    r0 += TO_ACCUM_TYPE(" + W[5] + " * s2);\n";
153   read_3x_line(2);
154   c += "    r0 += TO_ACCUM_TYPE(" + W[6] + " * s0);\n";
155   c += "    r0 += TO_ACCUM_TYPE(" + W[7] + " * s1);\n";
156   c += "    r0 += TO_ACCUM_TYPE(" + W[8] + " * s2);\n";
157   c += "    l0 += TO_ACCUM_TYPE(" + W[0] + " * s0);\n";
158   c += "    l0 += TO_ACCUM_TYPE(" + W[1] + " * s1);\n";
159   c += "    l0 += TO_ACCUM_TYPE(" + W[2] + " * s2);\n";
160   read_3x_line(3);
161   c += "    l0 += TO_ACCUM_TYPE(" + W[3] + " * s0);\n";
162   c += "    l0 += TO_ACCUM_TYPE(" + W[4] + " * s1);\n";
163   c += "    l0 += TO_ACCUM_TYPE(" + W[5] + " * s2);\n";
164   read_3x_line(4);
165   c += "    l0 += TO_ACCUM_TYPE(" + W[6] + " * s0);\n";
166   c += "    l0 += TO_ACCUM_TYPE(" + W[7] + " * s1);\n";
167   c += "    l0 += TO_ACCUM_TYPE(" + W[8] + " * s2);\n";
168   if (!weights_are_buffer) {
169     c += "   FLT4 bias = args.weights.Read(9, S);\n";
170   }
171   c += "  r0 += TO_ACCUM_TYPE(" + bias + ");\n";
172   c += "  l0 += TO_ACCUM_TYPE(" + bias + ");\n";
173   c += R"(
174   if (Y < args.dst_tensor.Height()) {
175     FLT4 value = TO_FLT4(r0);
176     args.dst_tensor.Write(value, X, Y, S);
177   }
178   if (Y + 1 < args.dst_tensor.Height()) {
179     FLT4 value = TO_FLT4(l0);
180     args.dst_tensor.Write(value, X, Y + 1, S);
181   }
182 }
183 )";
184 
185   return c;
186 }
187 
188 }  // namespace
189 
GetGridSize() const190 int3 DepthWiseConv3x3StrideH2::GetGridSize() const {
191   const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
192   const int grid_y = DivideRoundUp(dst_[0]->Height(), 2);
193   const int grid_z = dst_[0]->Slices();
194   return int3(grid_x, grid_y, grid_z);
195 }
196 
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const197 void DepthWiseConv3x3StrideH2::GetPossibleKernelWorkGroups(
198     TuningType tuning_type, const GpuInfo& gpu_info,
199     const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
200   if (local_mem_uploads_) {
201     work_groups->push_back(work_group_size_);
202   } else {
203     GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
204                           work_groups);
205   }
206 }
207 
CreateDepthWiseConv3x3StrideH2(const OperationDef & definition,const DepthwiseConvolution2DAttributes & attr,const GpuInfo & gpu_info)208 DepthWiseConv3x3StrideH2 CreateDepthWiseConv3x3StrideH2(
209     const OperationDef& definition,
210     const DepthwiseConvolution2DAttributes& attr, const GpuInfo& gpu_info) {
211   bool weights_are_buffer = !gpu_info.SupportsImages() ||
212                             gpu_info.IsPowerVR() || gpu_info.IsMali() ||
213                             gpu_info.IsApple();
214 
215   DepthWiseConv3x3StrideH2 desc(definition);
216   desc.local_mem_uploads_ = weights_are_buffer && gpu_info.IsPowerVR();
217   desc.work_group_size_ = int3(8, 4, 1);
218   desc.code_ = GetKernelDepthWiseConv3x3StrideH2(definition, weights_are_buffer,
219                                                  desc.local_mem_uploads_);
220   auto src_desc = definition.src_tensors[0];
221   src_desc.SetAddressMode(AddressMode::kZero);
222   desc.AddSrcTensor("src_tensor", src_desc);
223   desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
224 
225   desc.args_.AddInt("padding_x", -attr.padding.prepended.w);
226   desc.args_.AddInt("padding_y", -attr.padding.prepended.h);
227   desc.args_.AddInt("stride_x", attr.strides.w);
228   desc.args_.AddInt("dilation_x", attr.dilations.w);
229 
230   desc.UploadWeightsAndBiases(attr.weights, attr.bias, weights_are_buffer);
231   return desc;
232 }
233 
IsDepthWiseConv3x3StrideH2Supported(const DepthwiseConvolution2DAttributes & attr)234 bool IsDepthWiseConv3x3StrideH2Supported(
235     const DepthwiseConvolution2DAttributes& attr) {
236   return attr.weights.shape.o == 1 && attr.weights.shape.h == 3 &&
237          attr.weights.shape.w == 3 && attr.strides.h == 2 &&
238          attr.dilations.h == 1;
239 }
240 
241 }  // namespace gpu
242 }  // namespace tflite
243