• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/depthwise_conv.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
24 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
25 #include "tensorflow/lite/delegates/gpu/common/util.h"
26 
27 namespace tflite {
28 namespace gpu {
29 
30 namespace {
31 
IsSpecializedCase(int channel_multiplier)32 bool IsSpecializedCase(int channel_multiplier) {
33   return channel_multiplier == 1 || channel_multiplier == 2 ||
34          channel_multiplier == 4;
35 }
36 
AppendToBack(const std::string & value,const std::string & delimeter,std::string * result)37 void AppendToBack(const std::string& value, const std::string& delimeter,
38                   std::string* result) {
39   if (!result->empty()) {
40     *result += delimeter;
41   }
42   *result += value;
43 }
44 
GetSrcValue(int channel_multiplier,const std::vector<std::string> & coords,const std::string & value_name)45 std::string GetSrcValue(int channel_multiplier,
46                         const std::vector<std::string>& coords,
47                         const std::string& value_name) {
48   std::string coords_str;
49   for (const auto& coord : coords) {
50     AppendToBack(coord, ", ", &coords_str);
51   }
52   std::string c;
53   if (channel_multiplier == 1) {
54     c += "    " + value_name + " = args.src_tensor.Read(" + coords_str +
55          ", S);\n";
56   } else if (channel_multiplier == 2) {
57     c += "    {int s_layer = S / 2;\n";
58     c += "    FLT4 src = args.src_tensor.Read(" + coords_str + ", s_layer);\n";
59     c += "    FLT2 t0 = S % 2 == 0 ? src.xy : src.zw;\n";
60     c += "    " + value_name + " = INIT_FLT4v4(t0.x, t0.x, t0.y, t0.y);}\n";
61   } else if (channel_multiplier == 4) {
62     c += "    {int s_layer = S / 4;\n";
63     c += "    FLT4 src = args.src_tensor.Read(" + coords_str + ", s_layer);\n";
64     c += "    FLT t0 = src.x;\n";
65     c += "    int reminder = S % 4;\n";
66     c += "    if (reminder == 1) t0 = src.y;\n";
67     c += "    if (reminder == 2) t0 = src.z;\n";
68     c += "    if (reminder == 3) t0 = src.w;\n";
69     c += "    " + value_name + " = INIT_FLT4v4(t0, t0, t0, t0);}\n";
70   } else {
71     c += "    {int s_layer = S / args.ch_multiplier;\n";
72     c += "    FLT4 src = args.src_tensor.Read(" + coords_str + ", s_layer);\n";
73     c += "    int s_offset = (S % args.ch_multiplier) * 4;\n";
74     c += "    FLT temp_arr[4] = {src.x, src.y, src.z, src.w};\n";
75     c += "    src.x = temp_arr[(s_offset + 0) / args.ch_multiplier];\n";
76     c += "    src.y = temp_arr[(s_offset + 1) / args.ch_multiplier];\n";
77     c += "    src.z = temp_arr[(s_offset + 2) / args.ch_multiplier];\n";
78     c += "    src.w = temp_arr[(s_offset + 3) / args.ch_multiplier];\n";
79     c += "    " + value_name + " = src;}\n";
80   }
81 
82   return c;
83 }
84 
GetSrcXYCheck(const GpuInfo & gpu_info,const TensorDescriptor & src_desc,const std::string & x_coord,const std::string & y_coord)85 std::string GetSrcXYCheck(const GpuInfo& gpu_info,
86                           const TensorDescriptor& src_desc,
87                           const std::string& x_coord,
88                           const std::string& y_coord) {
89   std::string result;
90   if (!src_desc.SupportsZeroClamp(Axis::WIDTH, gpu_info)) {
91     const std::string x_check =
92         x_coord + " >= 0 && " + x_coord + " < args.src_tensor.Width()";
93     AppendToBack(x_check, " && ", &result);
94   }
95   if (!src_desc.SupportsZeroClamp(Axis::HEIGHT, gpu_info)) {
96     const std::string y_check =
97         y_coord + " >= 0 && " + y_coord + " < args.src_tensor.Height()";
98     AppendToBack(y_check, " && ", &result);
99   }
100   return result;
101 }
102 
UseBuffersForWeights(const GpuInfo & gpu_info)103 bool UseBuffersForWeights(const GpuInfo& gpu_info) {
104   if (gpu_info.IsApple()) {
105     if (gpu_info.apple_info.IsA7GenerationGpu() ||
106         gpu_info.apple_info.IsA8GenerationGpu()) {
107       return false;
108     }
109   }
110   return !gpu_info.SupportsImages() || gpu_info.IsMali() ||
111          gpu_info.IsApple() || gpu_info.IsAMD();
112 }
113 }  // namespace
114 
DepthwiseConv(const OperationDef & definition,const DepthwiseConvParams & params)115 DepthwiseConv::DepthwiseConv(const OperationDef& definition,
116                              const DepthwiseConvParams& params)
117     : GPUOperation(definition), params_(params) {
118   if (params.UseLocalMem()) {
119     work_group_size_ = params.work_group_size;
120   }
121 }
122 
GetGridSize() const123 int3 DepthwiseConv::GetGridSize() const {
124   const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
125   const int grid_y = dst_[0]->Height() * dst_[0]->Depth();
126   const int grid_z = dst_[0]->Slices();
127   return int3(grid_x, grid_y, grid_z);
128 }
129 
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const130 void DepthwiseConv::GetPossibleKernelWorkGroups(
131     TuningType tuning_type, const GpuInfo& gpu_info,
132     const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
133   if (params_.UseLocalMem()) {
134     work_groups->push_back(work_group_size_);
135     return;
136   }
137   GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
138                         work_groups);
139 }
140 
GenerateSrcUpload(const GpuInfo & gpu_info)141 std::string DepthwiseConv::GenerateSrcUpload(const GpuInfo& gpu_info) {
142   int cache_size_x = params_.work_group_size.x +
143                      params_.x_kernel_size * params_.x_dilation_size - 1;
144   int cache_size_y = params_.work_group_size.y +
145                      params_.y_kernel_size * params_.y_dilation_size - 1;
146   int groups_x = DivideRoundUp(cache_size_x, params_.work_group_size.x);
147   int groups_y = DivideRoundUp(cache_size_y, params_.work_group_size.y);
148   std::string c;
149   c += "  __local FLT4 spatial_cache[" + std::to_string(cache_size_y) + "][" +
150        std::to_string(cache_size_x) + "];\n";
151   for (int gr_y = 0; gr_y < groups_y; ++gr_y) {
152     std::string y_offset = std::to_string(params_.work_group_size.y * gr_y);
153     std::string ys = "(y_src + " + y_offset + ")";
154     std::string ly = "(LOCAL_ID_1 + " + y_offset + ")";
155     for (int gr_x = 0; gr_x < groups_x; ++gr_x) {
156       std::string x_offset = std::to_string(params_.work_group_size.x * gr_x);
157       std::string xs = "(x_src + " + x_offset + ")";
158       std::string lx = "(LOCAL_ID_0 + " + x_offset + ")";
159       std::string value = "spatial_cache[" + ly + "][" + lx + "]";
160       std::string src_value_read_instructions =
161           GetSrcValue(params_.channel_multiplier, {xs, ys}, value);
162       std::string check =
163           GetSrcXYCheck(gpu_info, definition_.src_tensors[0], xs, ys);
164       c += "  if (" + lx + " < " + std::to_string(cache_size_x) + " && " + ly +
165            " < " + std::to_string(cache_size_y) + ") {\n";
166       if (check.empty()) {
167         c += src_value_read_instructions;
168       } else {
169         c += "    if (" + check + ") {\n";
170         c += src_value_read_instructions;
171         c += "    } else {\n";
172         c += "      " + value + " = INIT_FLT4(0.0f);\n";
173         c += "    }\n";
174       }
175       c += "  }\n";
176     }
177   }
178   return c;
179 }
180 
GenerateWeightsUpload(const GpuInfo & gpu_info)181 std::string DepthwiseConv::GenerateWeightsUpload(const GpuInfo& gpu_info) {
182   const bool weights_are_buffer = UseBuffersForWeights(gpu_info);
183   auto read_weight = [](bool weights_are_buffer, const std::string& lid,
184                         int work_group_total_size) {
185     if (weights_are_buffer) {
186       return "args.weights.Read(S * args.kernels_total_size + " + lid + ")";
187     } else {
188       return "args.weights.Read(" + lid + ", S)";
189     }
190   };
191   std::string c;
192   const int work_group_total_size = params_.GetWorkGroupTotalSize();
193   c += "  __local FLT4 weights_cache[" +
194        std::to_string(params_.GetKernelsTotalSize()) + "];\n";
195   c += "  int linear_local_id = (LOCAL_ID_2 * GROUP_SIZE_1 + LOCAL_ID_1) * "
196        "GROUP_SIZE_0 + LOCAL_ID_0;\n";
197   const int groups = params_.GetKernelsTotalSize() / work_group_total_size;
198   const int reminder = params_.GetKernelsTotalSize() % work_group_total_size;
199   for (int i = 0; i < groups; ++i) {
200     const std::string lid =
201         "linear_local_id + " + std::to_string(work_group_total_size * i);
202     c += "  weights_cache[" + lid +
203          "] = " + read_weight(weights_are_buffer, lid, work_group_total_size) +
204          ";\n";
205   }
206   if (reminder != 0) {
207     const std::string lid =
208         "linear_local_id + " + std::to_string(work_group_total_size * groups);
209     c += "  if (linear_local_id < " + std::to_string(reminder) + ") {\n";
210     c += "    weights_cache[" + lid +
211          "] = " + read_weight(weights_are_buffer, lid, work_group_total_size) +
212          ";\n";
213     c += "  }\n";
214   }
215   return c;
216 }
217 
GenerateCode(const GpuInfo & gpu_info)218 std::string DepthwiseConv::GenerateCode(const GpuInfo& gpu_info) {
219   const bool weights_are_buffer = UseBuffersForWeights(gpu_info);
220   const bool dynamic_weights = definition_.src_tensors.size() == 2;
221   AddSrcTensor("src_tensor", definition_.src_tensors[0]);
222   if (dynamic_weights) {
223     AddSrcTensor("weights", definition_.src_tensors[1]);
224   }
225   AddDstTensor("dst_tensor", definition_.dst_tensors[0]);
226 
227   std::string c;
228 
229   const auto& src_desc = definition_.src_tensors[0];
230   c += "MAIN_FUNCTION($0) {\n";
231   if (src_desc.HasAxis(Axis::BATCH)) {
232     c += "  int linear_id = GLOBAL_ID_0;\n";
233     c += "  int X = linear_id / args.dst_tensor.Batch();\n";
234     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
235     c += "  args.src_tensor.SetBatchRef(B);\n";
236     c += "  args.dst_tensor.SetBatchRef(B);\n";
237   } else {
238     c += "  int X = GLOBAL_ID_0;\n";
239   }
240   if (src_desc.HasAxis(Axis::DEPTH)) {
241     c += "  int linear_id_1 = GLOBAL_ID_1;\n";
242     c += "  int Y = linear_id_1 / args.dst_tensor.Depth();\n";
243     c += "  int Z = linear_id_1 % args.dst_tensor.Depth();\n";
244   } else {
245     c += "  int Y = GLOBAL_ID_1;\n";
246   }
247   c += "  int S = GLOBAL_ID_2;\n";
248   c += "  int x_src = X * args.stride_x + args.padding_x;\n";
249   c += "  int y_src = Y * args.stride_y + args.padding_y;\n";
250   if (src_desc.HasAxis(Axis::DEPTH)) {
251     c += "  int z_src = Z * args.stride_z + args.padding_z;\n";
252   }
253   if (params_.use_spatial_caching) {
254     c += GenerateSrcUpload(gpu_info);
255   }
256   if (params_.use_weights_caching) {
257     c += GenerateWeightsUpload(gpu_info);
258   }
259   if (params_.UseLocalMem()) {
260     c += "  LOCAL_MEM_BARRIER;\n";
261   }
262   c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
263        "S >= args.dst_tensor.Slices()) { \n";
264   c += "    return; \n";
265   c += "  } \n";
266   c += "  ACCUM_FLT4 r = INIT_ACCUM_FLT4(0.0f);\n";
267   if (!dynamic_weights && !params_.use_weights_caching) {
268     if (weights_are_buffer) {
269       c += "  int fx_c = S * args.kernels_total_size;\n";
270     } else {
271       c += "  int fx_c = 0;\n";
272     }
273   }
274   std::string kernel_size_x =
275       dynamic_weights ? "args.weights.Width()" : "args.kernel_size_x";
276   std::string kernel_size_y =
277       dynamic_weights ? "args.weights.Height()" : "args.kernel_size_y";
278   std::string kernel_size_z =
279       dynamic_weights ? "args.weights.Depth()" : "args.kernel_size_z";
280   if (params_.UseLocalMem()) {
281     kernel_size_x = std::to_string(params_.x_kernel_size);
282     kernel_size_y = std::to_string(params_.y_kernel_size);
283     kernel_size_z = std::to_string(params_.z_kernel_size);
284   }
285 
286   std::string check;
287   std::vector<std::string> coords;
288   if (src_desc.HasAxis(Axis::DEPTH)) {
289     c += "  for (int kz = 0; kz < " + kernel_size_z + "; ++kz) {\n";
290     if (!params_.use_spatial_caching) {
291       c += "    int z_c = z_src + kz * args.dilation_z;\n";
292       coords.insert(coords.begin(), "z_c");
293       if (!src_desc.SupportsZeroClamp(Axis::DEPTH, gpu_info)) {
294         c += "    bool inside_z = z_c >= 0 && z_c < args.src_tensor.Depth();\n";
295         c += "    z_c = clamp(z_c, 0, args.src_tensor.Depth() - 1);\n";
296         AppendToBack("inside_z", " && ", &check);
297       }
298     }
299   }
300   if (src_desc.HasAxis(Axis::HEIGHT)) {
301     c += "  for (int ky = 0; ky < " + kernel_size_y + "; ++ky) {\n";
302     if (!params_.use_spatial_caching) {
303       c += "    int y_c = y_src + ky * args.dilation_y;\n";
304       coords.insert(coords.begin(), "y_c");
305       if (!src_desc.SupportsZeroClamp(Axis::HEIGHT, gpu_info)) {
306         c +=
307             "    bool inside_y = y_c >= 0 && y_c < args.src_tensor.Height();\n";
308         c += "    y_c = clamp(y_c, 0, args.src_tensor.Height() - 1);\n";
309         AppendToBack("inside_y", " && ", &check);
310       }
311     }
312   }
313   if (src_desc.HasAxis(Axis::WIDTH)) {
314     c += "  for (int kx = 0; kx < " + kernel_size_x + "; ++kx) {\n";
315     if (!params_.use_spatial_caching) {
316       c += "    int x_c = x_src + kx * args.dilation_x;\n";
317       coords.insert(coords.begin(), "x_c");
318       if (!src_desc.SupportsZeroClamp(Axis::WIDTH, gpu_info)) {
319         c += "    bool inside_x = x_c >= 0 && x_c < args.src_tensor.Width();\n";
320         c += "    x_c = clamp(x_c, 0, args.src_tensor.Width() - 1);\n";
321         AppendToBack("inside_x", " && ", &check);
322       }
323     }
324   }
325   std::string weight_value;
326   if (params_.use_weights_caching) {
327     std::string weight_index = "ky";
328     if (src_desc.HasAxis(Axis::DEPTH)) {
329       weight_index =
330           "(kz * " + std::to_string(params_.y_kernel_size) + " + ky)";
331     }
332     weight_value = "weights_cache[" + weight_index + " * " +
333                    std::to_string(params_.x_kernel_size) + " + kx]";
334   } else {
335     weight_value = "f";
336     if (dynamic_weights) {
337       c += "    FLT4 f = args.weights.Read(kx, ky, S);\n";
338     } else {
339       if (weights_are_buffer) {
340         c += "    FLT4 f = args.weights.Read(fx_c);\n";
341       } else {
342         c += "    FLT4 f = args.weights.Read(fx_c, S);\n";
343       }
344     }
345   }
346   std::string src_value;
347   if (params_.use_spatial_caching) {
348     std::string loc_x = params_.x_dilation_size == 1
349                             ? "kx"
350                             : "kx * " + std::to_string(params_.x_dilation_size);
351     std::string loc_y = params_.y_dilation_size == 1
352                             ? "ky"
353                             : "ky * " + std::to_string(params_.y_dilation_size);
354     src_value =
355         "spatial_cache[LOCAL_ID_1 + " + loc_y + "][LOCAL_ID_0 + " + loc_x + "]";
356   } else {
357     c += "    FLT4 src_final;\n";
358     src_value = "src_final";
359     c += GetSrcValue(params_.channel_multiplier, coords, src_value);
360     if (!check.empty()) {
361       c += "    src_final = src_final * INIT_FLT(" + check + ");\n";
362     }
363   }
364   c += "    r += TO_ACCUM_TYPE(" + src_value + " * " + weight_value + ");\n";
365   if (!dynamic_weights && !params_.use_weights_caching) {
366     c += "    fx_c++;\n";
367   }
368   if (src_desc.HasAxis(Axis::WIDTH)) {
369     c += "  }\n";
370   }
371   if (src_desc.HasAxis(Axis::HEIGHT)) {
372     c += "  }\n";
373   }
374   if (src_desc.HasAxis(Axis::DEPTH)) {
375     c += "  }\n";
376   }
377   c += "  FLT4 res0 = TO_FLT4(r) + args.biases.Read(S);\n";
378   if (src_desc.HasAxis(Axis::DEPTH)) {
379     c += "  args.dst_tensor.Write(res0, X, Y, Z, S);\n";
380   } else {
381     c += "  args.dst_tensor.Write(res0, X, Y, S);\n";
382   }
383   c += "}\n";
384   return c;
385 }
386 
CreateDepthwiseConvolution2D(const GpuInfo & gpu_info,const OperationDef & definition,const DepthwiseConvolution2DAttributes & attr)387 DepthwiseConv CreateDepthwiseConvolution2D(
388     const GpuInfo& gpu_info, const OperationDef& definition,
389     const DepthwiseConvolution2DAttributes& attr) {
390   const bool weights_are_buffer = UseBuffersForWeights(gpu_info);
391   DepthwiseConv::DepthwiseConvParams params;
392   params.channel_multiplier = attr.weights.shape.o;
393   if (gpu_info.IsAMD()) {
394     if (attr.strides.w == 1 && attr.strides.h == 1 && attr.dilations.w == 1 &&
395         attr.dilations.h == 1 &&
396         attr.weights.shape.w * attr.weights.shape.h >= 10) {
397       params.use_weights_caching = true;
398       params.use_spatial_caching = true;
399       params.x_kernel_size = attr.weights.shape.w;
400       params.y_kernel_size = attr.weights.shape.h;
401       params.x_dilation_size = attr.dilations.w;
402       params.y_dilation_size = attr.dilations.h;
403       params.work_group_size = int3(16, 16, 1);
404     }
405   }
406   DepthwiseConv op(definition, params);
407   op.args_.AddInt("kernel_size_x", attr.weights.shape.w);
408   op.args_.AddInt("stride_x", attr.strides.w);
409   op.args_.AddInt("padding_x", -attr.padding.prepended.w);
410   op.args_.AddInt("dilation_x", attr.dilations.w);
411   op.args_.AddInt("kernel_size_y", attr.weights.shape.h);
412   op.args_.AddInt("stride_y", attr.strides.h);
413   op.args_.AddInt("padding_y", -attr.padding.prepended.h);
414   op.args_.AddInt("dilation_y", attr.dilations.h);
415   op.args_.AddInt("kernels_total_size",
416                   attr.weights.shape.w * attr.weights.shape.h);
417   if (!IsSpecializedCase(attr.weights.shape.o)) {
418     op.args_.AddInt("ch_multiplier", attr.weights.shape.o);
419   }
420   op.code_ = op.GenerateCode(gpu_info);
421   op.UploadWeightsForDWConv2D(attr.weights, weights_are_buffer);
422   op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
423 
424   TensorDescriptor bias_tensor_desc = CreateConstantLinearTensorDescriptor(
425       gpu_info, definition.src_tensors[0].GetDataType(), attr.bias);
426   op.args_.AddObject("biases", std::make_unique<TensorDescriptor>(
427                                    std::move(bias_tensor_desc)));
428   return op;
429 }
430 
CreateDepthwiseConvolution2DDynamicWeights(const GpuInfo & gpu_info,const OperationDef & definition,const DepthwiseConvolution2DAttributes & attr)431 DepthwiseConv CreateDepthwiseConvolution2DDynamicWeights(
432     const GpuInfo& gpu_info, const OperationDef& definition,
433     const DepthwiseConvolution2DAttributes& attr) {
434   DepthwiseConv::DepthwiseConvParams params;
435   params.channel_multiplier = 1;
436   DepthwiseConv op(definition, params);
437   op.args_.AddInt("stride_x", attr.strides.w);
438   op.args_.AddInt("padding_x", -attr.padding.prepended.w);
439   op.args_.AddInt("dilation_x", attr.dilations.w);
440   op.args_.AddInt("stride_y", attr.strides.h);
441   op.args_.AddInt("padding_y", -attr.padding.prepended.h);
442   op.args_.AddInt("dilation_y", attr.dilations.h);
443   op.code_ = op.GenerateCode(gpu_info);
444   op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
445 
446   TensorDescriptor bias_tensor_desc = CreateConstantLinearTensorDescriptor(
447       gpu_info, definition.src_tensors[0].GetDataType(), attr.bias);
448   op.args_.AddObject("biases", std::make_unique<TensorDescriptor>(
449                                    std::move(bias_tensor_desc)));
450   return op;
451 }
452 
CreateDepthwiseConvolution3D(const GpuInfo & gpu_info,const OperationDef & definition,const DepthwiseConvolution3DAttributes & attr)453 DepthwiseConv CreateDepthwiseConvolution3D(
454     const GpuInfo& gpu_info, const OperationDef& definition,
455     const DepthwiseConvolution3DAttributes& attr) {
456   const bool weights_are_buffer = UseBuffersForWeights(gpu_info);
457   DepthwiseConv::DepthwiseConvParams params;
458   params.channel_multiplier = attr.weights.shape.o;
459   DepthwiseConv op(definition, params);
460   op.args_.AddInt("kernel_size_x", attr.weights.shape.w);
461   op.args_.AddInt("stride_x", attr.strides.w);
462   op.args_.AddInt("padding_x", -attr.padding.prepended.w);
463   op.args_.AddInt("dilation_x", attr.dilations.w);
464   op.args_.AddInt("kernel_size_y", attr.weights.shape.h);
465   op.args_.AddInt("stride_y", attr.strides.h);
466   op.args_.AddInt("padding_y", -attr.padding.prepended.h);
467   op.args_.AddInt("dilation_y", attr.dilations.h);
468   op.args_.AddInt("kernel_size_z", attr.weights.shape.d);
469   op.args_.AddInt("stride_z", attr.strides.d);
470   op.args_.AddInt("padding_z", -attr.padding.prepended.d);
471   op.args_.AddInt("dilation_z", attr.dilations.d);
472   op.args_.AddInt(
473       "kernels_total_size",
474       attr.weights.shape.w * attr.weights.shape.h * attr.weights.shape.d);
475   if (!IsSpecializedCase(attr.weights.shape.o)) {
476     op.args_.AddInt("ch_multiplier", attr.weights.shape.o);
477   }
478   op.code_ = op.GenerateCode(gpu_info);
479   op.UploadWeightsForDWConv3D(attr.weights, weights_are_buffer);
480   op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
481 
482   TensorDescriptor bias_tensor_desc = CreateConstantLinearTensorDescriptor(
483       gpu_info, definition.src_tensors[0].GetDataType(), attr.bias);
484   op.args_.AddObject("biases", std::make_unique<TensorDescriptor>(
485                                    std::move(bias_tensor_desc)));
486   return op;
487 }
488 
489 }  // namespace gpu
490 }  // namespace tflite
491