1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/conv_weights_converter.h"
17
18 #include <cstring>
19 #include <string>
20 #include <utility>
21
22 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
23 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
24
25 namespace tflite {
26 namespace gpu {
27
ConverterToConvWeights(const OperationDef & definition,const WeightsDescription & weights_desc)28 ConverterToConvWeights::ConverterToConvWeights(
29 const OperationDef& definition, const WeightsDescription& weights_desc)
30 : GPUOperation(definition), weights_desc_(weights_desc) {
31 code_ = GetConverterToConvWeightsCode(definition_, weights_desc_);
32 }
33
ConverterToConvWeights(ConverterToConvWeights && operation)34 ConverterToConvWeights::ConverterToConvWeights(
35 ConverterToConvWeights&& operation)
36 : GPUOperation(std::move(operation)),
37 weights_desc_(std::move(operation.weights_desc_)) {}
38
operator =(ConverterToConvWeights && operation)39 ConverterToConvWeights& ConverterToConvWeights::operator=(
40 ConverterToConvWeights&& operation) {
41 if (this != &operation) {
42 weights_desc_ = std::move(operation.weights_desc_);
43 GPUOperation::operator=(std::move(operation));
44 }
45 return *this;
46 }
47
GetConverterToConvWeightsCode(const OperationDef & op_def,const WeightsDescription & conv_weights_desc)48 std::string ConverterToConvWeights::GetConverterToConvWeightsCode(
49 const OperationDef& op_def, const WeightsDescription& conv_weights_desc) {
50 AddSrcTensor("src_tensor", op_def.src_tensors[0]);
51 args_.AddFloat("mask_x");
52 args_.AddFloat("mask_y");
53 args_.AddFloat("mask_z");
54 args_.AddFloat("mask_w");
55 args_.AddInt("grid_x_size");
56
57 if (conv_weights_desc.layout == WeightsLayout::kOICustomSpatialI4O4 ||
58 conv_weights_desc.layout == WeightsLayout::kOICustomSpatialO4I4) {
59 std::vector<int32_t> remap(conv_weights_desc.spatial_remap.size());
60 for (int i = 0; i < remap.size(); ++i) {
61 remap[i] = conv_weights_desc.spatial_remap[i];
62 }
63 BufferDescriptor desc;
64 desc.element_type = DataType::INT32;
65 desc.element_size = 1;
66 desc.memory_type = MemoryType::GLOBAL;
67 desc.size = remap.size() * sizeof(int32_t);
68 desc.data.resize(desc.size);
69 std::memcpy(desc.data.data(), remap.data(), desc.size);
70 args_.AddObject("spatial_remap",
71 absl::make_unique<BufferDescriptor>(std::move(desc)));
72 }
73
74 std::string c;
75 c += "MAIN_FUNCTION($0) {\n";
76 c += " int O = GLOBAL_ID_0;\n";
77 c += " int I = GLOBAL_ID_1;\n";
78 c += " int Z = GLOBAL_ID_2;\n";
79 c += " int W = Z % args.src_tensor.Width();\n";
80 c += " int H = Z / args.src_tensor.Width();\n";
81 c += " if (O >= args.grid_x_size || I >= args.src_tensor.Slices() || "
82 "H >= args.src_tensor.Height()) return;\n";
83 c += " O *= 4;\n";
84 std::string x_kern = "W";
85 std::string y_kern = "H";
86 if (conv_weights_desc.layout == WeightsLayout::kOICustomSpatialI4O4 ||
87 conv_weights_desc.layout == WeightsLayout::kOICustomSpatialO4I4) {
88 c += " int spatial_linear = H * args.src_tensor.Width() + W;\n";
89 c += " int linear_remap = args.spatial_remap.Read(spatial_linear);\n";
90 c += " int w_remap = linear_remap % args.src_tensor.Width();\n";
91 c += " int h_remap = linear_remap / args.src_tensor.Width();\n";
92 x_kern = "w_remap";
93 y_kern = "h_remap";
94 }
95 const std::string coords = x_kern + ", " + y_kern;
96 c += " FLT4 v0 = INIT_FLT4(0.0f);\n";
97 c += " FLT4 v1 = INIT_FLT4(0.0f);\n";
98 c += " FLT4 v2 = INIT_FLT4(0.0f);\n";
99 c += " FLT4 v3 = INIT_FLT4(0.0f);\n";
100 c += " if (O < args.src_tensor.Batch()) {\n";
101 c += " v0 = args.src_tensor.Read(" + coords + ", I, O);\n";
102 c += " }\n";
103 c += " if (O + 1 < args.src_tensor.Batch()) {\n";
104 c += " v1 = args.src_tensor.Read(" + coords + ", I, O + 1);\n";
105 c += " }\n";
106 c += " if (O + 2 < args.src_tensor.Batch()) {\n";
107 c += " v2 = args.src_tensor.Read(" + coords + ", I, O + 2);\n";
108 c += " }\n";
109 c += " if (O + 3 < args.src_tensor.Batch()) {\n";
110 c += " v3 = args.src_tensor.Read(" + coords + ", I, O + 3);\n";
111 c += " }\n";
112 c += " if (I == args.src_tensor.Slices() - 1) {\n";
113 c += " FLT4 mask = INIT_FLT4v4(args.mask_x, args.mask_y, args.mask_z, "
114 "args.mask_w);\n";
115 c += " v0 *= mask;\n";
116 c += " v1 *= mask;\n";
117 c += " v2 *= mask;\n";
118 c += " v3 *= mask;\n";
119 c += " }\n";
120 if (conv_weights_desc.IsI4O4()) {
121 c += " FLT4 r0 = INIT_FLT4v4(v0.x, v1.x, v2.x, v3.x);\n";
122 c += " FLT4 r1 = INIT_FLT4v4(v0.y, v1.y, v2.y, v3.y);\n";
123 c += " FLT4 r2 = INIT_FLT4v4(v0.z, v1.z, v2.z, v3.z);\n";
124 c += " FLT4 r3 = INIT_FLT4v4(v0.w, v1.w, v2.w, v3.w);\n";
125 } else if (conv_weights_desc.IsO4I4()) {
126 c += " FLT4 r0 = v0;\n";
127 c += " FLT4 r1 = v1;\n";
128 c += " FLT4 r2 = v2;\n";
129 c += " FLT4 r3 = v3;\n";
130 }
131 if (conv_weights_desc.layout ==
132 WeightsLayout::k2DX4I4YIsSpatialIAndXIsOOGroupO4 ||
133 conv_weights_desc.layout ==
134 WeightsLayout::k2DX4O4YIsSpatialIAndXIsOOGroupI4) {
135 // Writing to 4X Textures 2D
136 AddDstTensor("dst_tensor0", op_def.dst_tensors[0]);
137 AddDstTensor("dst_tensor1", op_def.dst_tensors[1]);
138 AddDstTensor("dst_tensor2", op_def.dst_tensors[2]);
139 AddDstTensor("dst_tensor3", op_def.dst_tensors[3]);
140 c += " int yc = (H * args.src_tensor.Width() + W) * "
141 "args.src_tensor.Slices() + I;\n";
142 c += " args.dst_tensor0.Write2D(r0, O / 4, yc);\n";
143 c += " args.dst_tensor1.Write2D(r1, O / 4, yc);\n";
144 c += " args.dst_tensor2.Write2D(r2, O / 4, yc);\n";
145 c += " args.dst_tensor3.Write2D(r3, O / 4, yc);\n";
146 c += "}\n";
147 } else {
148 // Writing to linear buffer
149 AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
150 c += " int GROUP_SIZE = " +
151 std::to_string(conv_weights_desc.GetOutputGroupSize()) + ";\n";
152 c += " int d_index = O / (GROUP_SIZE * 4);\n";
153 c += " int k_index = (O % (GROUP_SIZE * 4)) / 4;\n";
154 std::string index;
155 if (conv_weights_desc.layout == WeightsLayout::kOICustomSpatialI4O4 ||
156 conv_weights_desc.layout == WeightsLayout::kOICustomSpatialO4I4) {
157 index =
158 "((d_index * args.src_tensor.Slices() + I) * "
159 "args.src_tensor.Height() "
160 "+ H) * args.src_tensor.Width() + W";
161 } else if (conv_weights_desc.layout ==
162 WeightsLayout::kOSpatialIOGroupI4O4 ||
163 conv_weights_desc.layout ==
164 WeightsLayout::kOSpatialIOGroupO4I4) {
165 index =
166 "((d_index * args.src_tensor.Height() + H) * args.src_tensor.Width() "
167 "+ "
168 "W) * args.src_tensor.Slices() + I";
169 }
170 c += " int dst_offset = (" + index + ") * GROUP_SIZE + k_index;\n";
171 c += " args.dst_tensor.WriteLinear(r0, dst_offset * 4 + 0);\n";
172 c += " args.dst_tensor.WriteLinear(r1, dst_offset * 4 + 1);\n";
173 c += " args.dst_tensor.WriteLinear(r2, dst_offset * 4 + 2);\n";
174 c += " args.dst_tensor.WriteLinear(r3, dst_offset * 4 + 3);\n";
175 c += "}\n";
176 }
177 return c;
178 }
179
BindArguments(ArgumentsBinder * args)180 absl::Status ConverterToConvWeights::BindArguments(ArgumentsBinder* args) {
181 const int out_group_size = weights_desc_.GetOutputGroupSize();
182 const int grid_x =
183 DivideRoundUp(AlignByN(src_[0]->Batch(), 4 * out_group_size), 4);
184 RETURN_IF_ERROR(args->SetInt("grid_x_size", grid_x));
185 float4 mask = GetMaskForLastPlane(src_[0]->Channels());
186 RETURN_IF_ERROR(args->SetFloat("mask_x", mask.x));
187 RETURN_IF_ERROR(args->SetFloat("mask_y", mask.y));
188 RETURN_IF_ERROR(args->SetFloat("mask_z", mask.z));
189 return args->SetFloat("mask_w", mask.w);
190 }
191
GetGridSize() const192 int3 ConverterToConvWeights::GetGridSize() const {
193 const int out_group_size = weights_desc_.GetOutputGroupSize();
194 const int grid_x =
195 DivideRoundUp(AlignByN(src_[0]->Batch(), 4 * out_group_size), 4);
196 const int grid_y = src_[0]->Slices();
197 const int grid_z = src_[0]->Width() * src_[0]->Height();
198 return int3(grid_x, grid_y, grid_z);
199 }
200
CreateConverterToConvWeights(const OperationDef & definition,const WeightsDescription & weights_desc)201 ConverterToConvWeights CreateConverterToConvWeights(
202 const OperationDef& definition, const WeightsDescription& weights_desc) {
203 return ConverterToConvWeights(definition, weights_desc);
204 }
205
206 } // namespace gpu
207 } // namespace tflite
208