• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/strided_slice_v2.h"
18 
19 #include <algorithm>
20 #include <bitset>
21 #include <map>
22 #include <memory>
23 #include <ostream>
24 #include <set>
25 #include <string>
26 #include <vector>
27 
28 #include "abstract/abstract_value.h"
29 #include "abstract/dshape.h"
30 #include "abstract/ops/op_infer.h"
31 #include "abstract/ops/primitive_infer_map.h"
32 #include "base/base.h"
33 #include "ir/anf.h"
34 #include "ir/dtype/number.h"
35 #include "ir/dtype/type.h"
36 #include "ir/primitive.h"
37 #include "ir/tensor.h"
38 #include "mindapi/base/shape_vector.h"
39 #include "mindapi/base/shared_ptr.h"
40 #include "mindapi/ir/value.h"
41 #include "mindapi/src/helper.h"
42 #include "mindspore/core/ops/other_ops.h"
43 #include "ops/op_name.h"
44 #include "ops/op_utils.h"
45 #include "ops/primitive_c.h"
46 #include "utils/check_convert_utils.h"
47 #include "utils/convert_utils_base.h"
48 #include "utils/log_adapter.h"
49 
50 namespace mindspore {
51 namespace ops {
52 namespace {
53 constexpr size_t kStridedSliceMaxDim = 8;
54 constexpr int64_t n_two = 2;
55 constexpr int64_t n_eight = 8;
56 
TenToTwoV2(int64_t num)57 std::vector<int64_t> TenToTwoV2(int64_t num) {
58   std::vector<int64_t> output;
59   if (num == 0) {
60     output.push_back(0);
61     return output;
62   }
63   while (num != 0) {
64     output.push_back(num % n_two);
65     num /= n_two;
66   }
67 
68   return output;
69 }
70 
GetAndCheckAttrMaskV2(const PrimitivePtr & primitive,std::vector<int64_t> * begin_pos,std::vector<int64_t> * end_pos,std::vector<int64_t> * ellipsis_pos,std::vector<int64_t> * new_axis_pos,std::vector<int64_t> * shrink_axis_pos)71 void GetAndCheckAttrMaskV2(const PrimitivePtr &primitive, std::vector<int64_t> *begin_pos,
72                            std::vector<int64_t> *end_pos, std::vector<int64_t> *ellipsis_pos,
73                            std::vector<int64_t> *new_axis_pos, std::vector<int64_t> *shrink_axis_pos) {
74   MS_EXCEPTION_IF_NULL(primitive);
75   auto begin_mask = GetValue<int64_t>(primitive->GetAttr(kBeginMask));
76   auto end_mask = GetValue<int64_t>(primitive->GetAttr(kEndMask));
77   auto ellipsis_mask = GetValue<int64_t>(primitive->GetAttr(kEllipsisMask));
78   auto new_axis_mask = GetValue<int64_t>(primitive->GetAttr(kNewAxisMask));
79   auto shrink_axis_mask = GetValue<int64_t>(primitive->GetAttr(kShrinkAxisMask));
80   if (begin_mask < 0 || end_mask < 0 || ellipsis_mask < 0 || new_axis_mask < 0 || shrink_axis_mask < 0) {
81     MS_EXCEPTION(ValueError) << "For 'StridedSliceV2', begin_mask or end_mask or ellipsis_mask or new_axis_mask or "
82                                 "shrink_axis_mask must more zero.";
83   }
84   *begin_pos = TenToTwoV2(begin_mask);
85   *end_pos = TenToTwoV2(end_mask);
86   *ellipsis_pos = TenToTwoV2(ellipsis_mask);
87   *new_axis_pos = TenToTwoV2(new_axis_mask);
88   *shrink_axis_pos = TenToTwoV2(shrink_axis_mask);
89 
90   int ellipsis_count = 0;
91   std::vector<int64_t> &_ellipsis_pos = *ellipsis_pos;
92   for (size_t i = 0; i < _ellipsis_pos.size(); i++) {
93     if (_ellipsis_pos[i] == 1) {
94       ellipsis_count++;
95     }
96   }
97   if (ellipsis_count > 1) {
98     MS_EXCEPTION(ValueError) << "For 'StridedSliceV2', Only one non-zero bit is allowed in 'ellipsis_mask'.";
99   }
100   return;
101 }
102 
GetSlicingLengthForPositiveStridesV2(int64_t start_pos,int64_t end_pos,int64_t strides,int64_t x_dim)103 int64_t GetSlicingLengthForPositiveStridesV2(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) {
104   int64_t slicing_length = 0;
105   if (strides == 0) {
106     MS_EXCEPTION(ValueError) << "For 'StridedSliceV2', input 'strides' can not contain 0.";
107   }
108   if ((start_pos < x_dim) && end_pos >= -x_dim) {
109     if (-x_dim <= start_pos && start_pos < 0) {
110       start_pos += x_dim;
111     }
112     if (start_pos < -x_dim) {
113       start_pos = 0;
114     }
115     if (-x_dim <= end_pos && end_pos < 0) {
116       end_pos += x_dim;
117     }
118     if (end_pos > x_dim) {
119       end_pos = x_dim;
120     }
121     if (start_pos >= end_pos) {
122       slicing_length = 0;
123     } else {
124       slicing_length = 1 + (end_pos - 1 - start_pos) / strides;
125     }
126   }
127   return slicing_length;
128 }
129 
GetSlicingLengthForNegativeStridesV2(int64_t start_pos,int64_t end_pos,int64_t strides,int64_t x_dim)130 int64_t GetSlicingLengthForNegativeStridesV2(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) {
131   int64_t slicing_length = 0;
132   if (strides == 0) {
133     MS_EXCEPTION(ValueError) << "For 'StridedSliceV2', input 'strides' can not contain 0.";
134   }
135   if (start_pos >= -x_dim && end_pos < x_dim) {
136     if (start_pos >= 0 && start_pos < x_dim) {
137       start_pos += -x_dim;
138     }
139     if (start_pos >= x_dim) {
140       start_pos = -1;
141     }
142     if (end_pos >= 0 && end_pos < x_dim) {
143       end_pos += -x_dim;
144     }
145     if (end_pos < -x_dim - 1) {
146       end_pos = -x_dim - 1;
147     }
148     if (start_pos <= end_pos) {
149       slicing_length = 0;
150     } else {
151       slicing_length = 1 + (end_pos + 1 - start_pos) / strides;
152     }
153   }
154   return slicing_length;
155 }
156 
ComputeSlicingLengthV2(int64_t start_pos,int64_t end_pos,int64_t strides,int64_t x_dim)157 int64_t ComputeSlicingLengthV2(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) {
158   int64_t slicing_length = 0;
159   if (strides == 0) {
160     MS_EXCEPTION(ValueError) << "For 'StridedSliceV2', input 'strides' can not contain 0.";
161   }
162   if (strides > 0) {
163     slicing_length = GetSlicingLengthForPositiveStridesV2(start_pos, end_pos, strides, x_dim);
164   }
165   if (strides < 0) {
166     slicing_length = GetSlicingLengthForNegativeStridesV2(start_pos, end_pos, strides, x_dim);
167   }
168   return slicing_length;
169 }
170 
EllipsisInferShapeV2(const PrimitivePtr & primitive,const std::vector<int64_t> & x_shape,const std::vector<int64_t> & begin_v,const std::vector<int64_t> & end_v,const std::vector<int64_t> & strides_v,std::vector<int64_t> * infer_shape,size_t i,size_t j,bool has_ellipsis)171 void EllipsisInferShapeV2(const PrimitivePtr &primitive, const std::vector<int64_t> &x_shape,
172                           const std::vector<int64_t> &begin_v, const std::vector<int64_t> &end_v,
173                           const std::vector<int64_t> &strides_v, std::vector<int64_t> *infer_shape, size_t i, size_t j,
174                           bool has_ellipsis) {
175   if (!has_ellipsis) {
176     return;
177   }
178   MS_EXCEPTION_IF_NULL(primitive);
179   size_t x_rank = x_shape.size();
180   size_t slice_len = begin_v.size();
181   std::vector<int64_t> begin_pos;
182   std::vector<int64_t> end_pos;
183   std::vector<int64_t> ellipsis_pos;
184   std::vector<int64_t> new_axis_pos;
185   std::vector<int64_t> shrink_axis_pos;
186   GetAndCheckAttrMaskV2(primitive, &begin_pos, &end_pos, &ellipsis_pos, &new_axis_pos, &shrink_axis_pos);
187   size_t num = 0;
188   for (size_t n = j + 1; n < slice_len; n++) {
189     if (new_axis_pos[n] == 1) {
190       num++;
191     }
192   }
193   size_t ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + num;
194   MS_EXCEPTION_IF_NULL(infer_shape);
195   (void)infer_shape->insert(infer_shape->end(), x_shape.begin() + SizeToLong(i),
196                             x_shape.begin() + SizeToLong(i + ellipsis_occupied_dims));
197   j += 1;
198   i += ellipsis_occupied_dims;
199 
200   while (i < x_rank || j < slice_len) {
201     int64_t x_dim_size = x_shape[i];
202     int64_t start = begin_v[j];
203     int64_t finish = end_v[j];
204     int64_t strides = strides_v[j];
205     if (j < begin_pos.size() && begin_pos[j] == 1) {
206       start = strides_v[j] < 0 ? -1 : 0;
207     }
208     if (j < end_pos.size() && end_pos[j] == 1) {
209       finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i];
210     }
211     if (j < new_axis_pos.size() && new_axis_pos[j] == 1) {
212       infer_shape->push_back(1);
213       j += 1;
214       continue;
215     }
216     if (j < shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
217       if (!(-x_shape[i] <= start && start < x_shape[i]) || strides < 0) {
218         MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the 'strides[" << j << "]' cannot be "
219                                  << "negative number and 'begin[" << j << "]' must be in [-" << x_shape[i] << ", "
220                                  << x_shape[i] << ") when 'shrink_axis_mask' is greater than 0, but got 'strides[" << j
221                                  << "]': " << strides << ", 'begin[" << j << "]': " << start << ".";
222       }
223       j += 1;
224       i += 1;
225       continue;
226     }
227     int64_t slicing_length = ComputeSlicingLengthV2(start, finish, strides, x_dim_size);
228     infer_shape->push_back(slicing_length);
229     i += 1;
230     j += 1;
231   }
232   return;
233 }
234 
ComputeInferShapeV2(const PrimitivePtr & primitive,const std::vector<int64_t> & begin_v,const std::vector<int64_t> & end_v,const std::vector<int64_t> & strides_v,const std::vector<int64_t> & x_shape)235 std::vector<int64_t> ComputeInferShapeV2(const PrimitivePtr &primitive, const std::vector<int64_t> &begin_v,
236                                          const std::vector<int64_t> &end_v, const std::vector<int64_t> &strides_v,
237                                          const std::vector<int64_t> &x_shape) {
238   std::vector<int64_t> begin_pos;
239   std::vector<int64_t> end_pos;
240   std::vector<int64_t> ellipsis_pos;
241   std::vector<int64_t> new_axis_pos;
242   std::vector<int64_t> shrink_axis_pos;
243   GetAndCheckAttrMaskV2(primitive, &begin_pos, &end_pos, &ellipsis_pos, &new_axis_pos, &shrink_axis_pos);
244 
245   size_t i = 0;
246   size_t j = 0;
247   int64_t start;
248   int64_t finish;
249   int64_t strides;
250   int64_t slicing_length;
251   bool has_ellipsis = false;
252   std::vector<int64_t> infer_shape;
253   infer_shape.clear();
254   size_t slice_len = begin_v.size();
255   size_t x_rank = x_shape.size();
256 
257   while (i < x_rank || j < slice_len) {
258     int64_t x_dim_size = x_shape[i];
259     if (j < slice_len) {
260       start = begin_v[j];
261       finish = end_v[j];
262       strides = strides_v[j];
263       if (j < ellipsis_pos.size() && ellipsis_pos[j] == 1) {
264         has_ellipsis = true;
265         break;
266       }
267       if (j < begin_pos.size() && begin_pos[j] == 1) {
268         start = strides_v[j] < 0 ? -1 : 0;
269       }
270       if (j < end_pos.size() && end_pos[j] == 1) {
271         finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i];
272       }
273       if (j < new_axis_pos.size() && new_axis_pos[j] == 1) {
274         infer_shape.push_back(1);
275         j += 1;
276         continue;
277       }
278       if (j < shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
279         if (!(-x_shape[i] <= start && start < x_shape[i]) || strides < 0) {
280           MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the 'strides[" << j << "]' cannot be "
281                                    << "negative number and 'begin[" << j << "]' must be in [-" << x_shape[i] << ", "
282                                    << x_shape[i] << ") when 'shrink_axis_mask' is greater than 0, but got 'strides["
283                                    << j << "]': " << strides << ", 'begin[" << j << "]': " << start << ".";
284         }
285         j += 1;
286         i += 1;
287         continue;
288       }
289     } else {
290       start = 0;
291       finish = x_shape[i];
292       strides = 1;
293     }
294     slicing_length = ComputeSlicingLengthV2(start, finish, strides, x_dim_size);
295     infer_shape.push_back(slicing_length);
296     i += 1;
297     j += 1;
298   }
299   EllipsisInferShapeV2(primitive, x_shape, begin_v, end_v, strides_v, &infer_shape, i, j, has_ellipsis);
300   return infer_shape;
301 }
302 
DynamicComputeInferShapeV2(const PrimitivePtr & primitive,const std::vector<int64_t> & x_shape,const size_t slice_len)303 ShapeMap DynamicComputeInferShapeV2(const PrimitivePtr &primitive, const std::vector<int64_t> &x_shape,
304                                     const size_t slice_len) {
305   // currently not support mask
306   std::vector<int64_t> begin_pos;
307   std::vector<int64_t> end_pos;
308   std::vector<int64_t> ellipsis_pos;
309   std::vector<int64_t> new_axis_pos;
310   std::vector<int64_t> shrink_axis_pos;
311   GetAndCheckAttrMaskV2(primitive, &begin_pos, &end_pos, &ellipsis_pos, &new_axis_pos, &shrink_axis_pos);
312 
313   size_t i = 0;
314   size_t j = 0;
315   int64_t start;
316   int64_t finish;
317   int64_t strides;
318   ShapeMap shape_map;
319   std::vector<int64_t> infer_shape;
320   size_t x_rank = x_shape.size();
321 
322   while (i < x_rank || j < slice_len) {
323     int64_t slicing_length = -1;
324     int64_t x_dim_size = x_shape[i];
325     if (x_dim_size == 1) {
326       slicing_length = 1;
327     }
328     if (j < slice_len) {
329       if (j < new_axis_pos.size() && new_axis_pos[j] == 1) {
330         j += 1;
331         continue;
332       }
333       if (j < shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
334         j += 1;
335         i += 1;
336         continue;
337       }
338     }
339     if (j >= slice_len && x_dim_size > 0) {
340       start = 0;
341       finish = x_shape[i];
342       strides = 1;
343       if (finish > 0) {
344         slicing_length = ComputeSlicingLengthV2(start, finish, strides, x_dim_size);
345       }
346     }
347     infer_shape.push_back(slicing_length);
348     i += 1;
349     j += 1;
350   }
351   shape_map[kShape] = infer_shape;
352   return shape_map;
353 }
354 
CheckAndGetDynamicSliceV2(const AbstractBasePtr & input_arg,const std::string & arg_name,ShapeVector * slice_value,size_t * slice_len)355 bool CheckAndGetDynamicSliceV2(const AbstractBasePtr &input_arg, const std::string &arg_name, ShapeVector *slice_value,
356                                size_t *slice_len) {
357   bool is_dynamic = false;
358   MS_EXCEPTION_IF_NULL(input_arg);
359   auto input_value = input_arg->GetValue();
360   MS_EXCEPTION_IF_NULL(input_value);
361   if (CheckAndConvertUtils::IsTuple(input_arg)) {
362     if (IsValueKnown(input_value)) {
363       *slice_value = CheckAndConvertUtils::CheckTupleInt(arg_name, input_value, "StridedSliceV2");
364       *slice_len = (*slice_value).size();
365     } else {
366       auto tuple_arg = input_arg->GetShape()->cast<abstract::SequenceShapePtr>();
367       MS_EXCEPTION_IF_NULL(tuple_arg);
368       *slice_len = tuple_arg->size();
369     }
370   } else if (CheckAndConvertUtils::IsTensor(input_arg)) {
371     (void)CheckAndConvertUtils::CheckTensorTypeValid(arg_name, input_arg->GetType(), {kInt64, kInt32},
372                                                      "StridedSliceV2");
373     if (input_value->isa<tensor::Tensor>()) {
374       *slice_value =
375         CheckAndConvertUtils::CheckTensorIntValue(arg_name, input_value, "StridedSliceV2", input_arg->GetType());
376       *slice_len = (*slice_value).size();
377     } else {
378       // slice is ValueAny
379       is_dynamic = true;
380       auto slice_shape = CheckAndConvertUtils::GetTensorInputShape("StridedSliceV2", {input_arg}, 0);
381       if (slice_shape->shape().size() != 1) {
382         MS_EXCEPTION(ValueError) << "For 'StridedSliceV2', " << arg_name << " must be 1-D, but got"
383                                  << slice_shape->shape().size() << "-D.";
384       }
385       *slice_len = LongToSize(slice_shape->shape()[0]);
386     }
387   } else {
388     MS_EXCEPTION(TypeError) << "For 'StridedSliceV2', '" << arg_name
389                             << "' must be tuple or Tensor, but got: " << input_arg->GetType()->ToString() << ".";
390   }
391 
392   if (arg_name == "strides") {
393     if (std::any_of((*slice_value).begin(), (*slice_value).end(),
394                     [](int64_t stride_value) { return stride_value == 0; })) {
395       MS_EXCEPTION(ValueError) << "For 'StridedSliceV2', input 'strides' can not contain 0.";
396     }
397   }
398   return is_dynamic;
399 }
400 
StridedSliceV2InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)401 abstract::ShapePtr StridedSliceV2InferShape(const PrimitivePtr &primitive,
402                                             const std::vector<AbstractBasePtr> &input_args) {
403   MS_EXCEPTION_IF_NULL(primitive);
404   auto prim_name = primitive->name();
405   const size_t x_index = 0;
406   auto input_x_shape = CheckAndConvertUtils::GetTensorInputShape("StridedSliceV2", {input_args[x_index]}, 0);
407   if (input_x_shape->shape().size() < 1 || input_x_shape->shape().size() > kStridedSliceMaxDim) {
408     MS_EXCEPTION(ValueError) << "For 'StridedSliceV2', input_x must be 1D-8D, but got" << input_x_shape->shape().size()
409                              << "-D.";
410   }
411   auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[x_index]->GetShape());
412   auto x_shape = shape_map[kShape];
413   bool x_is_dyn =
414     std::any_of(x_shape.begin(), x_shape.end(), [](int64_t value) { return value == abstract::Shape::kShapeDimAny; });
415   ShapeVector begin;
416   ShapeVector end;
417   ShapeVector strides;
418   ShapeVector ret_in_shape;
419   size_t begin_len = 0;
420   size_t end_len = 0;
421   size_t stride_len = 0;
422   const size_t begin_index = 1;
423   const size_t end_index = 2;
424   const size_t stride_index = 3;
425   bool begin_dynamic = CheckAndGetDynamicSliceV2(input_args[begin_index], "begin", &begin, &begin_len);
426   bool end_dynamic = CheckAndGetDynamicSliceV2(input_args[end_index], "end", &end, &end_len);
427   bool stride_dynamic = CheckAndGetDynamicSliceV2(input_args[stride_index], "strides", &strides, &stride_len);
428   if (begin_len != stride_len || end_len != stride_len) {
429     MS_EXCEPTION(ValueError) << "For '" << prim_name << "', 'begin', 'end' and 'strides' must have the same length, "
430                              << "but got length of 'begin': " << begin_len << ", 'end': " << end_len
431                              << ", 'strides': " << stride_len << ".";
432   }
433 
434   bool slice_dynamic = false;
435   if (begin_dynamic || end_dynamic || stride_dynamic || x_is_dyn) {
436     slice_dynamic = true;
437   }
438   if (!slice_dynamic) {
439     ret_in_shape = ComputeInferShapeV2(primitive, begin, end, strides, x_shape);
440     return std::make_shared<abstract::Shape>(ret_in_shape);
441   }
442   auto ret_shape_map = DynamicComputeInferShapeV2(primitive, x_shape, begin_len);
443   ret_in_shape = ret_shape_map[kShape];
444 
445   return std::make_shared<abstract::Shape>(ret_in_shape);
446 }
447 
StridedSliceV2InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)448 TypePtr StridedSliceV2InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
449   MS_EXCEPTION_IF_NULL(primitive);
450   const size_t x_index = 0;
451   return CheckAndConvertUtils::GetTensorInputType(primitive->name(), input_args, x_index);
452 }
453 }  // namespace
454 
455 MIND_API_OPERATOR_IMPL(StridedSliceV2, BaseOperator);
set_begin_mask(int64_t begin_mask)456 void StridedSliceV2::set_begin_mask(int64_t begin_mask) {
457   (void)CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name());
458   (void)this->AddAttr(kBeginMask, api::MakeValue(begin_mask));
459 }
get_begin_mask() const460 int64_t StridedSliceV2::get_begin_mask() const {
461   auto value_ptr = GetAttr(kBeginMask);
462   return GetValue<int64_t>(value_ptr);
463 }
set_end_mask(int64_t end_mask)464 void StridedSliceV2::set_end_mask(int64_t end_mask) {
465   (void)CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name());
466   (void)this->AddAttr(kEndMask, api::MakeValue(end_mask));
467 }
get_end_mask() const468 int64_t StridedSliceV2::get_end_mask() const {
469   auto value_ptr = GetAttr(kEndMask);
470   return GetValue<int64_t>(value_ptr);
471 }
set_ellipsis_mask(int64_t ellipsis_mask)472 void StridedSliceV2::set_ellipsis_mask(int64_t ellipsis_mask) {
473   (void)CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name());
474   std::bitset<sizeof(int64_t) * n_eight> bs(ellipsis_mask);
475   std::ostringstream buffer;
476   if (bs.count() > 1) {
477     buffer << "For" << this->name() << ", only support one ellipsis in the index, but got " << this->get_end_mask()
478            << ".";
479     MS_EXCEPTION(ValueError) << buffer.str();
480   }
481   (void)this->AddAttr(kEllipsisMask, api::MakeValue(ellipsis_mask));
482 }
get_ellipsis_mask() const483 int64_t StridedSliceV2::get_ellipsis_mask() const {
484   auto value_ptr = GetAttr(kEllipsisMask);
485   return GetValue<int64_t>(value_ptr);
486 }
set_new_axis_mask(int64_t new_axis_mask)487 void StridedSliceV2::set_new_axis_mask(int64_t new_axis_mask) {
488   (void)CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name());
489   (void)this->AddAttr(kNewAxisMask, api::MakeValue(new_axis_mask));
490 }
get_new_axis_mask() const491 int64_t StridedSliceV2::get_new_axis_mask() const {
492   auto value_ptr = GetAttr(kNewAxisMask);
493   return GetValue<int64_t>(value_ptr);
494 }
set_shrink_axis_mask(int64_t shrink_axis_mask)495 void StridedSliceV2::set_shrink_axis_mask(int64_t shrink_axis_mask) {
496   (void)CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name());
497   (void)this->AddAttr(kShrinkAxisMask, api::MakeValue(shrink_axis_mask));
498 }
get_shrink_axis_mask() const499 int64_t StridedSliceV2::get_shrink_axis_mask() const {
500   auto value_ptr = GetAttr(kShrinkAxisMask);
501   return GetValue<int64_t>(value_ptr);
502 }
Init(int64_t begin_mask,int64_t end_mask,int64_t ellipsis_mask,int64_t new_axis_mask,int64_t shrink_axis_mask)503 void StridedSliceV2::Init(int64_t begin_mask, int64_t end_mask, int64_t ellipsis_mask, int64_t new_axis_mask,
504                           int64_t shrink_axis_mask) {
505   this->set_begin_mask(begin_mask);
506   this->set_end_mask(end_mask);
507   this->set_ellipsis_mask(ellipsis_mask);
508   this->set_new_axis_mask(new_axis_mask);
509   this->set_shrink_axis_mask(shrink_axis_mask);
510 }
511 
StridedSliceV2Infer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)512 AbstractBasePtr StridedSliceV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
513                                     const std::vector<AbstractBasePtr> &input_args) {
514   MS_EXCEPTION_IF_NULL(primitive);
515   const int64_t input_num = 4;
516   CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
517   return std::make_shared<abstract::AbstractTensor>(StridedSliceV2InferType(primitive, input_args),
518                                                     StridedSliceV2InferShape(primitive, input_args));
519 }
520 
521 // AG means auto generated
522 class MIND_API AGStridedSliceV2Infer : public abstract::OpInferBase {
523  public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const524   BaseShapePtr InferShape(const PrimitivePtr &primitive,
525                           const std::vector<AbstractBasePtr> &input_args) const override {
526     return StridedSliceV2InferShape(primitive, input_args);
527   }
528 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const529   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
530     return StridedSliceV2InferType(primitive, input_args);
531   }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const532   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
533                                     const std::vector<AbstractBasePtr> &input_args) const override {
534     return StridedSliceV2Infer(engine, primitive, input_args);
535   }
536 
GetValueDependArgIndices() const537   std::set<int64_t> GetValueDependArgIndices() const override { return {1, 2, 3}; }
538 };
539 
540 REGISTER_PRIMITIVE_OP_INFER_IMPL(StridedSliceV2, prim::kPrimStridedSliceV2, AGStridedSliceV2Infer, false);
541 }  // namespace ops
542 }  // namespace mindspore
543