• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/strided_slice.h"
17 
18 #include <string>
19 
20 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
21 
22 namespace tflite {
23 namespace gpu {
24 
25 namespace {
Is4Aligned(const SliceAttributes & attr)26 bool Is4Aligned(const SliceAttributes& attr) {
27   return attr.strides.c == 1 && attr.starts.c % 4 == 0;
28 }
29 
GetOffset(const SliceAttributes & attr,int src_width,int src_height,int src_channels,int src_batch)30 int4 GetOffset(const SliceAttributes& attr, int src_width, int src_height,
31                int src_channels, int src_batch) {
32   int4 offset;
33   if (attr.strides.w > 0) {
34     offset.x = attr.starts.w;
35   } else {
36     if (attr.ends.w > 0) {
37       offset.x = attr.ends.w;
38     } else {
39       offset.x = src_width + attr.ends.w;
40     }
41   }
42   if (attr.strides.h > 0) {
43     offset.y = attr.starts.h;
44   } else {
45     if (attr.ends.h > 0) {
46       offset.y = attr.ends.h;
47     } else {
48       offset.y = src_height + attr.ends.h;
49     }
50   }
51   if (attr.strides.c > 0) {
52     offset.z = attr.starts.c;
53   } else {
54     if (attr.ends.c > 0) {
55       offset.z = attr.ends.c;
56     } else {
57       offset.z = src_channels + attr.ends.c;
58     }
59   }
60   if (Is4Aligned(attr)) {
61     offset.z /= 4;
62   }
63   if (attr.strides.b > 0) {
64     offset.w = attr.starts.b;
65   } else {
66     if (attr.ends.b > 0) {
67       offset.w = attr.ends.b;
68     } else {
69       offset.w = src_batch + attr.ends.b;
70     }
71   }
72   return offset;
73 }
74 
75 }  // namespace
76 
StridedSlice(const OperationDef & definition,const SliceAttributes & attr)77 StridedSlice::StridedSlice(const OperationDef& definition,
78                            const SliceAttributes& attr)
79     : GPUOperation(definition), attributes_(attr) {
80   work_group_size_ = int3(8, 4, 1);
81   code_ = GetStridedSliceCode(definition_, Is4Aligned(attributes_));
82 }
83 
StridedSlice(StridedSlice && operation)84 StridedSlice::StridedSlice(StridedSlice&& operation)
85     : GPUOperation(std::move(operation)), attributes_(operation.attributes_) {}
86 
operator =(StridedSlice && operation)87 StridedSlice& StridedSlice::operator=(StridedSlice&& operation) {
88   if (this != &operation) {
89     attributes_ = operation.attributes_;
90     GPUOperation::operator=(std::move(operation));
91   }
92   return *this;
93 }
94 
GetStridedSliceCode(const OperationDef & op_def,bool alignedx4)95 std::string StridedSlice::GetStridedSliceCode(const OperationDef& op_def,
96                                               bool alignedx4) {
97   AddSrcTensor("src_tensor", op_def.src_tensors[0]);
98   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
99   args_.AddInt("offset_x");
100   args_.AddInt("offset_y");
101   args_.AddInt("offset_z");
102   args_.AddInt("offset_b");
103   args_.AddInt("stride_x");
104   args_.AddInt("stride_y");
105   args_.AddInt("stride_z");
106   args_.AddInt("stride_b");
107 
108   const std::string batch_id =
109       op_def.dst_tensors[0].HasAxis(Axis::BATCH) ? "B" : "0";
110   std::string c;
111   c += "MAIN_FUNCTION($0) {\n";
112   if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
113     c += "  int linear_id = GLOBAL_ID_0;\n";
114     c += "  int X = linear_id / args.dst_tensor.Batch();\n";
115     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
116     c += "  args.dst_tensor.SetBatchRef(B);\n";
117   } else {
118     c += "  int X = GLOBAL_ID_0;\n";
119   }
120   c += "  int Y = GLOBAL_ID_1;\n";
121   c += "  int S = GLOBAL_ID_2;\n";
122   c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
123        "S >= args.dst_tensor.Slices()) { \n";
124   c += "    return; \n";
125   c += "  } \n";
126   c += "  int s_x = X * args.stride_x + args.offset_x;\n";
127   c += "  int s_y = Y * args.stride_y + args.offset_y;\n";
128   if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) {
129     c += "  int s_b = " + batch_id + " * args.stride_b + args.offset_b;\n";
130     c += "  args.src_tensor.SetBatchRef(s_b);\n";
131   }
132   if (alignedx4) {
133     c += "  int s_z = S + args.offset_z;\n";
134     c += "  FLT4 result = args.src_tensor.Read(s_x, s_y, s_z);\n";
135   } else {
136     c += "  FLT4 result;\n";
137     const std::string postfixes[] = {"x", "y", "z", "w"};
138     for (int i = 0; i < 4; ++i) {
139       c += "  {\n";
140       const std::string channel = "(S * 4 + " + std::to_string(i) + ")";
141       c += "    int s_ch = " + channel + " * args.stride_z + args.offset_z;\n";
142       c += "    int s_z = min(s_ch >> 2, args.src_tensor.Slices() - 1);\n";
143       c += "    int s_z_rem = s_ch & 3;\n";
144       c += "    FLT4 t = args.src_tensor.Read(s_x, s_y, s_z);\n";
145       c += "    FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n";
146       c += "    result." + postfixes[i] + " = t_ar[s_z_rem];\n";
147       c += "  }\n";
148     }
149   }
150   c += "  args.dst_tensor.Write(result, X, Y, S);\n";
151   c += "}\n";
152   return c;
153 }
154 
BindArguments(ArgumentsBinder * args)155 absl::Status StridedSlice::BindArguments(ArgumentsBinder* args) {
156   int4 offset = GetOffset(attributes_, src_[0]->Width(), src_[0]->Height(),
157                           src_[0]->Channels(), src_[0]->Batch());
158   RETURN_IF_ERROR(args->SetInt("offset_x", offset.x));
159   RETURN_IF_ERROR(args->SetInt("offset_y", offset.y));
160   RETURN_IF_ERROR(args->SetInt("offset_z", offset.z));
161   RETURN_IF_ERROR(args->SetInt("offset_b", offset.w));
162   RETURN_IF_ERROR(args->SetInt("stride_x", attributes_.strides.w));
163   RETURN_IF_ERROR(args->SetInt("stride_y", attributes_.strides.h));
164   RETURN_IF_ERROR(args->SetInt("stride_z", attributes_.strides.c));
165   RETURN_IF_ERROR(args->SetInt("stride_b", attributes_.strides.b));
166   return absl::OkStatus();
167 }
168 
GetGridSize() const169 int3 StridedSlice::GetGridSize() const {
170   const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
171   const int grid_y = dst_[0]->Height();
172   const int grid_z = dst_[0]->Slices();
173   return int3(grid_x, grid_y, grid_z);
174 }
175 
CreateStridedSlice(const OperationDef & definition,const SliceAttributes & attr)176 StridedSlice CreateStridedSlice(const OperationDef& definition,
177                                 const SliceAttributes& attr) {
178   return StridedSlice(definition, attr);
179 }
180 
181 }  // namespace gpu
182 }  // namespace tflite
183