1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/strided_slice.h"
18 #include <string>
19 #include <algorithm>
20 #include <memory>
21 #include <set>
22 #include <vector>
23 #include <bitset>
24 #include "ops/op_utils.h"
25 #include "utils/check_convert_utils.h"
26 #include "abstract/primitive_infer_map.h"
27
28 namespace mindspore {
29 namespace ops {
30 namespace {
TenToTwo(int64_t num)31 std::vector<int64_t> TenToTwo(int64_t num) {
32 std::vector<int64_t> output;
33 if (num == 0) {
34 output.push_back(0);
35 return output;
36 }
37 while (num) {
38 output.push_back(num % 2);
39 num /= 2;
40 }
41
42 return output;
43 }
44
get_stride_with_not_zero(int64_t start_pos,int64_t end_pos,int64_t strides)45 int64_t get_stride_with_not_zero(int64_t start_pos, int64_t end_pos, int64_t strides) {
46 int64_t slicing_length = 0;
47 if (strides != 0) {
48 slicing_length = 1 + (end_pos + 1 - start_pos) / strides;
49 } else {
50 MS_EXCEPTION(ValueError) << "the strides must be non-zero but got " << strides;
51 }
52 return slicing_length;
53 }
54
EllipsisInferShape(const PrimitivePtr & primitive,const std::vector<int64_t> & x_shape,const std::vector<int64_t> & begin_v,const std::vector<int64_t> & end_v,const std::vector<int64_t> & strides_v,std::vector<int64_t> * infer_shape,size_t i,size_t j,bool has_ellipsis)55 void EllipsisInferShape(const PrimitivePtr &primitive, const std::vector<int64_t> &x_shape,
56 const std::vector<int64_t> &begin_v, const std::vector<int64_t> &end_v,
57 const std::vector<int64_t> &strides_v, std::vector<int64_t> *infer_shape, size_t i, size_t j,
58 bool has_ellipsis) {
59 if (!has_ellipsis) {
60 return;
61 }
62 MS_EXCEPTION_IF_NULL(primitive);
63 auto strided_slice_prim = primitive->cast<PrimStridedSlicePtr>();
64 MS_EXCEPTION_IF_NULL(strided_slice_prim);
65 size_t x_rank = x_shape.size();
66 size_t slice_len = begin_v.size();
67 std::vector<int64_t> begin_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kBeginMask)));
68 std::vector<int64_t> end_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kEndMask)));
69 std::vector<int64_t> new_axis_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kNewAxisMask)));
70 std::vector<int64_t> shrink_axis_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kShrinkAxisMask)));
71 (void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(new_axis_pos.size()), kGreaterEqual,
72 SizeToLong(slice_len), primitive->name());
73
74 size_t num = 0;
75 for (size_t n = j + 1; n < slice_len; n++) {
76 if (new_axis_pos[n] == 1) {
77 num++;
78 }
79 }
80
81 size_t ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + num;
82 (void)infer_shape->insert(infer_shape->end(), x_shape.begin() + SizeToInt(i),
83 x_shape.begin() + SizeToLong(i + ellipsis_occupied_dims));
84 j += 1;
85 i += ellipsis_occupied_dims;
86
87 while (i < x_rank || j < slice_len) {
88 int64_t x_dim_size = x_shape[i];
89 int64_t start = begin_v[j];
90 int64_t finish = end_v[j];
91 int64_t strides = strides_v[j];
92 if (j < begin_pos.size() || j < slice_len) {
93 start = strides_v[j] < 0 ? -1 : 0;
94 }
95 if (j < end_pos.size() && end_pos[j] == 1) {
96 finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i];
97 }
98 if (j < new_axis_pos.size() && new_axis_pos[j] == 1) {
99 infer_shape->push_back(1);
100 j += 1;
101 continue;
102 }
103 if (j < shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
104 if ((-x_shape[i] <= start && start < x_shape[i]) || strides < 0) {
105 MS_EXCEPTION(ValueError) << "when shrink axis, the stride cannot be negative number";
106 }
107 j += 1;
108 i += 1;
109 continue;
110 }
111 int64_t slicing_length = strided_slice_prim->compute_slicing_length(start, finish, strides, x_dim_size);
112 infer_shape->push_back(slicing_length);
113 i += 1;
114 j += 1;
115 }
116 return;
117 }
118
CheckAndGetValidStrides(const AbstractBasePtr & stride_arg)119 const std::vector<int64_t> CheckAndGetValidStrides(const AbstractBasePtr &stride_arg) {
120 MS_EXCEPTION_IF_NULL(stride_arg);
121 auto temp_strides = stride_arg->cast<abstract::AbstractTuplePtr>()->BuildValue();
122 MS_EXCEPTION_IF_NULL(temp_strides);
123 auto strides = GetValue<std::vector<int64_t>>(temp_strides);
124 if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) { return stride == 0; })) {
125 MS_EXCEPTION(ValueError) << "StridedSlice's input strides cannot contain 0.";
126 }
127 return strides;
128 }
129
ComputeInferShape(const PrimitivePtr & primitive,const std::vector<int64_t> & begin_v,const std::vector<int64_t> & end_v,const std::vector<int64_t> & x_shape,const std::vector<int64_t> & strides_v)130 std::vector<int64_t> ComputeInferShape(const PrimitivePtr &primitive, const std::vector<int64_t> &begin_v,
131 const std::vector<int64_t> &end_v, const std::vector<int64_t> &x_shape,
132 const std::vector<int64_t> &strides_v) {
133 std::vector<int64_t> begin_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kBeginMask)));
134 std::vector<int64_t> end_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kEndMask)));
135 std::vector<int64_t> ellipsis_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kEllipsisMask)));
136 std::vector<int64_t> new_axis_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kNewAxisMask)));
137 std::vector<int64_t> shrink_axis_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kShrinkAxisMask)));
138 size_t i = 0;
139 size_t j = 0;
140 int64_t start;
141 int64_t finish;
142 int64_t strides;
143 int64_t slicing_length;
144 bool has_ellipsis = false;
145 std::vector<int64_t> infer_shape;
146 size_t slice_len = begin_v.size();
147 size_t x_rank = x_shape.size();
148 (void)CheckAndConvertUtils::CheckInteger("end_v size", SizeToLong(end_v.size()), kGreaterEqual, SizeToLong(slice_len),
149 primitive->name());
150 (void)CheckAndConvertUtils::CheckInteger("strides_v size", SizeToLong(strides_v.size()), kGreaterEqual,
151 SizeToLong(slice_len), primitive->name());
152 while (i < x_rank || j < slice_len) {
153 int64_t x_dim_size = x_shape[i];
154 if (j < slice_len) {
155 start = begin_v[j];
156 finish = end_v[j];
157 strides = strides_v[j];
158 if (j < ellipsis_pos.size() && ellipsis_pos[j] == 1) {
159 has_ellipsis = true;
160 break;
161 }
162 if (j < begin_pos.size() && begin_pos[j] == 1) {
163 start = strides_v[j] < 0 ? -1 : 0;
164 }
165 if (j < end_pos.size() && end_pos[j] == 1) {
166 finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i];
167 }
168 if (j < new_axis_pos.size() && new_axis_pos[j] == 1) {
169 infer_shape.push_back(1);
170 j += 1;
171 continue;
172 }
173 if (j < shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
174 if ((-x_shape[i] <= start && start < x_shape[i]) || strides < 0) {
175 MS_EXCEPTION(ValueError) << "when shrink axis, the stride cannot be negative number";
176 }
177 j += 1;
178 i += 1;
179 continue;
180 }
181 } else {
182 start = 0;
183 finish = x_shape[0];
184 strides = 1;
185 }
186 auto strided_slice_prim = primitive->cast<PrimStridedSlicePtr>();
187 MS_EXCEPTION_IF_NULL(strided_slice_prim);
188 slicing_length = strided_slice_prim->compute_slicing_length(start, finish, strides, x_dim_size);
189 infer_shape.push_back(slicing_length);
190 i += 1;
191 j += 1;
192 }
193 EllipsisInferShape(primitive, x_shape, begin_v, end_v, strides_v, &infer_shape, i, j, has_ellipsis);
194 return infer_shape;
195 }
196
StridedSliceInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)197 abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
198 const std::vector<AbstractBasePtr> &input_args) {
199 MS_EXCEPTION_IF_NULL(primitive);
200 auto tuple_begin_v = input_args[kInputIndex1]->cast<abstract::AbstractTuplePtr>();
201 MS_EXCEPTION_IF_NULL(tuple_begin_v);
202 auto temp_begin_v = tuple_begin_v->BuildValue();
203 MS_EXCEPTION_IF_NULL(temp_begin_v);
204 auto begin_v = GetValue<std::vector<int64_t>>(temp_begin_v);
205
206 auto tuple_end_v = input_args[kInputIndex2]->cast<abstract::AbstractTuplePtr>();
207 MS_EXCEPTION_IF_NULL(tuple_end_v);
208 auto temp_end_v = tuple_end_v->BuildValue();
209 MS_EXCEPTION_IF_NULL(temp_end_v);
210 auto end_v = GetValue<std::vector<int64_t>>(temp_end_v);
211 auto strides_v = CheckAndGetValidStrides(input_args[kInputIndex3]);
212
213 auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
214 auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape];
215 auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMaxShape];
216 auto ret_in_shape = ComputeInferShape(primitive, begin_v, end_v, x_shape, strides_v);
217 if (min_shape.empty() || max_shape.empty()) {
218 return std::make_shared<abstract::Shape>(ret_in_shape);
219 }
220 auto ret_min_shape = ComputeInferShape(primitive, begin_v, end_v, min_shape, strides_v);
221 auto ret_max_shape = ComputeInferShape(primitive, begin_v, end_v, max_shape, strides_v);
222 return std::make_shared<abstract::Shape>(ret_in_shape, ret_min_shape, ret_max_shape);
223 }
224
StridedSliceInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)225 TypePtr StridedSliceInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
226 const int64_t x_index = 0;
227 return CheckAndConvertUtils::GetInputTensorType(input_args, x_index, primitive->name());
228 }
229 } // namespace
230
set_begin_mask(const int64_t begin_mask)231 void StridedSlice::set_begin_mask(const int64_t begin_mask) {
232 (void)CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name());
233 (void)this->AddAttr(kBeginMask, MakeValue(begin_mask));
234 }
get_begin_mask() const235 int64_t StridedSlice::get_begin_mask() const {
236 auto value_ptr = GetAttr(kBeginMask);
237 return GetValue<int64_t>(value_ptr);
238 }
set_end_mask(const int64_t end_mask)239 void StridedSlice::set_end_mask(const int64_t end_mask) {
240 (void)CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name());
241 (void)this->AddAttr(kEndMask, MakeValue(end_mask));
242 }
get_end_mask() const243 int64_t StridedSlice::get_end_mask() const {
244 auto value_ptr = GetAttr(kEndMask);
245 return GetValue<int64_t>(value_ptr);
246 }
set_ellipsis_mask(const int64_t ellipsis_mask)247 void StridedSlice::set_ellipsis_mask(const int64_t ellipsis_mask) {
248 (void)CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name());
249 std::bitset<sizeof(int64_t) * 8> bs(ellipsis_mask);
250 std::ostringstream buffer;
251 if (bs.count() > 1) {
252 buffer << "For" << this->name() << ", only support one ellipsis in the index, but got " << this->get_end_mask();
253 MS_EXCEPTION(ValueError) << buffer.str();
254 }
255 (void)this->AddAttr(kEllipsisMask, MakeValue(ellipsis_mask));
256 }
get_ellipsis_mask() const257 int64_t StridedSlice::get_ellipsis_mask() const {
258 auto value_ptr = GetAttr(kEllipsisMask);
259 return GetValue<int64_t>(value_ptr);
260 }
set_new_axis_mask(const int64_t new_axis_mask)261 void StridedSlice::set_new_axis_mask(const int64_t new_axis_mask) {
262 (void)CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name());
263 (void)this->AddAttr(kNewAxisMask, MakeValue(new_axis_mask));
264 }
get_new_axis_mask() const265 int64_t StridedSlice::get_new_axis_mask() const {
266 auto value_ptr = GetAttr(kNewAxisMask);
267 return GetValue<int64_t>(value_ptr);
268 }
set_shrink_axis_mask(const int64_t shrink_axis_mask)269 void StridedSlice::set_shrink_axis_mask(const int64_t shrink_axis_mask) {
270 (void)CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name());
271 (void)this->AddAttr(kShrinkAxisMask, MakeValue(shrink_axis_mask));
272 }
get_shrink_axis_mask() const273 int64_t StridedSlice::get_shrink_axis_mask() const {
274 auto value_ptr = GetAttr(kShrinkAxisMask);
275 return GetValue<int64_t>(value_ptr);
276 }
Init(const int64_t begin_mask,const int64_t end_mask,const int64_t ellipsis_mask,const int64_t new_axis_mask,const int64_t shrink_axis_mask)277 void StridedSlice::Init(const int64_t begin_mask, const int64_t end_mask, const int64_t ellipsis_mask,
278 const int64_t new_axis_mask, const int64_t shrink_axis_mask) {
279 this->set_begin_mask(begin_mask);
280 this->set_end_mask(end_mask);
281 this->set_ellipsis_mask(ellipsis_mask);
282 this->set_new_axis_mask(new_axis_mask);
283 this->set_shrink_axis_mask(shrink_axis_mask);
284 }
285
compute_slicing_length(int64_t start_pos,int64_t end_pos,int64_t strides,int64_t x_dim) const286 int64_t StridedSlice::compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) const {
287 int64_t slicing_length = 0;
288 if (strides > 0) {
289 if ((start_pos >= x_dim) || end_pos < -x_dim) {
290 slicing_length = 0;
291 } else {
292 if (-x_dim <= start_pos && start_pos < 0) {
293 start_pos += x_dim;
294 }
295 if (start_pos < -x_dim) {
296 start_pos = 0;
297 }
298 if (-x_dim <= end_pos && end_pos < 0) {
299 end_pos += x_dim;
300 }
301 if (end_pos > x_dim) {
302 end_pos = x_dim;
303 }
304 if (start_pos > end_pos) {
305 slicing_length = 0;
306 } else {
307 slicing_length = 1 + (end_pos - 1 - start_pos) / strides;
308 }
309 }
310 } else {
311 if (start_pos < -x_dim || end_pos >= x_dim) {
312 slicing_length = 0;
313 } else {
314 if (start_pos > 0 && start_pos < x_dim) {
315 start_pos += -x_dim;
316 }
317 if (start_pos >= x_dim) {
318 start_pos = -1;
319 }
320 if (end_pos >= 0 && end_pos < x_dim) {
321 end_pos += -x_dim;
322 }
323 if (end_pos < -x_dim - 1) {
324 end_pos = -x_dim - 1;
325 }
326 if (start_pos <= end_pos) {
327 slicing_length = 0;
328 } else {
329 slicing_length = get_stride_with_not_zero(start_pos, end_pos, strides);
330 }
331 }
332 }
333 return slicing_length;
334 }
335
StridedSliceInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)336 AbstractBasePtr StridedSliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
337 const std::vector<AbstractBasePtr> &input_args) {
338 MS_EXCEPTION_IF_NULL(primitive);
339 const int64_t input_num = 4;
340 CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
341 return std::make_shared<abstract::AbstractTensor>(StridedSliceInferType(primitive, input_args),
342 StridedSliceInferShape(primitive, input_args));
343 }
344 REGISTER_PRIMITIVE_C(kNameStridedSlice, StridedSlice);
345 } // namespace ops
346 } // namespace mindspore
347