1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/strided_slice.h"
17
18 #include <string>
19
20 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
21
22 namespace tflite {
23 namespace gpu {
24
25 namespace {
Is4Aligned(const SliceAttributes & attr)26 bool Is4Aligned(const SliceAttributes& attr) {
27 return attr.strides.c == 1 && attr.starts.c % 4 == 0;
28 }
29
GetOffset(const SliceAttributes & attr,int src_width,int src_height,int src_channels,int src_batch)30 int4 GetOffset(const SliceAttributes& attr, int src_width, int src_height,
31 int src_channels, int src_batch) {
32 int4 offset;
33 if (attr.strides.w > 0) {
34 offset.x = attr.starts.w;
35 } else {
36 if (attr.ends.w > 0) {
37 offset.x = attr.ends.w;
38 } else {
39 offset.x = src_width + attr.ends.w;
40 }
41 }
42 if (attr.strides.h > 0) {
43 offset.y = attr.starts.h;
44 } else {
45 if (attr.ends.h > 0) {
46 offset.y = attr.ends.h;
47 } else {
48 offset.y = src_height + attr.ends.h;
49 }
50 }
51 if (attr.strides.c > 0) {
52 offset.z = attr.starts.c;
53 } else {
54 if (attr.ends.c > 0) {
55 offset.z = attr.ends.c;
56 } else {
57 offset.z = src_channels + attr.ends.c;
58 }
59 }
60 if (Is4Aligned(attr)) {
61 offset.z /= 4;
62 }
63 if (attr.strides.b > 0) {
64 offset.w = attr.starts.b;
65 } else {
66 if (attr.ends.b > 0) {
67 offset.w = attr.ends.b;
68 } else {
69 offset.w = src_batch + attr.ends.b;
70 }
71 }
72 return offset;
73 }
74
75 } // namespace
76
StridedSlice(const OperationDef & definition,const SliceAttributes & attr)77 StridedSlice::StridedSlice(const OperationDef& definition,
78 const SliceAttributes& attr)
79 : GPUOperation(definition), attributes_(attr) {
80 work_group_size_ = int3(8, 4, 1);
81 code_ = GetStridedSliceCode(definition_, Is4Aligned(attributes_));
82 }
83
StridedSlice(StridedSlice && operation)84 StridedSlice::StridedSlice(StridedSlice&& operation)
85 : GPUOperation(std::move(operation)), attributes_(operation.attributes_) {}
86
operator =(StridedSlice && operation)87 StridedSlice& StridedSlice::operator=(StridedSlice&& operation) {
88 if (this != &operation) {
89 attributes_ = operation.attributes_;
90 GPUOperation::operator=(std::move(operation));
91 }
92 return *this;
93 }
94
GetStridedSliceCode(const OperationDef & op_def,bool alignedx4)95 std::string StridedSlice::GetStridedSliceCode(const OperationDef& op_def,
96 bool alignedx4) {
97 AddSrcTensor("src_tensor", op_def.src_tensors[0]);
98 AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
99 args_.AddInt("offset_x");
100 args_.AddInt("offset_y");
101 args_.AddInt("offset_z");
102 args_.AddInt("offset_b");
103 args_.AddInt("stride_x");
104 args_.AddInt("stride_y");
105 args_.AddInt("stride_z");
106 args_.AddInt("stride_b");
107
108 const std::string batch_id =
109 op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0";
110 std::string c;
111 c += "MAIN_FUNCTION($0) {\n";
112 if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
113 c += " int linear_id = GLOBAL_ID_0;\n";
114 c += " int X = linear_id / args.dst_tensor.Batch();\n";
115 c += " int B = linear_id % args.dst_tensor.Batch();\n";
116 c += " args.dst_tensor.SetBatchRef(B);\n";
117 } else {
118 c += " int X = GLOBAL_ID_0;\n";
119 }
120 c += " int Y = GLOBAL_ID_1;\n";
121 c += " int S = GLOBAL_ID_2;\n";
122 c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
123 "S >= args.dst_tensor.Slices()) { \n";
124 c += " return; \n";
125 c += " } \n";
126 c += " int s_x = X * args.stride_x + args.offset_x;\n";
127 c += " int s_y = Y * args.stride_y + args.offset_y;\n";
128 if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) {
129 c += " int s_b = " + batch_id + " * args.stride_b + args.offset_b;\n";
130 c += " args.src_tensor.SetBatchRef(s_b);\n";
131 }
132 if (alignedx4) {
133 c += " int s_z = S + args.offset_z;\n";
134 c += " FLT4 result = args.src_tensor.Read(s_x, s_y, s_z);\n";
135 } else {
136 c += " FLT4 result;\n";
137 const std::string postfixes[] = {"x", "y", "z", "w"};
138 for (int i = 0; i < 4; ++i) {
139 c += " {\n";
140 const std::string channel = "(S * 4 + " + std::to_string(i) + ")";
141 c += " int s_ch = " + channel + " * args.stride_z + args.offset_z;\n";
142 c += " int s_z = min(s_ch >> 2, args.src_tensor.Slices() - 1);\n";
143 c += " int s_z_rem = s_ch & 3;\n";
144 c += " FLT4 t = args.src_tensor.Read(s_x, s_y, s_z);\n";
145 c += " FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n";
146 c += " result." + postfixes[i] + " = t_ar[s_z_rem];\n";
147 c += " }\n";
148 }
149 }
150 c += " args.dst_tensor.Write(result, X, Y, S);\n";
151 c += "}\n";
152 return c;
153 }
154
BindArguments(ArgumentsBinder * args)155 absl::Status StridedSlice::BindArguments(ArgumentsBinder* args) {
156 int4 offset = GetOffset(attributes_, src_[0]->Width(), src_[0]->Height(),
157 src_[0]->Channels(), src_[0]->Batch());
158 RETURN_IF_ERROR(args->SetInt("offset_x", offset.x));
159 RETURN_IF_ERROR(args->SetInt("offset_y", offset.y));
160 RETURN_IF_ERROR(args->SetInt("offset_z", offset.z));
161 RETURN_IF_ERROR(args->SetInt("offset_b", offset.w));
162 RETURN_IF_ERROR(args->SetInt("stride_x", attributes_.strides.w));
163 RETURN_IF_ERROR(args->SetInt("stride_y", attributes_.strides.h));
164 RETURN_IF_ERROR(args->SetInt("stride_z", attributes_.strides.c));
165 RETURN_IF_ERROR(args->SetInt("stride_b", attributes_.strides.b));
166 return absl::OkStatus();
167 }
168
GetGridSize() const169 int3 StridedSlice::GetGridSize() const {
170 const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
171 const int grid_y = dst_[0]->Height();
172 const int grid_z = dst_[0]->Slices();
173 return int3(grid_x, grid_y, grid_z);
174 }
175
CreateStridedSlice(const OperationDef & definition,const SliceAttributes & attr)176 StridedSlice CreateStridedSlice(const OperationDef& definition,
177 const SliceAttributes& attr) {
178 return StridedSlice(definition, attr);
179 }
180
181 } // namespace gpu
182 } // namespace tflite
183