1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/conv_weights_converter.h"
17
18 #include <cstring>
19 #include <string>
20
21 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
22 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
23
24 namespace tflite {
25 namespace gpu {
26
ConverterToConvWeights(const OperationDef & definition,const WeightsDescription & weights_desc)27 ConverterToConvWeights::ConverterToConvWeights(
28 const OperationDef& definition, const WeightsDescription& weights_desc)
29 : GPUOperation(definition), weights_desc_(weights_desc) {
30 code_ = GetConverterToConvWeightsCode(definition_, weights_desc_);
31 }
32
ConverterToConvWeights(ConverterToConvWeights && operation)33 ConverterToConvWeights::ConverterToConvWeights(
34 ConverterToConvWeights&& operation)
35 : GPUOperation(std::move(operation)),
36 weights_desc_(std::move(operation.weights_desc_)) {}
37
operator =(ConverterToConvWeights && operation)38 ConverterToConvWeights& ConverterToConvWeights::operator=(
39 ConverterToConvWeights&& operation) {
40 if (this != &operation) {
41 weights_desc_ = std::move(operation.weights_desc_);
42 GPUOperation::operator=(std::move(operation));
43 }
44 return *this;
45 }
46
GetConverterToConvWeightsCode(const OperationDef & op_def,const WeightsDescription & conv_weights_desc)47 std::string ConverterToConvWeights::GetConverterToConvWeightsCode(
48 const OperationDef& op_def, const WeightsDescription& conv_weights_desc) {
49 AddSrcTensor("src_tensor", op_def.src_tensors[0]);
50 args_.AddFloat("mask_x");
51 args_.AddFloat("mask_y");
52 args_.AddFloat("mask_z");
53 args_.AddFloat("mask_w");
54
55 if (conv_weights_desc.layout == WeightsLayout::kOICustomSpatialI4O4 ||
56 conv_weights_desc.layout == WeightsLayout::kOICustomSpatialO4I4) {
57 std::vector<int32_t> remap(conv_weights_desc.spatial_remap.size());
58 for (int i = 0; i < remap.size(); ++i) {
59 remap[i] = conv_weights_desc.spatial_remap[i];
60 }
61 BufferDescriptor desc;
62 desc.element_type = DataType::INT32;
63 desc.element_size = 1;
64 desc.memory_type = MemoryType::GLOBAL;
65 desc.size = remap.size() * sizeof(int32_t);
66 desc.data.resize(desc.size);
67 std::memcpy(desc.data.data(), remap.data(), desc.size);
68 args_.AddObject("spatial_remap",
69 absl::make_unique<BufferDescriptor>(std::move(desc)));
70 }
71
72 std::string c;
73 c += "MAIN_FUNCTION($0) {\n";
74 c += " int O = GLOBAL_ID_0 * 4;\n";
75 c += " int I = GLOBAL_ID_1;\n";
76 c += " int Z = GLOBAL_ID_2;\n";
77 c += " int W = Z % args.src_tensor.Width();\n";
78 c += " int H = Z / args.src_tensor.Width();\n";
79 c += " if (O >= args.src_tensor.Batch() || I >= args.src_tensor.Slices() || "
80 "H >= args.src_tensor.Height()) return;\n";
81 std::string x_kern = "W";
82 std::string y_kern = "H";
83 if (conv_weights_desc.layout == WeightsLayout::kOICustomSpatialI4O4 ||
84 conv_weights_desc.layout == WeightsLayout::kOICustomSpatialO4I4) {
85 c += " int spatial_linear = H * args.src_tensor.Width() + W;\n";
86 c += " int linear_remap = args.spatial_remap.Read(spatial_linear);\n";
87 c += " int w_remap = linear_remap % args.src_tensor.Width();\n";
88 c += " int h_remap = linear_remap / args.src_tensor.Width();\n";
89 x_kern = "w_remap";
90 y_kern = "h_remap";
91 }
92 const std::string coords = x_kern + ", " + y_kern;
93 c += " FLT4 v0 = args.src_tensor.Read(" + coords + ", I, O + 0);\n";
94 c += " FLT4 v1 = INIT_FLT4(0.0f);\n";
95 c += " FLT4 v2 = INIT_FLT4(0.0f);\n";
96 c += " FLT4 v3 = INIT_FLT4(0.0f);\n";
97 c += " if (O + 1 < args.src_tensor.Batch()) {\n";
98 c += " v1 = args.src_tensor.Read(" + coords + ", I, O + 1);\n";
99 c += " }\n";
100 c += " if (O + 2 < args.src_tensor.Batch()) {\n";
101 c += " v2 = args.src_tensor.Read(" + coords + ", I, O + 2);\n";
102 c += " }\n";
103 c += " if (O + 3 < args.src_tensor.Batch()) {\n";
104 c += " v3 = args.src_tensor.Read(" + coords + ", I, O + 3);\n";
105 c += " }\n";
106 c += " if (I == args.src_tensor.Slices() - 1) {\n";
107 c += " FLT4 mask = INIT_FLT4v4(args.mask_x, args.mask_y, args.mask_z, "
108 "args.mask_w);\n";
109 c += " v0 *= mask;\n";
110 c += " v1 *= mask;\n";
111 c += " v2 *= mask;\n";
112 c += " v3 *= mask;\n";
113 c += " }\n";
114 if (conv_weights_desc.IsI4O4()) {
115 c += " FLT4 r0 = INIT_FLT4v4(v0.x, v1.x, v2.x, v3.x);\n";
116 c += " FLT4 r1 = INIT_FLT4v4(v0.y, v1.y, v2.y, v3.y);\n";
117 c += " FLT4 r2 = INIT_FLT4v4(v0.z, v1.z, v2.z, v3.z);\n";
118 c += " FLT4 r3 = INIT_FLT4v4(v0.w, v1.w, v2.w, v3.w);\n";
119 } else if (conv_weights_desc.IsO4I4()) {
120 c += " FLT4 r0 = v0;\n";
121 c += " FLT4 r1 = v1;\n";
122 c += " FLT4 r2 = v2;\n";
123 c += " FLT4 r3 = v3;\n";
124 }
125 if (conv_weights_desc.layout == WeightsLayout::k2DX4I4YIsHWIAndXIsOOGroupO4 ||
126 conv_weights_desc.layout == WeightsLayout::k2DX4O4YIsHWIAndXIsOOGroupI4) {
127 // Writing to 4X Textures 2D
128 AddDstTensor("dst_tensor0", op_def.dst_tensors[0]);
129 AddDstTensor("dst_tensor1", op_def.dst_tensors[1]);
130 AddDstTensor("dst_tensor2", op_def.dst_tensors[2]);
131 AddDstTensor("dst_tensor3", op_def.dst_tensors[3]);
132 c += " int yc = (H * args.src_tensor.Width() + W) * "
133 "args.src_tensor.Slices() + I\n;";
134 c += " args.dst_tensor0.Write2D(r0, O / 4, yc)\n;";
135 c += " args.dst_tensor1.Write2D(r1, O / 4, yc)\n;";
136 c += " args.dst_tensor2.Write2D(r2, O / 4, yc)\n;";
137 c += " args.dst_tensor3.Write2D(r3, O / 4, yc)\n;";
138 c += "}\n";
139 } else {
140 // Writing to linear buffer
141 AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
142 c += " int GROUP_SIZE = " +
143 std::to_string(conv_weights_desc.GetOutputGroupSize()) + ";\n";
144 c += " int d_index = O / (GROUP_SIZE * 4);\n";
145 c += " int k_index = (O % (GROUP_SIZE * 4)) / 4;\n";
146 std::string index;
147 if (conv_weights_desc.layout == WeightsLayout::kOICustomSpatialI4O4 ||
148 conv_weights_desc.layout == WeightsLayout::kOICustomSpatialO4I4) {
149 index =
150 "((d_index * args.src_tensor.Slices() + I) * "
151 "args.src_tensor.Height() "
152 "+ H) * args.src_tensor.Width() + W";
153 } else if (conv_weights_desc.layout == WeightsLayout::kOHWIOGroupI4O4 ||
154 conv_weights_desc.layout == WeightsLayout::kOHWIOGroupO4I4) {
155 index =
156 "((d_index * args.src_tensor.Height() + H) * args.src_tensor.Width() "
157 "+ "
158 "W) * args.src_tensor.Slices() + I";
159 }
160 c += " int dst_offset = (" + index + ") * GROUP_SIZE + k_index;\n";
161 c += " args.dst_tensor.WriteLinear(r0, dst_offset * 4 + 0)\n;";
162 c += " args.dst_tensor.WriteLinear(r1, dst_offset * 4 + 1)\n;";
163 c += " args.dst_tensor.WriteLinear(r2, dst_offset * 4 + 2)\n;";
164 c += " args.dst_tensor.WriteLinear(r3, dst_offset * 4 + 3)\n;";
165 c += "}\n";
166 }
167 return c;
168 }
169
BindArguments(ArgumentsBinder * args)170 absl::Status ConverterToConvWeights::BindArguments(ArgumentsBinder* args) {
171 float4 mask = GetMaskForLastPlane(src_[0]->Channels());
172 RETURN_IF_ERROR(args->SetFloat("mask_x", mask.x));
173 RETURN_IF_ERROR(args->SetFloat("mask_y", mask.y));
174 RETURN_IF_ERROR(args->SetFloat("mask_z", mask.z));
175 return args->SetFloat("mask_w", mask.w);
176 }
177
GetGridSize() const178 int3 ConverterToConvWeights::GetGridSize() const {
179 const int out_group_size = weights_desc_.GetOutputGroupSize();
180 const int grid_x =
181 DivideRoundUp(AlignByN(src_[0]->Batch(), 4 * out_group_size), 4);
182 const int grid_y = src_[0]->Slices();
183 const int grid_z = src_[0]->Width() * src_[0]->Height();
184 return int3(grid_x, grid_y, grid_z);
185 }
186
CreateConverterToConvWeights(const OperationDef & definition,const WeightsDescription & weights_desc)187 ConverterToConvWeights CreateConverterToConvWeights(
188 const OperationDef& definition, const WeightsDescription& weights_desc) {
189 return ConverterToConvWeights(definition, weights_desc);
190 }
191
192 } // namespace gpu
193 } // namespace tflite
194