• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/conv_weights_converter.h"
17 
18 #include <cstring>
19 #include <string>
20 #include <utility>
21 
22 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
23 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
24 
25 namespace tflite {
26 namespace gpu {
27 
ConverterToConvWeights(const OperationDef & definition,const WeightsDescription & weights_desc)28 ConverterToConvWeights::ConverterToConvWeights(
29     const OperationDef& definition, const WeightsDescription& weights_desc)
30     : GPUOperation(definition), weights_desc_(weights_desc) {
31   code_ = GetConverterToConvWeightsCode(definition_, weights_desc_);
32 }
33 
ConverterToConvWeights(ConverterToConvWeights && operation)34 ConverterToConvWeights::ConverterToConvWeights(
35     ConverterToConvWeights&& operation)
36     : GPUOperation(std::move(operation)),
37       weights_desc_(std::move(operation.weights_desc_)) {}
38 
operator =(ConverterToConvWeights && operation)39 ConverterToConvWeights& ConverterToConvWeights::operator=(
40     ConverterToConvWeights&& operation) {
41   if (this != &operation) {
42     weights_desc_ = std::move(operation.weights_desc_);
43     GPUOperation::operator=(std::move(operation));
44   }
45   return *this;
46 }
47 
GetConverterToConvWeightsCode(const OperationDef & op_def,const WeightsDescription & conv_weights_desc)48 std::string ConverterToConvWeights::GetConverterToConvWeightsCode(
49     const OperationDef& op_def, const WeightsDescription& conv_weights_desc) {
50   AddSrcTensor("src_tensor", op_def.src_tensors[0]);
51   args_.AddFloat("mask_x");
52   args_.AddFloat("mask_y");
53   args_.AddFloat("mask_z");
54   args_.AddFloat("mask_w");
55   args_.AddInt("grid_x_size");
56 
57   if (conv_weights_desc.layout == WeightsLayout::kOICustomSpatialI4O4 ||
58       conv_weights_desc.layout == WeightsLayout::kOICustomSpatialO4I4) {
59     std::vector<int32_t> remap(conv_weights_desc.spatial_remap.size());
60     for (int i = 0; i < remap.size(); ++i) {
61       remap[i] = conv_weights_desc.spatial_remap[i];
62     }
63     BufferDescriptor desc;
64     desc.element_type = DataType::INT32;
65     desc.element_size = 1;
66     desc.memory_type = MemoryType::GLOBAL;
67     desc.size = remap.size() * sizeof(int32_t);
68     desc.data.resize(desc.size);
69     std::memcpy(desc.data.data(), remap.data(), desc.size);
70     args_.AddObject("spatial_remap",
71                     absl::make_unique<BufferDescriptor>(std::move(desc)));
72   }
73 
74   std::string c;
75   c += "MAIN_FUNCTION($0) {\n";
76   c += "  int O = GLOBAL_ID_0;\n";
77   c += "  int I = GLOBAL_ID_1;\n";
78   c += "  int Z = GLOBAL_ID_2;\n";
79   c += "  int W = Z % args.src_tensor.Width();\n";
80   c += "  int H = Z / args.src_tensor.Width();\n";
81   c += "  if (O >= args.grid_x_size || I >= args.src_tensor.Slices() || "
82        "H >= args.src_tensor.Height()) return;\n";
83   c += "  O *= 4;\n";
84   std::string x_kern = "W";
85   std::string y_kern = "H";
86   if (conv_weights_desc.layout == WeightsLayout::kOICustomSpatialI4O4 ||
87       conv_weights_desc.layout == WeightsLayout::kOICustomSpatialO4I4) {
88     c += "  int spatial_linear = H * args.src_tensor.Width() + W;\n";
89     c += "  int linear_remap = args.spatial_remap.Read(spatial_linear);\n";
90     c += "  int w_remap = linear_remap % args.src_tensor.Width();\n";
91     c += "  int h_remap = linear_remap / args.src_tensor.Width();\n";
92     x_kern = "w_remap";
93     y_kern = "h_remap";
94   }
95   const std::string coords = x_kern + ", " + y_kern;
96   c += "  FLT4 v0 = INIT_FLT4(0.0f);\n";
97   c += "  FLT4 v1 = INIT_FLT4(0.0f);\n";
98   c += "  FLT4 v2 = INIT_FLT4(0.0f);\n";
99   c += "  FLT4 v3 = INIT_FLT4(0.0f);\n";
100   c += "  if (O < args.src_tensor.Batch()) {\n";
101   c += "    v0 = args.src_tensor.Read(" + coords + ", I, O);\n";
102   c += "  }\n";
103   c += "  if (O + 1 < args.src_tensor.Batch()) {\n";
104   c += "    v1 = args.src_tensor.Read(" + coords + ", I, O + 1);\n";
105   c += "  }\n";
106   c += "  if (O + 2 < args.src_tensor.Batch()) {\n";
107   c += "    v2 = args.src_tensor.Read(" + coords + ", I, O + 2);\n";
108   c += "  }\n";
109   c += "  if (O + 3 < args.src_tensor.Batch()) {\n";
110   c += "    v3 = args.src_tensor.Read(" + coords + ", I, O + 3);\n";
111   c += "  }\n";
112   c += "  if (I == args.src_tensor.Slices() - 1) {\n";
113   c += "    FLT4 mask = INIT_FLT4v4(args.mask_x, args.mask_y, args.mask_z, "
114        "args.mask_w);\n";
115   c += "    v0 *= mask;\n";
116   c += "    v1 *= mask;\n";
117   c += "    v2 *= mask;\n";
118   c += "    v3 *= mask;\n";
119   c += "  }\n";
120   if (conv_weights_desc.IsI4O4()) {
121     c += "  FLT4 r0 = INIT_FLT4v4(v0.x, v1.x, v2.x, v3.x);\n";
122     c += "  FLT4 r1 = INIT_FLT4v4(v0.y, v1.y, v2.y, v3.y);\n";
123     c += "  FLT4 r2 = INIT_FLT4v4(v0.z, v1.z, v2.z, v3.z);\n";
124     c += "  FLT4 r3 = INIT_FLT4v4(v0.w, v1.w, v2.w, v3.w);\n";
125   } else if (conv_weights_desc.IsO4I4()) {
126     c += "  FLT4 r0 = v0;\n";
127     c += "  FLT4 r1 = v1;\n";
128     c += "  FLT4 r2 = v2;\n";
129     c += "  FLT4 r3 = v3;\n";
130   }
131   if (conv_weights_desc.layout ==
132           WeightsLayout::k2DX4I4YIsSpatialIAndXIsOOGroupO4 ||
133       conv_weights_desc.layout ==
134           WeightsLayout::k2DX4O4YIsSpatialIAndXIsOOGroupI4) {
135     // Writing to 4X Textures 2D
136     AddDstTensor("dst_tensor0", op_def.dst_tensors[0]);
137     AddDstTensor("dst_tensor1", op_def.dst_tensors[1]);
138     AddDstTensor("dst_tensor2", op_def.dst_tensors[2]);
139     AddDstTensor("dst_tensor3", op_def.dst_tensors[3]);
140     c += "  int yc = (H * args.src_tensor.Width() + W) * "
141          "args.src_tensor.Slices() + I;\n";
142     c += "  args.dst_tensor0.Write2D(r0, O / 4, yc);\n";
143     c += "  args.dst_tensor1.Write2D(r1, O / 4, yc);\n";
144     c += "  args.dst_tensor2.Write2D(r2, O / 4, yc);\n";
145     c += "  args.dst_tensor3.Write2D(r3, O / 4, yc);\n";
146     c += "}\n";
147   } else {
148     // Writing to linear buffer
149     AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
150     c += "  int GROUP_SIZE = " +
151          std::to_string(conv_weights_desc.GetOutputGroupSize()) + ";\n";
152     c += "  int d_index = O / (GROUP_SIZE * 4);\n";
153     c += "  int k_index = (O % (GROUP_SIZE * 4)) / 4;\n";
154     std::string index;
155     if (conv_weights_desc.layout == WeightsLayout::kOICustomSpatialI4O4 ||
156         conv_weights_desc.layout == WeightsLayout::kOICustomSpatialO4I4) {
157       index =
158           "((d_index * args.src_tensor.Slices() + I) * "
159           "args.src_tensor.Height() "
160           "+ H) * args.src_tensor.Width() + W";
161     } else if (conv_weights_desc.layout ==
162                    WeightsLayout::kOSpatialIOGroupI4O4 ||
163                conv_weights_desc.layout ==
164                    WeightsLayout::kOSpatialIOGroupO4I4) {
165       index =
166           "((d_index * args.src_tensor.Height() + H) * args.src_tensor.Width() "
167           "+ "
168           "W) * args.src_tensor.Slices() + I";
169     }
170     c += "  int dst_offset = (" + index + ") * GROUP_SIZE + k_index;\n";
171     c += "  args.dst_tensor.WriteLinear(r0, dst_offset * 4 + 0);\n";
172     c += "  args.dst_tensor.WriteLinear(r1, dst_offset * 4 + 1);\n";
173     c += "  args.dst_tensor.WriteLinear(r2, dst_offset * 4 + 2);\n";
174     c += "  args.dst_tensor.WriteLinear(r3, dst_offset * 4 + 3);\n";
175     c += "}\n";
176   }
177   return c;
178 }
179 
BindArguments(ArgumentsBinder * args)180 absl::Status ConverterToConvWeights::BindArguments(ArgumentsBinder* args) {
181   const int out_group_size = weights_desc_.GetOutputGroupSize();
182   const int grid_x =
183       DivideRoundUp(AlignByN(src_[0]->Batch(), 4 * out_group_size), 4);
184   RETURN_IF_ERROR(args->SetInt("grid_x_size", grid_x));
185   float4 mask = GetMaskForLastPlane(src_[0]->Channels());
186   RETURN_IF_ERROR(args->SetFloat("mask_x", mask.x));
187   RETURN_IF_ERROR(args->SetFloat("mask_y", mask.y));
188   RETURN_IF_ERROR(args->SetFloat("mask_z", mask.z));
189   return args->SetFloat("mask_w", mask.w);
190 }
191 
GetGridSize() const192 int3 ConverterToConvWeights::GetGridSize() const {
193   const int out_group_size = weights_desc_.GetOutputGroupSize();
194   const int grid_x =
195       DivideRoundUp(AlignByN(src_[0]->Batch(), 4 * out_group_size), 4);
196   const int grid_y = src_[0]->Slices();
197   const int grid_z = src_[0]->Width() * src_[0]->Height();
198   return int3(grid_x, grid_y, grid_z);
199 }
200 
CreateConverterToConvWeights(const OperationDef & definition,const WeightsDescription & weights_desc)201 ConverterToConvWeights CreateConverterToConvWeights(
202     const OperationDef& definition, const WeightsDescription& weights_desc) {
203   return ConverterToConvWeights(definition, weights_desc);
204 }
205 
206 }  // namespace gpu
207 }  // namespace tflite
208