1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/strided_slice_v2.h"
18
19 #include <algorithm>
20 #include <bitset>
21 #include <map>
22 #include <memory>
23 #include <ostream>
24 #include <set>
25 #include <string>
26 #include <vector>
27
28 #include "abstract/abstract_value.h"
29 #include "abstract/dshape.h"
30 #include "abstract/ops/op_infer.h"
31 #include "abstract/ops/primitive_infer_map.h"
32 #include "base/base.h"
33 #include "ir/anf.h"
34 #include "ir/dtype/number.h"
35 #include "ir/dtype/type.h"
36 #include "ir/primitive.h"
37 #include "ir/tensor.h"
38 #include "mindapi/base/shape_vector.h"
39 #include "mindapi/base/shared_ptr.h"
40 #include "mindapi/ir/value.h"
41 #include "mindapi/src/helper.h"
42 #include "mindspore/core/ops/other_ops.h"
43 #include "ops/op_name.h"
44 #include "ops/op_utils.h"
45 #include "ops/primitive_c.h"
46 #include "utils/check_convert_utils.h"
47 #include "utils/convert_utils_base.h"
48 #include "utils/log_adapter.h"
49
50 namespace mindspore {
51 namespace ops {
52 namespace {
53 constexpr size_t kStridedSliceMaxDim = 8;
54 constexpr int64_t n_two = 2;
55 constexpr int64_t n_eight = 8;
56
TenToTwoV2(int64_t num)57 std::vector<int64_t> TenToTwoV2(int64_t num) {
58 std::vector<int64_t> output;
59 if (num == 0) {
60 output.push_back(0);
61 return output;
62 }
63 while (num != 0) {
64 output.push_back(num % n_two);
65 num /= n_two;
66 }
67
68 return output;
69 }
70
GetAndCheckAttrMaskV2(const PrimitivePtr & primitive,std::vector<int64_t> * begin_pos,std::vector<int64_t> * end_pos,std::vector<int64_t> * ellipsis_pos,std::vector<int64_t> * new_axis_pos,std::vector<int64_t> * shrink_axis_pos)71 void GetAndCheckAttrMaskV2(const PrimitivePtr &primitive, std::vector<int64_t> *begin_pos,
72 std::vector<int64_t> *end_pos, std::vector<int64_t> *ellipsis_pos,
73 std::vector<int64_t> *new_axis_pos, std::vector<int64_t> *shrink_axis_pos) {
74 MS_EXCEPTION_IF_NULL(primitive);
75 auto begin_mask = GetValue<int64_t>(primitive->GetAttr(kBeginMask));
76 auto end_mask = GetValue<int64_t>(primitive->GetAttr(kEndMask));
77 auto ellipsis_mask = GetValue<int64_t>(primitive->GetAttr(kEllipsisMask));
78 auto new_axis_mask = GetValue<int64_t>(primitive->GetAttr(kNewAxisMask));
79 auto shrink_axis_mask = GetValue<int64_t>(primitive->GetAttr(kShrinkAxisMask));
80 if (begin_mask < 0 || end_mask < 0 || ellipsis_mask < 0 || new_axis_mask < 0 || shrink_axis_mask < 0) {
81 MS_EXCEPTION(ValueError) << "For 'StridedSliceV2', begin_mask or end_mask or ellipsis_mask or new_axis_mask or "
82 "shrink_axis_mask must more zero.";
83 }
84 *begin_pos = TenToTwoV2(begin_mask);
85 *end_pos = TenToTwoV2(end_mask);
86 *ellipsis_pos = TenToTwoV2(ellipsis_mask);
87 *new_axis_pos = TenToTwoV2(new_axis_mask);
88 *shrink_axis_pos = TenToTwoV2(shrink_axis_mask);
89
90 int ellipsis_count = 0;
91 std::vector<int64_t> &_ellipsis_pos = *ellipsis_pos;
92 for (size_t i = 0; i < _ellipsis_pos.size(); i++) {
93 if (_ellipsis_pos[i] == 1) {
94 ellipsis_count++;
95 }
96 }
97 if (ellipsis_count > 1) {
98 MS_EXCEPTION(ValueError) << "For 'StridedSliceV2', Only one non-zero bit is allowed in 'ellipsis_mask'.";
99 }
100 return;
101 }
102
GetSlicingLengthForPositiveStridesV2(int64_t start_pos,int64_t end_pos,int64_t strides,int64_t x_dim)103 int64_t GetSlicingLengthForPositiveStridesV2(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) {
104 int64_t slicing_length = 0;
105 if (strides == 0) {
106 MS_EXCEPTION(ValueError) << "For 'StridedSliceV2', input 'strides' can not contain 0.";
107 }
108 if ((start_pos < x_dim) && end_pos >= -x_dim) {
109 if (-x_dim <= start_pos && start_pos < 0) {
110 start_pos += x_dim;
111 }
112 if (start_pos < -x_dim) {
113 start_pos = 0;
114 }
115 if (-x_dim <= end_pos && end_pos < 0) {
116 end_pos += x_dim;
117 }
118 if (end_pos > x_dim) {
119 end_pos = x_dim;
120 }
121 if (start_pos >= end_pos) {
122 slicing_length = 0;
123 } else {
124 slicing_length = 1 + (end_pos - 1 - start_pos) / strides;
125 }
126 }
127 return slicing_length;
128 }
129
GetSlicingLengthForNegativeStridesV2(int64_t start_pos,int64_t end_pos,int64_t strides,int64_t x_dim)130 int64_t GetSlicingLengthForNegativeStridesV2(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) {
131 int64_t slicing_length = 0;
132 if (strides == 0) {
133 MS_EXCEPTION(ValueError) << "For 'StridedSliceV2', input 'strides' can not contain 0.";
134 }
135 if (start_pos >= -x_dim && end_pos < x_dim) {
136 if (start_pos >= 0 && start_pos < x_dim) {
137 start_pos += -x_dim;
138 }
139 if (start_pos >= x_dim) {
140 start_pos = -1;
141 }
142 if (end_pos >= 0 && end_pos < x_dim) {
143 end_pos += -x_dim;
144 }
145 if (end_pos < -x_dim - 1) {
146 end_pos = -x_dim - 1;
147 }
148 if (start_pos <= end_pos) {
149 slicing_length = 0;
150 } else {
151 slicing_length = 1 + (end_pos + 1 - start_pos) / strides;
152 }
153 }
154 return slicing_length;
155 }
156
ComputeSlicingLengthV2(int64_t start_pos,int64_t end_pos,int64_t strides,int64_t x_dim)157 int64_t ComputeSlicingLengthV2(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) {
158 int64_t slicing_length = 0;
159 if (strides == 0) {
160 MS_EXCEPTION(ValueError) << "For 'StridedSliceV2', input 'strides' can not contain 0.";
161 }
162 if (strides > 0) {
163 slicing_length = GetSlicingLengthForPositiveStridesV2(start_pos, end_pos, strides, x_dim);
164 }
165 if (strides < 0) {
166 slicing_length = GetSlicingLengthForNegativeStridesV2(start_pos, end_pos, strides, x_dim);
167 }
168 return slicing_length;
169 }
170
EllipsisInferShapeV2(const PrimitivePtr & primitive,const std::vector<int64_t> & x_shape,const std::vector<int64_t> & begin_v,const std::vector<int64_t> & end_v,const std::vector<int64_t> & strides_v,std::vector<int64_t> * infer_shape,size_t i,size_t j,bool has_ellipsis)171 void EllipsisInferShapeV2(const PrimitivePtr &primitive, const std::vector<int64_t> &x_shape,
172 const std::vector<int64_t> &begin_v, const std::vector<int64_t> &end_v,
173 const std::vector<int64_t> &strides_v, std::vector<int64_t> *infer_shape, size_t i, size_t j,
174 bool has_ellipsis) {
175 if (!has_ellipsis) {
176 return;
177 }
178 MS_EXCEPTION_IF_NULL(primitive);
179 size_t x_rank = x_shape.size();
180 size_t slice_len = begin_v.size();
181 std::vector<int64_t> begin_pos;
182 std::vector<int64_t> end_pos;
183 std::vector<int64_t> ellipsis_pos;
184 std::vector<int64_t> new_axis_pos;
185 std::vector<int64_t> shrink_axis_pos;
186 GetAndCheckAttrMaskV2(primitive, &begin_pos, &end_pos, &ellipsis_pos, &new_axis_pos, &shrink_axis_pos);
187 size_t num = 0;
188 for (size_t n = j + 1; n < slice_len; n++) {
189 if (new_axis_pos[n] == 1) {
190 num++;
191 }
192 }
193 size_t ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + num;
194 MS_EXCEPTION_IF_NULL(infer_shape);
195 (void)infer_shape->insert(infer_shape->end(), x_shape.begin() + SizeToLong(i),
196 x_shape.begin() + SizeToLong(i + ellipsis_occupied_dims));
197 j += 1;
198 i += ellipsis_occupied_dims;
199
200 while (i < x_rank || j < slice_len) {
201 int64_t x_dim_size = x_shape[i];
202 int64_t start = begin_v[j];
203 int64_t finish = end_v[j];
204 int64_t strides = strides_v[j];
205 if (j < begin_pos.size() && begin_pos[j] == 1) {
206 start = strides_v[j] < 0 ? -1 : 0;
207 }
208 if (j < end_pos.size() && end_pos[j] == 1) {
209 finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i];
210 }
211 if (j < new_axis_pos.size() && new_axis_pos[j] == 1) {
212 infer_shape->push_back(1);
213 j += 1;
214 continue;
215 }
216 if (j < shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
217 if (!(-x_shape[i] <= start && start < x_shape[i]) || strides < 0) {
218 MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the 'strides[" << j << "]' cannot be "
219 << "negative number and 'begin[" << j << "]' must be in [-" << x_shape[i] << ", "
220 << x_shape[i] << ") when 'shrink_axis_mask' is greater than 0, but got 'strides[" << j
221 << "]': " << strides << ", 'begin[" << j << "]': " << start << ".";
222 }
223 j += 1;
224 i += 1;
225 continue;
226 }
227 int64_t slicing_length = ComputeSlicingLengthV2(start, finish, strides, x_dim_size);
228 infer_shape->push_back(slicing_length);
229 i += 1;
230 j += 1;
231 }
232 return;
233 }
234
ComputeInferShapeV2(const PrimitivePtr & primitive,const std::vector<int64_t> & begin_v,const std::vector<int64_t> & end_v,const std::vector<int64_t> & strides_v,const std::vector<int64_t> & x_shape)235 std::vector<int64_t> ComputeInferShapeV2(const PrimitivePtr &primitive, const std::vector<int64_t> &begin_v,
236 const std::vector<int64_t> &end_v, const std::vector<int64_t> &strides_v,
237 const std::vector<int64_t> &x_shape) {
238 std::vector<int64_t> begin_pos;
239 std::vector<int64_t> end_pos;
240 std::vector<int64_t> ellipsis_pos;
241 std::vector<int64_t> new_axis_pos;
242 std::vector<int64_t> shrink_axis_pos;
243 GetAndCheckAttrMaskV2(primitive, &begin_pos, &end_pos, &ellipsis_pos, &new_axis_pos, &shrink_axis_pos);
244
245 size_t i = 0;
246 size_t j = 0;
247 int64_t start;
248 int64_t finish;
249 int64_t strides;
250 int64_t slicing_length;
251 bool has_ellipsis = false;
252 std::vector<int64_t> infer_shape;
253 infer_shape.clear();
254 size_t slice_len = begin_v.size();
255 size_t x_rank = x_shape.size();
256
257 while (i < x_rank || j < slice_len) {
258 int64_t x_dim_size = x_shape[i];
259 if (j < slice_len) {
260 start = begin_v[j];
261 finish = end_v[j];
262 strides = strides_v[j];
263 if (j < ellipsis_pos.size() && ellipsis_pos[j] == 1) {
264 has_ellipsis = true;
265 break;
266 }
267 if (j < begin_pos.size() && begin_pos[j] == 1) {
268 start = strides_v[j] < 0 ? -1 : 0;
269 }
270 if (j < end_pos.size() && end_pos[j] == 1) {
271 finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i];
272 }
273 if (j < new_axis_pos.size() && new_axis_pos[j] == 1) {
274 infer_shape.push_back(1);
275 j += 1;
276 continue;
277 }
278 if (j < shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
279 if (!(-x_shape[i] <= start && start < x_shape[i]) || strides < 0) {
280 MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the 'strides[" << j << "]' cannot be "
281 << "negative number and 'begin[" << j << "]' must be in [-" << x_shape[i] << ", "
282 << x_shape[i] << ") when 'shrink_axis_mask' is greater than 0, but got 'strides["
283 << j << "]': " << strides << ", 'begin[" << j << "]': " << start << ".";
284 }
285 j += 1;
286 i += 1;
287 continue;
288 }
289 } else {
290 start = 0;
291 finish = x_shape[i];
292 strides = 1;
293 }
294 slicing_length = ComputeSlicingLengthV2(start, finish, strides, x_dim_size);
295 infer_shape.push_back(slicing_length);
296 i += 1;
297 j += 1;
298 }
299 EllipsisInferShapeV2(primitive, x_shape, begin_v, end_v, strides_v, &infer_shape, i, j, has_ellipsis);
300 return infer_shape;
301 }
302
DynamicComputeInferShapeV2(const PrimitivePtr & primitive,const std::vector<int64_t> & x_shape,const size_t slice_len)303 ShapeMap DynamicComputeInferShapeV2(const PrimitivePtr &primitive, const std::vector<int64_t> &x_shape,
304 const size_t slice_len) {
305 // currently not support mask
306 std::vector<int64_t> begin_pos;
307 std::vector<int64_t> end_pos;
308 std::vector<int64_t> ellipsis_pos;
309 std::vector<int64_t> new_axis_pos;
310 std::vector<int64_t> shrink_axis_pos;
311 GetAndCheckAttrMaskV2(primitive, &begin_pos, &end_pos, &ellipsis_pos, &new_axis_pos, &shrink_axis_pos);
312
313 size_t i = 0;
314 size_t j = 0;
315 int64_t start;
316 int64_t finish;
317 int64_t strides;
318 ShapeMap shape_map;
319 std::vector<int64_t> infer_shape;
320 size_t x_rank = x_shape.size();
321
322 while (i < x_rank || j < slice_len) {
323 int64_t slicing_length = -1;
324 int64_t x_dim_size = x_shape[i];
325 if (x_dim_size == 1) {
326 slicing_length = 1;
327 }
328 if (j < slice_len) {
329 if (j < new_axis_pos.size() && new_axis_pos[j] == 1) {
330 j += 1;
331 continue;
332 }
333 if (j < shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
334 j += 1;
335 i += 1;
336 continue;
337 }
338 }
339 if (j >= slice_len && x_dim_size > 0) {
340 start = 0;
341 finish = x_shape[i];
342 strides = 1;
343 if (finish > 0) {
344 slicing_length = ComputeSlicingLengthV2(start, finish, strides, x_dim_size);
345 }
346 }
347 infer_shape.push_back(slicing_length);
348 i += 1;
349 j += 1;
350 }
351 shape_map[kShape] = infer_shape;
352 return shape_map;
353 }
354
CheckAndGetDynamicSliceV2(const AbstractBasePtr & input_arg,const std::string & arg_name,ShapeVector * slice_value,size_t * slice_len)355 bool CheckAndGetDynamicSliceV2(const AbstractBasePtr &input_arg, const std::string &arg_name, ShapeVector *slice_value,
356 size_t *slice_len) {
357 bool is_dynamic = false;
358 MS_EXCEPTION_IF_NULL(input_arg);
359 auto input_value = input_arg->GetValue();
360 MS_EXCEPTION_IF_NULL(input_value);
361 if (CheckAndConvertUtils::IsTuple(input_arg)) {
362 if (IsValueKnown(input_value)) {
363 *slice_value = CheckAndConvertUtils::CheckTupleInt(arg_name, input_value, "StridedSliceV2");
364 *slice_len = (*slice_value).size();
365 } else {
366 auto tuple_arg = input_arg->GetShape()->cast<abstract::SequenceShapePtr>();
367 MS_EXCEPTION_IF_NULL(tuple_arg);
368 *slice_len = tuple_arg->size();
369 }
370 } else if (CheckAndConvertUtils::IsTensor(input_arg)) {
371 (void)CheckAndConvertUtils::CheckTensorTypeValid(arg_name, input_arg->GetType(), {kInt64, kInt32},
372 "StridedSliceV2");
373 if (input_value->isa<tensor::Tensor>()) {
374 *slice_value =
375 CheckAndConvertUtils::CheckTensorIntValue(arg_name, input_value, "StridedSliceV2", input_arg->GetType());
376 *slice_len = (*slice_value).size();
377 } else {
378 // slice is ValueAny
379 is_dynamic = true;
380 auto slice_shape = CheckAndConvertUtils::GetTensorInputShape("StridedSliceV2", {input_arg}, 0);
381 if (slice_shape->shape().size() != 1) {
382 MS_EXCEPTION(ValueError) << "For 'StridedSliceV2', " << arg_name << " must be 1-D, but got"
383 << slice_shape->shape().size() << "-D.";
384 }
385 *slice_len = LongToSize(slice_shape->shape()[0]);
386 }
387 } else {
388 MS_EXCEPTION(TypeError) << "For 'StridedSliceV2', '" << arg_name
389 << "' must be tuple or Tensor, but got: " << input_arg->GetType()->ToString() << ".";
390 }
391
392 if (arg_name == "strides") {
393 if (std::any_of((*slice_value).begin(), (*slice_value).end(),
394 [](int64_t stride_value) { return stride_value == 0; })) {
395 MS_EXCEPTION(ValueError) << "For 'StridedSliceV2', input 'strides' can not contain 0.";
396 }
397 }
398 return is_dynamic;
399 }
400
StridedSliceV2InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)401 abstract::ShapePtr StridedSliceV2InferShape(const PrimitivePtr &primitive,
402 const std::vector<AbstractBasePtr> &input_args) {
403 MS_EXCEPTION_IF_NULL(primitive);
404 auto prim_name = primitive->name();
405 const size_t x_index = 0;
406 auto input_x_shape = CheckAndConvertUtils::GetTensorInputShape("StridedSliceV2", {input_args[x_index]}, 0);
407 if (input_x_shape->shape().size() < 1 || input_x_shape->shape().size() > kStridedSliceMaxDim) {
408 MS_EXCEPTION(ValueError) << "For 'StridedSliceV2', input_x must be 1D-8D, but got" << input_x_shape->shape().size()
409 << "-D.";
410 }
411 auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[x_index]->GetShape());
412 auto x_shape = shape_map[kShape];
413 bool x_is_dyn =
414 std::any_of(x_shape.begin(), x_shape.end(), [](int64_t value) { return value == abstract::Shape::kShapeDimAny; });
415 ShapeVector begin;
416 ShapeVector end;
417 ShapeVector strides;
418 ShapeVector ret_in_shape;
419 size_t begin_len = 0;
420 size_t end_len = 0;
421 size_t stride_len = 0;
422 const size_t begin_index = 1;
423 const size_t end_index = 2;
424 const size_t stride_index = 3;
425 bool begin_dynamic = CheckAndGetDynamicSliceV2(input_args[begin_index], "begin", &begin, &begin_len);
426 bool end_dynamic = CheckAndGetDynamicSliceV2(input_args[end_index], "end", &end, &end_len);
427 bool stride_dynamic = CheckAndGetDynamicSliceV2(input_args[stride_index], "strides", &strides, &stride_len);
428 if (begin_len != stride_len || end_len != stride_len) {
429 MS_EXCEPTION(ValueError) << "For '" << prim_name << "', 'begin', 'end' and 'strides' must have the same length, "
430 << "but got length of 'begin': " << begin_len << ", 'end': " << end_len
431 << ", 'strides': " << stride_len << ".";
432 }
433
434 bool slice_dynamic = false;
435 if (begin_dynamic || end_dynamic || stride_dynamic || x_is_dyn) {
436 slice_dynamic = true;
437 }
438 if (!slice_dynamic) {
439 ret_in_shape = ComputeInferShapeV2(primitive, begin, end, strides, x_shape);
440 return std::make_shared<abstract::Shape>(ret_in_shape);
441 }
442 auto ret_shape_map = DynamicComputeInferShapeV2(primitive, x_shape, begin_len);
443 ret_in_shape = ret_shape_map[kShape];
444
445 return std::make_shared<abstract::Shape>(ret_in_shape);
446 }
447
StridedSliceV2InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)448 TypePtr StridedSliceV2InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
449 MS_EXCEPTION_IF_NULL(primitive);
450 const size_t x_index = 0;
451 return CheckAndConvertUtils::GetTensorInputType(primitive->name(), input_args, x_index);
452 }
453 } // namespace
454
455 MIND_API_OPERATOR_IMPL(StridedSliceV2, BaseOperator);
set_begin_mask(int64_t begin_mask)456 void StridedSliceV2::set_begin_mask(int64_t begin_mask) {
457 (void)CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name());
458 (void)this->AddAttr(kBeginMask, api::MakeValue(begin_mask));
459 }
get_begin_mask() const460 int64_t StridedSliceV2::get_begin_mask() const {
461 auto value_ptr = GetAttr(kBeginMask);
462 return GetValue<int64_t>(value_ptr);
463 }
set_end_mask(int64_t end_mask)464 void StridedSliceV2::set_end_mask(int64_t end_mask) {
465 (void)CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name());
466 (void)this->AddAttr(kEndMask, api::MakeValue(end_mask));
467 }
get_end_mask() const468 int64_t StridedSliceV2::get_end_mask() const {
469 auto value_ptr = GetAttr(kEndMask);
470 return GetValue<int64_t>(value_ptr);
471 }
set_ellipsis_mask(int64_t ellipsis_mask)472 void StridedSliceV2::set_ellipsis_mask(int64_t ellipsis_mask) {
473 (void)CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name());
474 std::bitset<sizeof(int64_t) * n_eight> bs(ellipsis_mask);
475 std::ostringstream buffer;
476 if (bs.count() > 1) {
477 buffer << "For" << this->name() << ", only support one ellipsis in the index, but got " << this->get_end_mask()
478 << ".";
479 MS_EXCEPTION(ValueError) << buffer.str();
480 }
481 (void)this->AddAttr(kEllipsisMask, api::MakeValue(ellipsis_mask));
482 }
get_ellipsis_mask() const483 int64_t StridedSliceV2::get_ellipsis_mask() const {
484 auto value_ptr = GetAttr(kEllipsisMask);
485 return GetValue<int64_t>(value_ptr);
486 }
set_new_axis_mask(int64_t new_axis_mask)487 void StridedSliceV2::set_new_axis_mask(int64_t new_axis_mask) {
488 (void)CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name());
489 (void)this->AddAttr(kNewAxisMask, api::MakeValue(new_axis_mask));
490 }
get_new_axis_mask() const491 int64_t StridedSliceV2::get_new_axis_mask() const {
492 auto value_ptr = GetAttr(kNewAxisMask);
493 return GetValue<int64_t>(value_ptr);
494 }
set_shrink_axis_mask(int64_t shrink_axis_mask)495 void StridedSliceV2::set_shrink_axis_mask(int64_t shrink_axis_mask) {
496 (void)CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name());
497 (void)this->AddAttr(kShrinkAxisMask, api::MakeValue(shrink_axis_mask));
498 }
get_shrink_axis_mask() const499 int64_t StridedSliceV2::get_shrink_axis_mask() const {
500 auto value_ptr = GetAttr(kShrinkAxisMask);
501 return GetValue<int64_t>(value_ptr);
502 }
Init(int64_t begin_mask,int64_t end_mask,int64_t ellipsis_mask,int64_t new_axis_mask,int64_t shrink_axis_mask)503 void StridedSliceV2::Init(int64_t begin_mask, int64_t end_mask, int64_t ellipsis_mask, int64_t new_axis_mask,
504 int64_t shrink_axis_mask) {
505 this->set_begin_mask(begin_mask);
506 this->set_end_mask(end_mask);
507 this->set_ellipsis_mask(ellipsis_mask);
508 this->set_new_axis_mask(new_axis_mask);
509 this->set_shrink_axis_mask(shrink_axis_mask);
510 }
511
StridedSliceV2Infer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)512 AbstractBasePtr StridedSliceV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
513 const std::vector<AbstractBasePtr> &input_args) {
514 MS_EXCEPTION_IF_NULL(primitive);
515 const int64_t input_num = 4;
516 CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
517 return std::make_shared<abstract::AbstractTensor>(StridedSliceV2InferType(primitive, input_args),
518 StridedSliceV2InferShape(primitive, input_args));
519 }
520
521 // AG means auto generated
522 class MIND_API AGStridedSliceV2Infer : public abstract::OpInferBase {
523 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const524 BaseShapePtr InferShape(const PrimitivePtr &primitive,
525 const std::vector<AbstractBasePtr> &input_args) const override {
526 return StridedSliceV2InferShape(primitive, input_args);
527 }
528
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const529 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
530 return StridedSliceV2InferType(primitive, input_args);
531 }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const532 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
533 const std::vector<AbstractBasePtr> &input_args) const override {
534 return StridedSliceV2Infer(engine, primitive, input_args);
535 }
536
GetValueDependArgIndices() const537 std::set<int64_t> GetValueDependArgIndices() const override { return {1, 2, 3}; }
538 };
539
540 REGISTER_PRIMITIVE_OP_INFER_IMPL(StridedSliceV2, prim::kPrimStridedSliceV2, AGStridedSliceV2Infer, false);
541 } // namespace ops
542 } // namespace mindspore
543