1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/depthwise_conv.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
24 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
25 #include "tensorflow/lite/delegates/gpu/common/util.h"
26
27 namespace tflite {
28 namespace gpu {
29
30 namespace {
31
IsSpecializedCase(int channel_multiplier)32 bool IsSpecializedCase(int channel_multiplier) {
33 return channel_multiplier == 1 || channel_multiplier == 2 ||
34 channel_multiplier == 4;
35 }
36
AppendToBack(const std::string & value,const std::string & delimeter,std::string * result)37 void AppendToBack(const std::string& value, const std::string& delimeter,
38 std::string* result) {
39 if (!result->empty()) {
40 *result += delimeter;
41 }
42 *result += value;
43 }
44
GetSrcValue(int channel_multiplier,const std::vector<std::string> & coords,const std::string & value_name)45 std::string GetSrcValue(int channel_multiplier,
46 const std::vector<std::string>& coords,
47 const std::string& value_name) {
48 std::string coords_str;
49 for (const auto& coord : coords) {
50 AppendToBack(coord, ", ", &coords_str);
51 }
52 std::string c;
53 if (channel_multiplier == 1) {
54 c += " " + value_name + " = args.src_tensor.Read(" + coords_str +
55 ", S);\n";
56 } else if (channel_multiplier == 2) {
57 c += " {int s_layer = S / 2;\n";
58 c += " FLT4 src = args.src_tensor.Read(" + coords_str + ", s_layer);\n";
59 c += " FLT2 t0 = S % 2 == 0 ? src.xy : src.zw;\n";
60 c += " " + value_name + " = INIT_FLT4v4(t0.x, t0.x, t0.y, t0.y);}\n";
61 } else if (channel_multiplier == 4) {
62 c += " {int s_layer = S / 4;\n";
63 c += " FLT4 src = args.src_tensor.Read(" + coords_str + ", s_layer);\n";
64 c += " FLT t0 = src.x;\n";
65 c += " int reminder = S % 4;\n";
66 c += " if (reminder == 1) t0 = src.y;\n";
67 c += " if (reminder == 2) t0 = src.z;\n";
68 c += " if (reminder == 3) t0 = src.w;\n";
69 c += " " + value_name + " = INIT_FLT4v4(t0, t0, t0, t0);}\n";
70 } else {
71 c += " {int s_layer = S / args.ch_multiplier;\n";
72 c += " FLT4 src = args.src_tensor.Read(" + coords_str + ", s_layer);\n";
73 c += " int s_offset = (S % args.ch_multiplier) * 4;\n";
74 c += " FLT temp_arr[4] = {src.x, src.y, src.z, src.w};\n";
75 c += " src.x = temp_arr[(s_offset + 0) / args.ch_multiplier];\n";
76 c += " src.y = temp_arr[(s_offset + 1) / args.ch_multiplier];\n";
77 c += " src.z = temp_arr[(s_offset + 2) / args.ch_multiplier];\n";
78 c += " src.w = temp_arr[(s_offset + 3) / args.ch_multiplier];\n";
79 c += " " + value_name + " = src;}\n";
80 }
81
82 return c;
83 }
84
GetSrcXYCheck(const GpuInfo & gpu_info,const TensorDescriptor & src_desc,const std::string & x_coord,const std::string & y_coord)85 std::string GetSrcXYCheck(const GpuInfo& gpu_info,
86 const TensorDescriptor& src_desc,
87 const std::string& x_coord,
88 const std::string& y_coord) {
89 std::string result;
90 if (!src_desc.SupportsZeroClamp(Axis::WIDTH, gpu_info)) {
91 const std::string x_check =
92 x_coord + " >= 0 && " + x_coord + " < args.src_tensor.Width()";
93 AppendToBack(x_check, " && ", &result);
94 }
95 if (!src_desc.SupportsZeroClamp(Axis::HEIGHT, gpu_info)) {
96 const std::string y_check =
97 y_coord + " >= 0 && " + y_coord + " < args.src_tensor.Height()";
98 AppendToBack(y_check, " && ", &result);
99 }
100 return result;
101 }
102
UseBuffersForWeights(const GpuInfo & gpu_info)103 bool UseBuffersForWeights(const GpuInfo& gpu_info) {
104 if (gpu_info.IsApple()) {
105 if (gpu_info.apple_info.IsA7GenerationGpu() ||
106 gpu_info.apple_info.IsA8GenerationGpu()) {
107 return false;
108 }
109 }
110 return !gpu_info.SupportsImages() || gpu_info.IsMali() ||
111 gpu_info.IsApple() || gpu_info.IsAMD();
112 }
113 } // namespace
114
DepthwiseConv(const OperationDef & definition,const DepthwiseConvParams & params)115 DepthwiseConv::DepthwiseConv(const OperationDef& definition,
116 const DepthwiseConvParams& params)
117 : GPUOperation(definition), params_(params) {
118 if (params.UseLocalMem()) {
119 work_group_size_ = params.work_group_size;
120 }
121 }
122
GetGridSize() const123 int3 DepthwiseConv::GetGridSize() const {
124 const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
125 const int grid_y = dst_[0]->Height() * dst_[0]->Depth();
126 const int grid_z = dst_[0]->Slices();
127 return int3(grid_x, grid_y, grid_z);
128 }
129
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const130 void DepthwiseConv::GetPossibleKernelWorkGroups(
131 TuningType tuning_type, const GpuInfo& gpu_info,
132 const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
133 if (params_.UseLocalMem()) {
134 work_groups->push_back(work_group_size_);
135 return;
136 }
137 GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
138 work_groups);
139 }
140
GenerateSrcUpload(const GpuInfo & gpu_info)141 std::string DepthwiseConv::GenerateSrcUpload(const GpuInfo& gpu_info) {
142 int cache_size_x = params_.work_group_size.x +
143 params_.x_kernel_size * params_.x_dilation_size - 1;
144 int cache_size_y = params_.work_group_size.y +
145 params_.y_kernel_size * params_.y_dilation_size - 1;
146 int groups_x = DivideRoundUp(cache_size_x, params_.work_group_size.x);
147 int groups_y = DivideRoundUp(cache_size_y, params_.work_group_size.y);
148 std::string c;
149 c += " __local FLT4 spatial_cache[" + std::to_string(cache_size_y) + "][" +
150 std::to_string(cache_size_x) + "];\n";
151 for (int gr_y = 0; gr_y < groups_y; ++gr_y) {
152 std::string y_offset = std::to_string(params_.work_group_size.y * gr_y);
153 std::string ys = "(y_src + " + y_offset + ")";
154 std::string ly = "(LOCAL_ID_1 + " + y_offset + ")";
155 for (int gr_x = 0; gr_x < groups_x; ++gr_x) {
156 std::string x_offset = std::to_string(params_.work_group_size.x * gr_x);
157 std::string xs = "(x_src + " + x_offset + ")";
158 std::string lx = "(LOCAL_ID_0 + " + x_offset + ")";
159 std::string value = "spatial_cache[" + ly + "][" + lx + "]";
160 std::string src_value_read_instructions =
161 GetSrcValue(params_.channel_multiplier, {xs, ys}, value);
162 std::string check =
163 GetSrcXYCheck(gpu_info, definition_.src_tensors[0], xs, ys);
164 c += " if (" + lx + " < " + std::to_string(cache_size_x) + " && " + ly +
165 " < " + std::to_string(cache_size_y) + ") {\n";
166 if (check.empty()) {
167 c += src_value_read_instructions;
168 } else {
169 c += " if (" + check + ") {\n";
170 c += src_value_read_instructions;
171 c += " } else {\n";
172 c += " " + value + " = INIT_FLT4(0.0f);\n";
173 c += " }\n";
174 }
175 c += " }\n";
176 }
177 }
178 return c;
179 }
180
GenerateWeightsUpload(const GpuInfo & gpu_info)181 std::string DepthwiseConv::GenerateWeightsUpload(const GpuInfo& gpu_info) {
182 const bool weights_are_buffer = UseBuffersForWeights(gpu_info);
183 auto read_weight = [](bool weights_are_buffer, const std::string& lid,
184 int work_group_total_size) {
185 if (weights_are_buffer) {
186 return "args.weights.Read(S * args.kernels_total_size + " + lid + ")";
187 } else {
188 return "args.weights.Read(" + lid + ", S)";
189 }
190 };
191 std::string c;
192 const int work_group_total_size = params_.GetWorkGroupTotalSize();
193 c += " __local FLT4 weights_cache[" +
194 std::to_string(params_.GetKernelsTotalSize()) + "];\n";
195 c += " int linear_local_id = (LOCAL_ID_2 * GROUP_SIZE_1 + LOCAL_ID_1) * "
196 "GROUP_SIZE_0 + LOCAL_ID_0;\n";
197 const int groups = params_.GetKernelsTotalSize() / work_group_total_size;
198 const int reminder = params_.GetKernelsTotalSize() % work_group_total_size;
199 for (int i = 0; i < groups; ++i) {
200 const std::string lid =
201 "linear_local_id + " + std::to_string(work_group_total_size * i);
202 c += " weights_cache[" + lid +
203 "] = " + read_weight(weights_are_buffer, lid, work_group_total_size) +
204 ";\n";
205 }
206 if (reminder != 0) {
207 const std::string lid =
208 "linear_local_id + " + std::to_string(work_group_total_size * groups);
209 c += " if (linear_local_id < " + std::to_string(reminder) + ") {\n";
210 c += " weights_cache[" + lid +
211 "] = " + read_weight(weights_are_buffer, lid, work_group_total_size) +
212 ";\n";
213 c += " }\n";
214 }
215 return c;
216 }
217
GenerateCode(const GpuInfo & gpu_info)218 std::string DepthwiseConv::GenerateCode(const GpuInfo& gpu_info) {
219 const bool weights_are_buffer = UseBuffersForWeights(gpu_info);
220 const bool dynamic_weights = definition_.src_tensors.size() == 2;
221 AddSrcTensor("src_tensor", definition_.src_tensors[0]);
222 if (dynamic_weights) {
223 AddSrcTensor("weights", definition_.src_tensors[1]);
224 }
225 AddDstTensor("dst_tensor", definition_.dst_tensors[0]);
226
227 std::string c;
228
229 const auto& src_desc = definition_.src_tensors[0];
230 c += "MAIN_FUNCTION($0) {\n";
231 if (src_desc.HasAxis(Axis::BATCH)) {
232 c += " int linear_id = GLOBAL_ID_0;\n";
233 c += " int X = linear_id / args.dst_tensor.Batch();\n";
234 c += " int B = linear_id % args.dst_tensor.Batch();\n";
235 c += " args.src_tensor.SetBatchRef(B);\n";
236 c += " args.dst_tensor.SetBatchRef(B);\n";
237 } else {
238 c += " int X = GLOBAL_ID_0;\n";
239 }
240 if (src_desc.HasAxis(Axis::DEPTH)) {
241 c += " int linear_id_1 = GLOBAL_ID_1;\n";
242 c += " int Y = linear_id_1 / args.dst_tensor.Depth();\n";
243 c += " int Z = linear_id_1 % args.dst_tensor.Depth();\n";
244 } else {
245 c += " int Y = GLOBAL_ID_1;\n";
246 }
247 c += " int S = GLOBAL_ID_2;\n";
248 c += " int x_src = X * args.stride_x + args.padding_x;\n";
249 c += " int y_src = Y * args.stride_y + args.padding_y;\n";
250 if (src_desc.HasAxis(Axis::DEPTH)) {
251 c += " int z_src = Z * args.stride_z + args.padding_z;\n";
252 }
253 if (params_.use_spatial_caching) {
254 c += GenerateSrcUpload(gpu_info);
255 }
256 if (params_.use_weights_caching) {
257 c += GenerateWeightsUpload(gpu_info);
258 }
259 if (params_.UseLocalMem()) {
260 c += " LOCAL_MEM_BARRIER;\n";
261 }
262 c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
263 "S >= args.dst_tensor.Slices()) { \n";
264 c += " return; \n";
265 c += " } \n";
266 c += " ACCUM_FLT4 r = INIT_ACCUM_FLT4(0.0f);\n";
267 if (!dynamic_weights && !params_.use_weights_caching) {
268 if (weights_are_buffer) {
269 c += " int fx_c = S * args.kernels_total_size;\n";
270 } else {
271 c += " int fx_c = 0;\n";
272 }
273 }
274 std::string kernel_size_x =
275 dynamic_weights ? "args.weights.Width()" : "args.kernel_size_x";
276 std::string kernel_size_y =
277 dynamic_weights ? "args.weights.Height()" : "args.kernel_size_y";
278 std::string kernel_size_z =
279 dynamic_weights ? "args.weights.Depth()" : "args.kernel_size_z";
280 if (params_.UseLocalMem()) {
281 kernel_size_x = std::to_string(params_.x_kernel_size);
282 kernel_size_y = std::to_string(params_.y_kernel_size);
283 kernel_size_z = std::to_string(params_.z_kernel_size);
284 }
285
286 std::string check;
287 std::vector<std::string> coords;
288 if (src_desc.HasAxis(Axis::DEPTH)) {
289 c += " for (int kz = 0; kz < " + kernel_size_z + "; ++kz) {\n";
290 if (!params_.use_spatial_caching) {
291 c += " int z_c = z_src + kz * args.dilation_z;\n";
292 coords.insert(coords.begin(), "z_c");
293 if (!src_desc.SupportsZeroClamp(Axis::DEPTH, gpu_info)) {
294 c += " bool inside_z = z_c >= 0 && z_c < args.src_tensor.Depth();\n";
295 c += " z_c = clamp(z_c, 0, args.src_tensor.Depth() - 1);\n";
296 AppendToBack("inside_z", " && ", &check);
297 }
298 }
299 }
300 if (src_desc.HasAxis(Axis::HEIGHT)) {
301 c += " for (int ky = 0; ky < " + kernel_size_y + "; ++ky) {\n";
302 if (!params_.use_spatial_caching) {
303 c += " int y_c = y_src + ky * args.dilation_y;\n";
304 coords.insert(coords.begin(), "y_c");
305 if (!src_desc.SupportsZeroClamp(Axis::HEIGHT, gpu_info)) {
306 c +=
307 " bool inside_y = y_c >= 0 && y_c < args.src_tensor.Height();\n";
308 c += " y_c = clamp(y_c, 0, args.src_tensor.Height() - 1);\n";
309 AppendToBack("inside_y", " && ", &check);
310 }
311 }
312 }
313 if (src_desc.HasAxis(Axis::WIDTH)) {
314 c += " for (int kx = 0; kx < " + kernel_size_x + "; ++kx) {\n";
315 if (!params_.use_spatial_caching) {
316 c += " int x_c = x_src + kx * args.dilation_x;\n";
317 coords.insert(coords.begin(), "x_c");
318 if (!src_desc.SupportsZeroClamp(Axis::WIDTH, gpu_info)) {
319 c += " bool inside_x = x_c >= 0 && x_c < args.src_tensor.Width();\n";
320 c += " x_c = clamp(x_c, 0, args.src_tensor.Width() - 1);\n";
321 AppendToBack("inside_x", " && ", &check);
322 }
323 }
324 }
325 std::string weight_value;
326 if (params_.use_weights_caching) {
327 std::string weight_index = "ky";
328 if (src_desc.HasAxis(Axis::DEPTH)) {
329 weight_index =
330 "(kz * " + std::to_string(params_.y_kernel_size) + " + ky)";
331 }
332 weight_value = "weights_cache[" + weight_index + " * " +
333 std::to_string(params_.x_kernel_size) + " + kx]";
334 } else {
335 weight_value = "f";
336 if (dynamic_weights) {
337 c += " FLT4 f = args.weights.Read(kx, ky, S);\n";
338 } else {
339 if (weights_are_buffer) {
340 c += " FLT4 f = args.weights.Read(fx_c);\n";
341 } else {
342 c += " FLT4 f = args.weights.Read(fx_c, S);\n";
343 }
344 }
345 }
346 std::string src_value;
347 if (params_.use_spatial_caching) {
348 std::string loc_x = params_.x_dilation_size == 1
349 ? "kx"
350 : "kx * " + std::to_string(params_.x_dilation_size);
351 std::string loc_y = params_.y_dilation_size == 1
352 ? "ky"
353 : "ky * " + std::to_string(params_.y_dilation_size);
354 src_value =
355 "spatial_cache[LOCAL_ID_1 + " + loc_y + "][LOCAL_ID_0 + " + loc_x + "]";
356 } else {
357 c += " FLT4 src_final;\n";
358 src_value = "src_final";
359 c += GetSrcValue(params_.channel_multiplier, coords, src_value);
360 if (!check.empty()) {
361 c += " src_final = src_final * INIT_FLT(" + check + ");\n";
362 }
363 }
364 c += " r += TO_ACCUM_TYPE(" + src_value + " * " + weight_value + ");\n";
365 if (!dynamic_weights && !params_.use_weights_caching) {
366 c += " fx_c++;\n";
367 }
368 if (src_desc.HasAxis(Axis::WIDTH)) {
369 c += " }\n";
370 }
371 if (src_desc.HasAxis(Axis::HEIGHT)) {
372 c += " }\n";
373 }
374 if (src_desc.HasAxis(Axis::DEPTH)) {
375 c += " }\n";
376 }
377 c += " FLT4 res0 = TO_FLT4(r) + args.biases.Read(S);\n";
378 if (src_desc.HasAxis(Axis::DEPTH)) {
379 c += " args.dst_tensor.Write(res0, X, Y, Z, S);\n";
380 } else {
381 c += " args.dst_tensor.Write(res0, X, Y, S);\n";
382 }
383 c += "}\n";
384 return c;
385 }
386
CreateDepthwiseConvolution2D(const GpuInfo & gpu_info,const OperationDef & definition,const DepthwiseConvolution2DAttributes & attr)387 DepthwiseConv CreateDepthwiseConvolution2D(
388 const GpuInfo& gpu_info, const OperationDef& definition,
389 const DepthwiseConvolution2DAttributes& attr) {
390 const bool weights_are_buffer = UseBuffersForWeights(gpu_info);
391 DepthwiseConv::DepthwiseConvParams params;
392 params.channel_multiplier = attr.weights.shape.o;
393 if (gpu_info.IsAMD()) {
394 if (attr.strides.w == 1 && attr.strides.h == 1 && attr.dilations.w == 1 &&
395 attr.dilations.h == 1 &&
396 attr.weights.shape.w * attr.weights.shape.h >= 10) {
397 params.use_weights_caching = true;
398 params.use_spatial_caching = true;
399 params.x_kernel_size = attr.weights.shape.w;
400 params.y_kernel_size = attr.weights.shape.h;
401 params.x_dilation_size = attr.dilations.w;
402 params.y_dilation_size = attr.dilations.h;
403 params.work_group_size = int3(16, 16, 1);
404 }
405 }
406 DepthwiseConv op(definition, params);
407 op.args_.AddInt("kernel_size_x", attr.weights.shape.w);
408 op.args_.AddInt("stride_x", attr.strides.w);
409 op.args_.AddInt("padding_x", -attr.padding.prepended.w);
410 op.args_.AddInt("dilation_x", attr.dilations.w);
411 op.args_.AddInt("kernel_size_y", attr.weights.shape.h);
412 op.args_.AddInt("stride_y", attr.strides.h);
413 op.args_.AddInt("padding_y", -attr.padding.prepended.h);
414 op.args_.AddInt("dilation_y", attr.dilations.h);
415 op.args_.AddInt("kernels_total_size",
416 attr.weights.shape.w * attr.weights.shape.h);
417 if (!IsSpecializedCase(attr.weights.shape.o)) {
418 op.args_.AddInt("ch_multiplier", attr.weights.shape.o);
419 }
420 op.code_ = op.GenerateCode(gpu_info);
421 op.UploadWeightsForDWConv2D(attr.weights, weights_are_buffer);
422 op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
423
424 TensorDescriptor bias_tensor_desc = CreateConstantLinearTensorDescriptor(
425 gpu_info, definition.src_tensors[0].GetDataType(), attr.bias);
426 op.args_.AddObject("biases", std::make_unique<TensorDescriptor>(
427 std::move(bias_tensor_desc)));
428 return op;
429 }
430
CreateDepthwiseConvolution2DDynamicWeights(const GpuInfo & gpu_info,const OperationDef & definition,const DepthwiseConvolution2DAttributes & attr)431 DepthwiseConv CreateDepthwiseConvolution2DDynamicWeights(
432 const GpuInfo& gpu_info, const OperationDef& definition,
433 const DepthwiseConvolution2DAttributes& attr) {
434 DepthwiseConv::DepthwiseConvParams params;
435 params.channel_multiplier = 1;
436 DepthwiseConv op(definition, params);
437 op.args_.AddInt("stride_x", attr.strides.w);
438 op.args_.AddInt("padding_x", -attr.padding.prepended.w);
439 op.args_.AddInt("dilation_x", attr.dilations.w);
440 op.args_.AddInt("stride_y", attr.strides.h);
441 op.args_.AddInt("padding_y", -attr.padding.prepended.h);
442 op.args_.AddInt("dilation_y", attr.dilations.h);
443 op.code_ = op.GenerateCode(gpu_info);
444 op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
445
446 TensorDescriptor bias_tensor_desc = CreateConstantLinearTensorDescriptor(
447 gpu_info, definition.src_tensors[0].GetDataType(), attr.bias);
448 op.args_.AddObject("biases", std::make_unique<TensorDescriptor>(
449 std::move(bias_tensor_desc)));
450 return op;
451 }
452
CreateDepthwiseConvolution3D(const GpuInfo & gpu_info,const OperationDef & definition,const DepthwiseConvolution3DAttributes & attr)453 DepthwiseConv CreateDepthwiseConvolution3D(
454 const GpuInfo& gpu_info, const OperationDef& definition,
455 const DepthwiseConvolution3DAttributes& attr) {
456 const bool weights_are_buffer = UseBuffersForWeights(gpu_info);
457 DepthwiseConv::DepthwiseConvParams params;
458 params.channel_multiplier = attr.weights.shape.o;
459 DepthwiseConv op(definition, params);
460 op.args_.AddInt("kernel_size_x", attr.weights.shape.w);
461 op.args_.AddInt("stride_x", attr.strides.w);
462 op.args_.AddInt("padding_x", -attr.padding.prepended.w);
463 op.args_.AddInt("dilation_x", attr.dilations.w);
464 op.args_.AddInt("kernel_size_y", attr.weights.shape.h);
465 op.args_.AddInt("stride_y", attr.strides.h);
466 op.args_.AddInt("padding_y", -attr.padding.prepended.h);
467 op.args_.AddInt("dilation_y", attr.dilations.h);
468 op.args_.AddInt("kernel_size_z", attr.weights.shape.d);
469 op.args_.AddInt("stride_z", attr.strides.d);
470 op.args_.AddInt("padding_z", -attr.padding.prepended.d);
471 op.args_.AddInt("dilation_z", attr.dilations.d);
472 op.args_.AddInt(
473 "kernels_total_size",
474 attr.weights.shape.w * attr.weights.shape.h * attr.weights.shape.d);
475 if (!IsSpecializedCase(attr.weights.shape.o)) {
476 op.args_.AddInt("ch_multiplier", attr.weights.shape.o);
477 }
478 op.code_ = op.GenerateCode(gpu_info);
479 op.UploadWeightsForDWConv3D(attr.weights, weights_are_buffer);
480 op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
481
482 TensorDescriptor bias_tensor_desc = CreateConstantLinearTensorDescriptor(
483 gpu_info, definition.src_tensors[0].GetDataType(), attr.bias);
484 op.args_.AddObject("biases", std::make_unique<TensorDescriptor>(
485 std::move(bias_tensor_desc)));
486 return op;
487 }
488
489 } // namespace gpu
490 } // namespace tflite
491