• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/strided_slice.h"
18 #include <string>
19 #include <algorithm>
20 #include <memory>
21 #include <set>
22 #include <vector>
23 #include <bitset>
24 #include "ops/op_utils.h"
25 #include "utils/check_convert_utils.h"
26 #include "abstract/primitive_infer_map.h"
27 
28 namespace mindspore {
29 namespace ops {
30 namespace {
TenToTwo(int64_t num)31 std::vector<int64_t> TenToTwo(int64_t num) {
32   std::vector<int64_t> output;
33   if (num == 0) {
34     output.push_back(0);
35     return output;
36   }
37   while (num) {
38     output.push_back(num % 2);
39     num /= 2;
40   }
41 
42   return output;
43 }
44 
get_stride_with_not_zero(int64_t start_pos,int64_t end_pos,int64_t strides)45 int64_t get_stride_with_not_zero(int64_t start_pos, int64_t end_pos, int64_t strides) {
46   int64_t slicing_length = 0;
47   if (strides != 0) {
48     slicing_length = 1 + (end_pos + 1 - start_pos) / strides;
49   } else {
50     MS_EXCEPTION(ValueError) << "the strides must be non-zero but got " << strides;
51   }
52   return slicing_length;
53 }
54 
EllipsisInferShape(const PrimitivePtr & primitive,const std::vector<int64_t> & x_shape,const std::vector<int64_t> & begin_v,const std::vector<int64_t> & end_v,const std::vector<int64_t> & strides_v,std::vector<int64_t> * infer_shape,size_t i,size_t j,bool has_ellipsis)55 void EllipsisInferShape(const PrimitivePtr &primitive, const std::vector<int64_t> &x_shape,
56                         const std::vector<int64_t> &begin_v, const std::vector<int64_t> &end_v,
57                         const std::vector<int64_t> &strides_v, std::vector<int64_t> *infer_shape, size_t i, size_t j,
58                         bool has_ellipsis) {
59   if (!has_ellipsis) {
60     return;
61   }
62   MS_EXCEPTION_IF_NULL(primitive);
63   auto strided_slice_prim = primitive->cast<PrimStridedSlicePtr>();
64   MS_EXCEPTION_IF_NULL(strided_slice_prim);
65   size_t x_rank = x_shape.size();
66   size_t slice_len = begin_v.size();
67   std::vector<int64_t> begin_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kBeginMask)));
68   std::vector<int64_t> end_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kEndMask)));
69   std::vector<int64_t> new_axis_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kNewAxisMask)));
70   std::vector<int64_t> shrink_axis_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kShrinkAxisMask)));
71   (void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(new_axis_pos.size()), kGreaterEqual,
72                                            SizeToLong(slice_len), primitive->name());
73 
74   size_t num = 0;
75   for (size_t n = j + 1; n < slice_len; n++) {
76     if (new_axis_pos[n] == 1) {
77       num++;
78     }
79   }
80 
81   size_t ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + num;
82   (void)infer_shape->insert(infer_shape->end(), x_shape.begin() + SizeToInt(i),
83                             x_shape.begin() + SizeToLong(i + ellipsis_occupied_dims));
84   j += 1;
85   i += ellipsis_occupied_dims;
86 
87   while (i < x_rank || j < slice_len) {
88     int64_t x_dim_size = x_shape[i];
89     int64_t start = begin_v[j];
90     int64_t finish = end_v[j];
91     int64_t strides = strides_v[j];
92     if (j < begin_pos.size() || j < slice_len) {
93       start = strides_v[j] < 0 ? -1 : 0;
94     }
95     if (j < end_pos.size() && end_pos[j] == 1) {
96       finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i];
97     }
98     if (j < new_axis_pos.size() && new_axis_pos[j] == 1) {
99       infer_shape->push_back(1);
100       j += 1;
101       continue;
102     }
103     if (j < shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
104       if ((-x_shape[i] <= start && start < x_shape[i]) || strides < 0) {
105         MS_EXCEPTION(ValueError) << "when shrink axis, the stride cannot be negative number";
106       }
107       j += 1;
108       i += 1;
109       continue;
110     }
111     int64_t slicing_length = strided_slice_prim->compute_slicing_length(start, finish, strides, x_dim_size);
112     infer_shape->push_back(slicing_length);
113     i += 1;
114     j += 1;
115   }
116   return;
117 }
118 
CheckAndGetValidStrides(const AbstractBasePtr & stride_arg)119 const std::vector<int64_t> CheckAndGetValidStrides(const AbstractBasePtr &stride_arg) {
120   MS_EXCEPTION_IF_NULL(stride_arg);
121   auto temp_strides = stride_arg->cast<abstract::AbstractTuplePtr>()->BuildValue();
122   MS_EXCEPTION_IF_NULL(temp_strides);
123   auto strides = GetValue<std::vector<int64_t>>(temp_strides);
124   if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) { return stride == 0; })) {
125     MS_EXCEPTION(ValueError) << "StridedSlice's input strides cannot contain 0.";
126   }
127   return strides;
128 }
129 
ComputeInferShape(const PrimitivePtr & primitive,const std::vector<int64_t> & begin_v,const std::vector<int64_t> & end_v,const std::vector<int64_t> & x_shape,const std::vector<int64_t> & strides_v)130 std::vector<int64_t> ComputeInferShape(const PrimitivePtr &primitive, const std::vector<int64_t> &begin_v,
131                                        const std::vector<int64_t> &end_v, const std::vector<int64_t> &x_shape,
132                                        const std::vector<int64_t> &strides_v) {
133   std::vector<int64_t> begin_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kBeginMask)));
134   std::vector<int64_t> end_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kEndMask)));
135   std::vector<int64_t> ellipsis_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kEllipsisMask)));
136   std::vector<int64_t> new_axis_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kNewAxisMask)));
137   std::vector<int64_t> shrink_axis_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kShrinkAxisMask)));
138   size_t i = 0;
139   size_t j = 0;
140   int64_t start;
141   int64_t finish;
142   int64_t strides;
143   int64_t slicing_length;
144   bool has_ellipsis = false;
145   std::vector<int64_t> infer_shape;
146   size_t slice_len = begin_v.size();
147   size_t x_rank = x_shape.size();
148   (void)CheckAndConvertUtils::CheckInteger("end_v size", SizeToLong(end_v.size()), kGreaterEqual, SizeToLong(slice_len),
149                                            primitive->name());
150   (void)CheckAndConvertUtils::CheckInteger("strides_v size", SizeToLong(strides_v.size()), kGreaterEqual,
151                                            SizeToLong(slice_len), primitive->name());
152   while (i < x_rank || j < slice_len) {
153     int64_t x_dim_size = x_shape[i];
154     if (j < slice_len) {
155       start = begin_v[j];
156       finish = end_v[j];
157       strides = strides_v[j];
158       if (j < ellipsis_pos.size() && ellipsis_pos[j] == 1) {
159         has_ellipsis = true;
160         break;
161       }
162       if (j < begin_pos.size() && begin_pos[j] == 1) {
163         start = strides_v[j] < 0 ? -1 : 0;
164       }
165       if (j < end_pos.size() && end_pos[j] == 1) {
166         finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i];
167       }
168       if (j < new_axis_pos.size() && new_axis_pos[j] == 1) {
169         infer_shape.push_back(1);
170         j += 1;
171         continue;
172       }
173       if (j < shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
174         if ((-x_shape[i] <= start && start < x_shape[i]) || strides < 0) {
175           MS_EXCEPTION(ValueError) << "when shrink axis, the stride cannot be negative number";
176         }
177         j += 1;
178         i += 1;
179         continue;
180       }
181     } else {
182       start = 0;
183       finish = x_shape[0];
184       strides = 1;
185     }
186     auto strided_slice_prim = primitive->cast<PrimStridedSlicePtr>();
187     MS_EXCEPTION_IF_NULL(strided_slice_prim);
188     slicing_length = strided_slice_prim->compute_slicing_length(start, finish, strides, x_dim_size);
189     infer_shape.push_back(slicing_length);
190     i += 1;
191     j += 1;
192   }
193   EllipsisInferShape(primitive, x_shape, begin_v, end_v, strides_v, &infer_shape, i, j, has_ellipsis);
194   return infer_shape;
195 }
196 
StridedSliceInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)197 abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
198                                           const std::vector<AbstractBasePtr> &input_args) {
199   MS_EXCEPTION_IF_NULL(primitive);
200   auto tuple_begin_v = input_args[kInputIndex1]->cast<abstract::AbstractTuplePtr>();
201   MS_EXCEPTION_IF_NULL(tuple_begin_v);
202   auto temp_begin_v = tuple_begin_v->BuildValue();
203   MS_EXCEPTION_IF_NULL(temp_begin_v);
204   auto begin_v = GetValue<std::vector<int64_t>>(temp_begin_v);
205 
206   auto tuple_end_v = input_args[kInputIndex2]->cast<abstract::AbstractTuplePtr>();
207   MS_EXCEPTION_IF_NULL(tuple_end_v);
208   auto temp_end_v = tuple_end_v->BuildValue();
209   MS_EXCEPTION_IF_NULL(temp_end_v);
210   auto end_v = GetValue<std::vector<int64_t>>(temp_end_v);
211   auto strides_v = CheckAndGetValidStrides(input_args[kInputIndex3]);
212 
213   auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
214   auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape];
215   auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMaxShape];
216   auto ret_in_shape = ComputeInferShape(primitive, begin_v, end_v, x_shape, strides_v);
217   if (min_shape.empty() || max_shape.empty()) {
218     return std::make_shared<abstract::Shape>(ret_in_shape);
219   }
220   auto ret_min_shape = ComputeInferShape(primitive, begin_v, end_v, min_shape, strides_v);
221   auto ret_max_shape = ComputeInferShape(primitive, begin_v, end_v, max_shape, strides_v);
222   return std::make_shared<abstract::Shape>(ret_in_shape, ret_min_shape, ret_max_shape);
223 }
224 
StridedSliceInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)225 TypePtr StridedSliceInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
226   const int64_t x_index = 0;
227   return CheckAndConvertUtils::GetInputTensorType(input_args, x_index, primitive->name());
228 }
229 }  // namespace
230 
set_begin_mask(const int64_t begin_mask)231 void StridedSlice::set_begin_mask(const int64_t begin_mask) {
232   (void)CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name());
233   (void)this->AddAttr(kBeginMask, MakeValue(begin_mask));
234 }
get_begin_mask() const235 int64_t StridedSlice::get_begin_mask() const {
236   auto value_ptr = GetAttr(kBeginMask);
237   return GetValue<int64_t>(value_ptr);
238 }
set_end_mask(const int64_t end_mask)239 void StridedSlice::set_end_mask(const int64_t end_mask) {
240   (void)CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name());
241   (void)this->AddAttr(kEndMask, MakeValue(end_mask));
242 }
get_end_mask() const243 int64_t StridedSlice::get_end_mask() const {
244   auto value_ptr = GetAttr(kEndMask);
245   return GetValue<int64_t>(value_ptr);
246 }
set_ellipsis_mask(const int64_t ellipsis_mask)247 void StridedSlice::set_ellipsis_mask(const int64_t ellipsis_mask) {
248   (void)CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name());
249   std::bitset<sizeof(int64_t) * 8> bs(ellipsis_mask);
250   std::ostringstream buffer;
251   if (bs.count() > 1) {
252     buffer << "For" << this->name() << ", only support one ellipsis in the index, but got " << this->get_end_mask();
253     MS_EXCEPTION(ValueError) << buffer.str();
254   }
255   (void)this->AddAttr(kEllipsisMask, MakeValue(ellipsis_mask));
256 }
get_ellipsis_mask() const257 int64_t StridedSlice::get_ellipsis_mask() const {
258   auto value_ptr = GetAttr(kEllipsisMask);
259   return GetValue<int64_t>(value_ptr);
260 }
set_new_axis_mask(const int64_t new_axis_mask)261 void StridedSlice::set_new_axis_mask(const int64_t new_axis_mask) {
262   (void)CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name());
263   (void)this->AddAttr(kNewAxisMask, MakeValue(new_axis_mask));
264 }
get_new_axis_mask() const265 int64_t StridedSlice::get_new_axis_mask() const {
266   auto value_ptr = GetAttr(kNewAxisMask);
267   return GetValue<int64_t>(value_ptr);
268 }
set_shrink_axis_mask(const int64_t shrink_axis_mask)269 void StridedSlice::set_shrink_axis_mask(const int64_t shrink_axis_mask) {
270   (void)CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name());
271   (void)this->AddAttr(kShrinkAxisMask, MakeValue(shrink_axis_mask));
272 }
get_shrink_axis_mask() const273 int64_t StridedSlice::get_shrink_axis_mask() const {
274   auto value_ptr = GetAttr(kShrinkAxisMask);
275   return GetValue<int64_t>(value_ptr);
276 }
Init(const int64_t begin_mask,const int64_t end_mask,const int64_t ellipsis_mask,const int64_t new_axis_mask,const int64_t shrink_axis_mask)277 void StridedSlice::Init(const int64_t begin_mask, const int64_t end_mask, const int64_t ellipsis_mask,
278                         const int64_t new_axis_mask, const int64_t shrink_axis_mask) {
279   this->set_begin_mask(begin_mask);
280   this->set_end_mask(end_mask);
281   this->set_ellipsis_mask(ellipsis_mask);
282   this->set_new_axis_mask(new_axis_mask);
283   this->set_shrink_axis_mask(shrink_axis_mask);
284 }
285 
compute_slicing_length(int64_t start_pos,int64_t end_pos,int64_t strides,int64_t x_dim) const286 int64_t StridedSlice::compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) const {
287   int64_t slicing_length = 0;
288   if (strides > 0) {
289     if ((start_pos >= x_dim) || end_pos < -x_dim) {
290       slicing_length = 0;
291     } else {
292       if (-x_dim <= start_pos && start_pos < 0) {
293         start_pos += x_dim;
294       }
295       if (start_pos < -x_dim) {
296         start_pos = 0;
297       }
298       if (-x_dim <= end_pos && end_pos < 0) {
299         end_pos += x_dim;
300       }
301       if (end_pos > x_dim) {
302         end_pos = x_dim;
303       }
304       if (start_pos > end_pos) {
305         slicing_length = 0;
306       } else {
307         slicing_length = 1 + (end_pos - 1 - start_pos) / strides;
308       }
309     }
310   } else {
311     if (start_pos < -x_dim || end_pos >= x_dim) {
312       slicing_length = 0;
313     } else {
314       if (start_pos > 0 && start_pos < x_dim) {
315         start_pos += -x_dim;
316       }
317       if (start_pos >= x_dim) {
318         start_pos = -1;
319       }
320       if (end_pos >= 0 && end_pos < x_dim) {
321         end_pos += -x_dim;
322       }
323       if (end_pos < -x_dim - 1) {
324         end_pos = -x_dim - 1;
325       }
326       if (start_pos <= end_pos) {
327         slicing_length = 0;
328       } else {
329         slicing_length = get_stride_with_not_zero(start_pos, end_pos, strides);
330       }
331     }
332   }
333   return slicing_length;
334 }
335 
StridedSliceInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)336 AbstractBasePtr StridedSliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
337                                   const std::vector<AbstractBasePtr> &input_args) {
338   MS_EXCEPTION_IF_NULL(primitive);
339   const int64_t input_num = 4;
340   CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
341   return std::make_shared<abstract::AbstractTensor>(StridedSliceInferType(primitive, input_args),
342                                                     StridedSliceInferShape(primitive, input_args));
343 }
344 REGISTER_PRIMITIVE_C(kNameStridedSlice, StridedSlice);
345 }  // namespace ops
346 }  // namespace mindspore
347