1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "absl/memory/memory.h"
22 #include "tensorflow/lite/delegates/gpu/common/convert.h"
23 #include "tensorflow/lite/delegates/gpu/common/operations.h"
24 #include "tensorflow/lite/delegates/gpu/common/shape.h"
25 #include "tensorflow/lite/delegates/gpu/common/status.h"
26 #include "tensorflow/lite/delegates/gpu/common/types.h"
27 #include "tensorflow/lite/delegates/gpu/common/util.h"
28 #include "tensorflow/lite/delegates/gpu/gl/node_shader.h"
29 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
30 #include "tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.h"
31
32 namespace tflite {
33 namespace gpu {
34 namespace gl {
35 namespace {
36
37 class DepthwiseConvolution : public NodeShader {
38 public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const39 absl::Status GenerateCode(const GenerationContext& ctx,
40 GeneratedCode* generated_code) const final {
41 if (ctx.input_shapes.size() != 1) {
42 return absl::UnimplementedError(
43 "DepthWise Convolution does not support more than 1 runtime tensor");
44 }
45 const auto& attr =
46 absl::any_cast<const DepthwiseConvolution2DAttributes&>(ctx.op_attr);
47 auto weights = attr.weights.shape;
48 const int offsets_count = weights.h * weights.w;
49 const bool offsets_count_too_large = offsets_count > kMaxConstArraySize;
50 std::vector<Variable> parameters;
51 if (offsets_count_too_large) {
52 parameters = {
53 {"input_data_0_h", static_cast<int>(ctx.input_shapes[0][1])},
54 {"input_data_0_w", static_cast<int>(ctx.input_shapes[0][2])},
55 {"padding_w", attr.padding.prepended.w},
56 {"padding_h", attr.padding.prepended.h},
57 {"dilation_w", attr.dilations.w},
58 {"dilation_h", attr.dilations.h},
59 {"kernel_w", weights.w},
60 {"kernel_h", weights.h},
61 {"src_depth", DivideRoundUp(weights.i, 4)},
62 {"channel_multiplier", weights.o},
63 {"stride", int2(attr.strides.w, attr.strides.h)},
64 };
65 } else {
66 std::vector<int2> offsets;
67 for (int h = 0; h < weights.h; ++h) {
68 for (int w = 0; w < weights.w; ++w) {
69 offsets.emplace_back(w * attr.dilations.w - attr.padding.prepended.w,
70 h * attr.dilations.h - attr.padding.prepended.h);
71 }
72 }
73 parameters = {
74 {"input_data_0_h", static_cast<int>(ctx.input_shapes[0][1])},
75 {"input_data_0_w", static_cast<int>(ctx.input_shapes[0][2])},
76 {"offsets_count", offsets_count},
77 {"offsets", offsets},
78 {"src_depth", DivideRoundUp(weights.i, 4)},
79 {"channel_multiplier", weights.o},
80 {"stride", int2(attr.strides.w, attr.strides.h)},
81 };
82 }
83 bool non_empty_padding =
84 attr.padding.appended.h != 0 || attr.padding.appended.w != 0 ||
85 attr.padding.prepended.h != 0 || attr.padding.prepended.w != 0;
86
87 std::vector<std::pair<std::string, Object>> objects = {
88 {"weights", MakeReadonlyObject(ConvertToPIOHW4(attr.weights))}};
89
90 std::string source;
91 if (offsets_count_too_large) {
92 source = R"(
93 int offsets_count = $kernel_w$ * $kernel_h$;
94 int src_layer_offset = (gid.z % $channel_multiplier$) * 4;
95 int filter_offset = gid.z * $src_depth$ * offsets_count * 4;
96 int i = 0;
97 for (int ky = 0; ky < $kernel_h$; ky++) {
98 for (int kx = 0; kx < $kernel_w$; kx++, i++) {
99 ivec2 coord = gid.xy * $stride$ + ivec2(kx * $dilation_w$ - $padding_w$, ky * $dilation_h$ - $padding_h$);)";
100 } else {
101 source = R"(
102 int offsets_count = $offsets_count$;
103 int src_layer_offset = (gid.z % $channel_multiplier$) * 4;
104 int filter_offset = gid.z * $src_depth$ * offsets_count * 4;
105 for (int i = 0; i < offsets_count; ++i) {
106 ivec2 coord = gid.xy * $stride$ + $offsets[i]$;)";
107 }
108 if (non_empty_padding) {
109 source += R"(
110 if (coord.x < 0 || coord.y < 0 ||
111 coord.x >= $input_data_0_w$ || coord.y >= $input_data_0_h$) {
112 continue;
113 })";
114 }
115 source += R"(
116 int src_layer = gid.z / $channel_multiplier$;
117 vec4 input_ = $input_data_0[coord.x, coord.y, src_layer]$;
118 vec4 input_shifted = vec4(
119 input_[(src_layer_offset + 0) / $channel_multiplier$],
120 input_[(src_layer_offset + 1) / $channel_multiplier$],
121 input_[(src_layer_offset + 2) / $channel_multiplier$],
122 input_[(src_layer_offset + 3) / $channel_multiplier$]
123 );
124 int filter_offset = gid.z * offsets_count + i;
125 value_0 += input_shifted * $weights[filter_offset]$;
126 }
127 )";
128 if (offsets_count_too_large) {
129 source += R"(
130 }
131 )";
132 }
133 if (!attr.bias.data.empty()) {
134 source += "value_0 += $bias[gid.z]$;\n";
135 objects.push_back({"bias", MakeReadonlyObject(attr.bias.data)});
136 }
137 *generated_code = {
138 /*parameters=*/std::move(parameters),
139 /*objects=*/std::move(objects),
140 /*shared_variables=*/{},
141 /*workload=*/uint3(),
142 /*workgroup=*/
143 GetIdealWorkgroupIfPossible(
144 *ctx.gpu_info, OperationType::DEPTHWISE_CONVOLUTION,
145 HW(attr.weights.shape.h, attr.weights.shape.w), attr.strides,
146 OHWI(attr.weights.shape.o, ctx.input_shapes[0][1],
147 ctx.input_shapes[0][2], ctx.input_shapes[0][3])),
148 /*source_code=*/std::move(source),
149 /*input=*/IOStructure::ONLY_DEFINITIONS,
150 /*output=*/IOStructure::AUTO,
151 };
152 return absl::OkStatus();
153 }
154 };
155
156 } // namespace
157
NewDepthwiseConvolutionNodeShader()158 std::unique_ptr<NodeShader> NewDepthwiseConvolutionNodeShader() {
159 return absl::make_unique<DepthwiseConvolution>();
160 }
161
162 } // namespace gl
163 } // namespace gpu
164 } // namespace tflite
165