• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_UTILS_TENSOR_INDEX_PY_H_
18 #define MINDSPORE_CCSRC_UTILS_TENSOR_INDEX_PY_H_
19 
20 #include <tuple>
21 #include <algorithm>
22 #include <limits>
23 #include <utility>
24 #include <vector>
25 #include "pybind11/numpy.h"
26 #include "pybind11/pytypes.h"
27 #include "ir/map_tensor.h"
28 #include "pybind_api/ir/tensor_py.h"
29 #include "include/common/utils/convert_utils_py.h"
30 #include "pipeline/pynative/base.h"
31 
32 namespace py = pybind11;
33 
34 namespace mindspore {
35 namespace tensor {
36 using tensor::TensorPy;
37 //
38 // Tensor index python adapter.
39 //
40 const int64_t kIndexMax = std::numeric_limits<int64_t>::max();
41 
42 enum class TensorIndexType { None = 0, Ellipsis, Integer, Boolean, Slice, Tensor, List, Tuple, Array, Float };
43 enum class ValueTransferType {
44   kUnknown,
45   kTensorScatterUpdate,
46   kExpandDims,
47   kBroadCast,
48   kCast,
49   kSelect,
50   kGather,
51   kStrideSlice,
52   kStrideSliceWithMask,
53   kGatherND,
54   kScatterNdUpdate,
55   kReshape,
56   kSelectView,
57   kUnsqueeze,
58   kCopyView,
59   kScatterND,
60   kNumberToTensor,
61   kHandleSequenceValue,
62   kByPass,
63   kReSetItemByIndex,
64   kCopySlice,
65   kSetItemByBool,
66   kEmptyTensor,
67   kSetItemByEllipsis,
68   kFormatIndexTensor,
69   kGetitemByBoolTensor,
70   kSetitemByBoolTensor,
71   kJustReturn,
72   kRaiseIndexError
73 };
74 
75 enum class IndexOpType { GetItem = 0, SetItem };
76 
77 class Slice final {
78  public:
Slice(const py::object & start_index,const py::object & stop_index,const py::object & step_index)79   Slice(const py::object &start_index, const py::object &stop_index, const py::object &step_index) {
80     dim_size_ = kIndexMax;
81     if (py::isinstance<Tensor>(step_index) || IsStubTensor(step_index)) {
82       auto step_tensor = IsStubTensor(step_index) ? ConvertStubTensor(step_index) : step_index.cast<TensorPtr>();
83       MS_EXCEPTION_IF_NULL(step_tensor);
84       if (step_tensor->data_type() == kMetaTypeNone) {
85         step_ = 1;
86       } else {
87         step_ = GetTensorData(step_tensor);
88       }
89     } else if (py::isinstance<py::none>(step_index)) {
90       step_ = 1;
91     } else if (py::isinstance<py::int_>(step_index)) {
92       step_ = step_index.cast<int64_t>();
93       if (step_ == 0) {
94         MS_EXCEPTION(ValueError) << "For 'StridedSlice', 'strides' cannot contain 0";
95       }
96       if (step_ < -kIndexMax) {
97         step_ = -kIndexMax;
98       }
99     }
100     start_ = NormalizeIndex(start_index, step_, dim_size_);
101     stop_ = NormalizeIndex(stop_index, -step_, dim_size_);
102     stop_init_by_none_ = InitByNone(stop_index);
103     start_init_by_none_ = InitByNone(start_index);
104   }
105 
Slice(int64_t start_index,int64_t stop_index,int64_t step_index,int64_t dim_size,bool start_init_by_none,bool stop_init_by_none)106   Slice(int64_t start_index, int64_t stop_index, int64_t step_index, int64_t dim_size, bool start_init_by_none,
107         bool stop_init_by_none) {
108     dim_size_ = dim_size;
109     step_ = step_index;
110     if (step_ == 0) {
111       MS_EXCEPTION(ValueError) << "For 'StridedSlice', 'strides' cannot contain 0";
112     }
113     if (step_ < -kIndexMax) {
114       step_ = -kIndexMax;
115     }
116 
117     start_ = NormalizeIndex(start_index, dim_size_);
118     stop_ = NormalizeIndex(stop_index, dim_size_);
119     start_init_by_none_ = start_init_by_none;
120     stop_init_by_none_ = stop_init_by_none;
121   }
122 
123   // Empty slice (None:None:None) -> (0:DimSize:1)
Slice()124   Slice() : Slice(0, kIndexMax, 1, kIndexMax, true, true) {}
125 
Slice(const Slice & slice,int64_t dim_size)126   Slice(const Slice &slice, int64_t dim_size)
127       : Slice(std::min(slice.start_, dim_size), std::min(slice.stop_, dim_size), slice.step_,
128               std::min(slice.dim_size_, dim_size), slice.start_init_by_none_, slice.stop_init_by_none_) {}
129 
GetTensorData(const TensorPtr & tensor)130   static inline int64_t GetTensorData(const TensorPtr &tensor) {
131     MS_EXCEPTION_IF_NULL(tensor);
132     const auto &device_address = tensor->device_address();
133     if (device_address != nullptr) {
134       tensor->data_sync();
135     }
136     if (!tensor->shape().empty()) {
137       MS_EXCEPTION(TypeError) << "Only integer scalar tensors can be converted to a scalar index";
138     }
139     int64_t tensor_value = 0;
140     if (tensor->data_type() == kNumberTypeInt32) {
141       tensor_value = *static_cast<int32_t *>(tensor->data_c());
142     } else if (tensor->data_type() == kNumberTypeInt64) {
143       tensor_value = *static_cast<int64_t *>(tensor->data_c());
144     }
145     return tensor_value;
146   }
147 
start()148   inline int64_t start() const { return start_; }
149 
start_init_by_none()150   inline bool start_init_by_none() const { return start_init_by_none_; }
151 
stop_init_by_none()152   inline bool stop_init_by_none() const { return stop_init_by_none_; }
153 
stop()154   inline int64_t stop() const { return stop_; }
155 
step()156   inline int64_t step() const { return step_; }
dim_size()157   inline int64_t dim_size() const { return dim_size_; }
158 
159  private:
160   int64_t start_ = 0;
161   int64_t stop_ = 0;
162   int64_t step_ = 0;
163   int64_t dim_size_ = 0;
164   bool start_init_by_none_ = false;
165   bool stop_init_by_none_ = false;
166 
NormalizeIndex(int64_t index,int64_t dim_size)167   static inline int64_t NormalizeIndex(int64_t index, int64_t dim_size) {
168     int64_t new_index = index;
169     if (dim_size == kIndexMax) {
170       return new_index;
171     }
172     if (new_index < 0) {
173       MS_EXCEPTION_IF_ZERO("DimsSize should not be zero", dim_size);
174       return new_index < -dim_size ? 0 : (dim_size + (new_index % dim_size)) % dim_size;  // NOLINT
175     }
176     return new_index < dim_size ? new_index : dim_size;
177   }
178 
NormalizeIndex(const TensorPtr & index,int64_t step,int64_t dim_size)179   static inline int64_t NormalizeIndex(const TensorPtr &index, int64_t step, int64_t dim_size) {
180     MS_EXCEPTION_IF_NULL(index);
181     if (index->data_type() == kMetaTypeNone) {
182       return step > 0 ? 0 : dim_size;
183     }
184     int64_t new_index = GetTensorData(index);
185     if (dim_size == kIndexMax) {
186       return new_index;
187     }
188     if (new_index < 0) {
189       MS_EXCEPTION_IF_ZERO("DimsSize should not be zero", dim_size);
190       return new_index < -dim_size ? 0 : (dim_size + (new_index % dim_size)) % dim_size;  // NOLINT
191     }
192     return new_index < dim_size ? new_index : dim_size;
193   }
194 
InitByNone(const py::object & index)195   static inline bool InitByNone(const py::object &index) {
196     if (py::isinstance<Tensor>(index)) {
197       auto tensor_index = index.cast<TensorPtr>();
198       MS_EXCEPTION_IF_NULL(tensor_index);
199       return tensor_index->data_type() == kMetaTypeNone;
200     } else if (IsStubTensor(index)) {
201       auto type_id = GetStubTensorInfo(index).second->type_id();
202       return type_id == kMetaTypeNone;
203     } else if (py::isinstance<py::none>(index)) {
204       return true;
205     }
206     return false;
207   }
208 
NormalizeIndex(const py::object & index,int64_t step,int64_t dim_size)209   static inline int64_t NormalizeIndex(const py::object &index, int64_t step, int64_t dim_size) {
210     int64_t normalized_index;
211     if (py::isinstance<Tensor>(index) || IsStubTensor(index)) {
212       auto tensor_index = IsStubTensor(index) ? ConvertStubTensor(index) : index.cast<TensorPtr>();
213       MS_EXCEPTION_IF_NULL(tensor_index);
214       normalized_index = NormalizeIndex(tensor_index, step, dim_size);
215     } else if (py::isinstance<py::int_>(index)) {
216       normalized_index = NormalizeIndex(index.cast<int64_t>(), dim_size);
217     } else if (py::isinstance<py::none>(index)) {
218       normalized_index = step > 0 ? 0 : dim_size;
219     } else {
220       MS_LOG(EXCEPTION) << "Slice index type must be int, tensor or none.";
221     }
222     return normalized_index;
223   }
224   friend inline std::ostream &operator<<(std::ostream &out, const Slice &slice) {
225     return out << "start: " << slice.start_ << ","
226                << "stop: " << slice.stop_ << ","
227                << "step: " << slice.step_;
228   }
229 };
230 
231 class TensorIndex final {
232  public:
TensorIndex(const py::none &)233   explicit TensorIndex(const py::none &) : type_(TensorIndexType::None) {}
234 
TensorIndex(const py::ellipsis &)235   explicit TensorIndex(const py::ellipsis &) : type_(TensorIndexType::Ellipsis) {}
236 
TensorIndex(int64_t integer)237   explicit TensorIndex(int64_t integer) : integer_(integer), type_(TensorIndexType::Integer) {}
TensorIndex(int integer)238   explicit TensorIndex(int integer) : TensorIndex(static_cast<int64_t>(integer)) {}
TensorIndex(const py::int_ & integer)239   explicit TensorIndex(const py::int_ &integer) : TensorIndex(integer.cast<int64_t>()) {}
240 
TensorIndex(bool boolean)241   explicit TensorIndex(bool boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}
TensorIndex(const py::bool_ & boolean)242   explicit TensorIndex(const py::bool_ &boolean) : TensorIndex(py::cast<bool>(boolean)) {}
243 
TensorIndex(const Slice & slice)244   explicit TensorIndex(const Slice &slice) : slice_(slice), type_(TensorIndexType::Slice) {}
TensorIndex(const py::slice & py_slice)245   explicit TensorIndex(const py::slice &py_slice)
246       : TensorIndex(Slice(py_slice.attr("start"), py_slice.attr("stop"), py_slice.attr("step"))) {}
247 
TensorIndex(TensorPtr tensor)248   explicit TensorIndex(TensorPtr tensor) : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}
TensorIndex(py::array py_array)249   explicit TensorIndex(py::array py_array) : array_(std::move(py_array)), type_(TensorIndexType::Array) {}
250 
TensorIndex(py::list py_list)251   explicit TensorIndex(py::list py_list) : list_(std::move(py_list)), type_(TensorIndexType::List) {}
TensorIndex(py::tuple py_tuple)252   explicit TensorIndex(py::tuple py_tuple) : tuple_(std::move(py_tuple)), type_(TensorIndexType::Tuple) {}
253 
TensorIndex(float float_input)254   explicit TensorIndex(float float_input) : float_(float_input), type_(TensorIndexType::Float) {}
TensorIndex(const py::float_ & float_input)255   explicit TensorIndex(const py::float_ &float_input) : TensorIndex(float_input.cast<float>()) {}
256 
TensorIndex(const py::handle & py_object)257   explicit TensorIndex(const py::handle &py_object) {
258     if (py::isinstance<py::list>(py_object)) {
259       this->list_ = py_object.cast<py::list>();
260       this->type_ = TensorIndexType::List;
261     } else if (py::isinstance<py::int_>(py_object) && !py::isinstance<py::bool_>(py_object)) {
262       this->integer_ = py_object.cast<py::int_>();
263       this->type_ = TensorIndexType::Integer;
264     } else if (py::isinstance<py::float_>(py_object)) {
265       this->float_ = py_object.cast<py::float_>();
266       this->type_ = TensorIndexType::Float;
267     } else if (py::isinstance<tensor::Tensor>(py_object)) {
268       this->tensor_ = py_object.cast<tensor::TensorPtr>();
269       this->type_ = TensorIndexType::Tensor;
270     } else if (py::isinstance<py::tuple>(py_object)) {
271       this->tuple_ = py_object.cast<py::tuple>();
272       this->type_ = TensorIndexType::Tuple;
273     } else if (py::isinstance<py::slice>(py_object)) {
274       this->slice_ = TensorIndex(py_object.cast<py::slice>()).slice_;
275       this->type_ = TensorIndexType::Slice;
276     } else if (py::isinstance<py::ellipsis>(py_object)) {
277       this->type_ = TensorIndexType::Ellipsis;
278     } else if (py::isinstance<py::none>(py_object)) {
279       this->type_ = TensorIndexType::None;
280     } else if (py::isinstance<py::array>(py_object)) {
281       this->array_ = py_object.cast<py::array>();
282       this->type_ = TensorIndexType::Array;
283     } else if (py::isinstance<py::bool_>(py_object)) {
284       this->boolean_ = py_object.cast<py::bool_>();
285       this->type_ = TensorIndexType::Boolean;
286     } else if (IsStubTensor(py_object)) {
287       this->tensor_ = ConvertStubTensor(py_object);
288       this->type_ = TensorIndexType::Tensor;
289     }
290   }
291 
IsNone()292   inline bool IsNone() const { return type_ == TensorIndexType::None; }
293 
IsEllipsis()294   inline bool IsEllipsis() const { return type_ == TensorIndexType::Ellipsis; }
295 
IsInteger()296   inline bool IsInteger() const { return type_ == TensorIndexType::Integer; }
297 
integer()298   inline int64_t integer() const { return integer_; }
299 
IsBoolean()300   inline bool IsBoolean() const { return type_ == TensorIndexType::Boolean; }
301 
boolean()302   inline bool boolean() const { return boolean_; }
303 
IsSlice()304   inline bool IsSlice() const { return type_ == TensorIndexType::Slice; }
305 
slice()306   inline const Slice &slice() const { return slice_; }
307 
IsTensor()308   inline bool IsTensor() const { return type_ == TensorIndexType::Tensor; }
309 
tensor()310   inline const TensorPtr &tensor() const { return tensor_; }
311 
IsList()312   inline bool IsList() const { return type_ == TensorIndexType::List; }
313 
list()314   inline const py::list &list() const { return list_; }
315 
IsTuple()316   inline bool IsTuple() const { return type_ == TensorIndexType::Tuple; }
317 
tuple()318   inline const py::tuple &tuple() const { return tuple_; }
319 
IsSequence()320   inline bool IsSequence() const { return IsList() || IsTuple(); }
321 
array()322   inline const py::array &array() const { return array_; }
323 
IsArray()324   inline bool IsArray() const { return type_ == TensorIndexType::Array; }
325 
floating_point()326   inline const float &floating_point() const { return float_; }
327 
IsFloat()328   inline bool IsFloat() const { return type_ == TensorIndexType::Float; }
329 
type()330   inline const TensorIndexType &type() const { return type_; }
331 
332   static py::object GetItemByTensor(const ShapeVector &data_shape, const TensorPtr &index);
333   static py::object GetItemByList(const ShapeVector &data_shape, const TensorIndex &tensor_index);
334   static py::object GetItemByTuple(const ShapeVector &data_shape, const std::vector<TensorIndex> &tensor_indexes);
335   static bool GetItemByTupleWithView(const ValuePtr &data_value, const ShapeVector &data_shape,
336                                      const py::object &py_index, std::vector<int64_t> *data_transfer_types,
337                                      std::vector<py::object> *data_transfer_args, const TypePtr &data_type);
338   static py::object GetItemByBool(const ValuePtr &data_value, const ShapeVector &data_shape, bool index);
339   static py::object GetItemByNumber(const ShapeVector &data_shape, int64_t index);
340   static py::object GetItemByNumberWithView(const ValuePtr &data_value, const ShapeVector &data_shape, int64_t index);
341   static py::object GetItemBySlice(const ValuePtr &data_value, const ShapeVector &data_shape,
342                                    const TensorIndex &py_index);
343   static py::object GetItemIndexSimpleIndex(const py::object &py_index, const ValuePtr &data_value,
344                                             const ShapeVector &data_shape);
345   static py::object GetItemIndexInfo(const py::object &data, const py::object &index, const py::bool_ &is_ascend);
346 
347   static py::object SetItemByNumber(const ShapeVector &data_shape, const TypePtr &data_type, bool is_parameter,
348                                     const TensorIndex &tensor_index, const TensorIndexType &py_value_type);
349   static py::object SetItemByNumberWithView(const ShapeVector &data_shape, const TypePtr &data_type, bool is_parameter,
350                                             const TensorIndex &tensor_index, const TensorIndexType &py_value_type,
351                                             const ValuePtr &data_value);
352   static py::object SetItemByTensor(const ShapeVector &data_shape, bool is_parameter, const TensorIndex &tensor_index,
353                                     const TensorIndexType &py_value_type);
354 
355   static py::object SetItemByTuple(const ShapeVector &data_shape, const TypePtr &data_type, const TensorIndex &py_index,
356                                    const TensorIndexType &py_value_type);
357 
358   static py::object SetItemBySlice(const ShapeVector &data_shape, const TypePtr &data_type, const TensorIndex &py_index,
359                                    const TensorIndexType &py_value_type, const ValuePtr &data_value);
360 
361   static py::object SetItemIndexInfo(const py::object &data, const py::object &index, const py::object &value,
362                                      const py::bool_ &is_ascend);
363 
364   static py::object SetItemIndexByIndexType(const TensorIndex &index, const py::object &py_index,
365                                             const ShapeVector &data_shape, const TypePtr &data_type,
366                                             const TensorIndexType &value_type, bool is_parameter);
367   static py::handle py_index_handle_;
368   static py::handle py_value_handle_;
369   static bool is_ascend_;
370   static py::module np_module_;
371   static IndexOpType index_op_type_;
372 
373  private:
374   int64_t integer_ = 0;
375   bool boolean_ = false;
376   float float_ = 0.0;
377   Slice slice_;
378   TensorPtr tensor_;
379   py::array array_;
380   py::list list_;
381   py::tuple tuple_;
382   TensorIndexType type_;
383 
384   // ***********************************************utils*******************************************
385   static void CheckGetItemIndex(const TensorIndexType &index_data_type);
386   static void CheckSetItemIndex(const TensorIndexType &index_data_type, const TensorIndexType &value_data_type);
387   template <typename T>
CheckTypeIsInstance(const T & type,const std::vector<T> & target_types)388   static inline bool CheckTypeIsInstance(const T &type, const std::vector<T> &target_types) {
389     return std::any_of(target_types.begin(), target_types.end(),
390                        [&type](const auto &target_type) { return target_type == type; });
391   }
JudgeDataDim(int64_t data_dim,int64_t min_data_dim,int64_t max_data_dim)392   static inline void JudgeDataDim(int64_t data_dim, int64_t min_data_dim, int64_t max_data_dim) {
393     if (data_dim < min_data_dim || data_dim > max_data_dim) {
394       MS_EXCEPTION(ValueError) << "The input data's dim must in the range of [" << min_data_dim << ", " << max_data_dim
395                                << "], but got '" << data_dim << "'.";
396     }
397   }
398   template <typename T>
VectorToPyTuple(const std::vector<T> & item_shape)399   static inline py::tuple VectorToPyTuple(const std::vector<T> &item_shape) {
400     size_t tuple_size = item_shape.size();
401     py::tuple out(tuple_size);
402     for (size_t i = 0; i < tuple_size; i++) {
403       out[i] = item_shape[i];
404     }
405     return out;
406   }
407   static ShapeVector BroadCastShape(const ShapeVector &x_shape, const ShapeVector &y_shape);
BroadCastShape(const std::vector<ShapeVector> & tensor_indexes_shapes)408   static ShapeVector BroadCastShape(const std::vector<ShapeVector> &tensor_indexes_shapes) {
409     if (tensor_indexes_shapes.empty()) {
410       return {};
411     }
412     return std::accumulate(tensor_indexes_shapes.begin(), tensor_indexes_shapes.end(), tensor_indexes_shapes[0],
413                            [](const auto &output_shape, const auto &tensor_indexes_shape) {
414                              return BroadCastShape(output_shape, tensor_indexes_shape);
415                            });
416   }
SliceToVector(int64_t start,int64_t stop,int64_t step)417   static std::vector<int64_t> SliceToVector(int64_t start, int64_t stop, int64_t step) {
418     std::vector<int64_t> slice_ele_list_index;
419     if (step > 0) {
420       for (int64_t j = start; j < stop; j += step) {
421         (void)slice_ele_list_index.emplace_back(j);
422       }
423       return slice_ele_list_index;
424     }
425     for (int64_t j = start; j > stop; j += step) {
426       (void)slice_ele_list_index.emplace_back(j);
427     }
428     return slice_ele_list_index;
429   }
430 
431   // This is the c++ version of sequence_to_index in
432   // "mindspore/python/mindspore/ops/composite/multitype_ops/_constexpr_utils.py"
433   // Transforms sequence to tensor.
SequenceToTensor(const TensorIndex & tensor_index,int64_t dim_size)434   static inline TensorIndex SequenceToTensor(const TensorIndex &tensor_index, int64_t dim_size) {
435     return tensor_index.type_ == TensorIndexType::List ? SequenceToTensor<py::list>(tensor_index.list_, dim_size)
436                                                        : SequenceToTensor<py::tuple>(tensor_index.tuple_, dim_size);
437   }
438   template <typename T>
439   static TensorIndex SequenceToTensor(const T &sequence, int64_t dim_size);
440   static py::object Unpack(const py::object &x);
CheckRange(const py::object & x,int64_t dim_size)441   static inline py::object CheckRange(const py::object &x, int64_t dim_size) {
442     if (py::isinstance<py::int_>(x)) {
443       auto temp_x = x.cast<int64_t>();
444       if (temp_x >= dim_size || temp_x < -dim_size) {
445         MS_EXCEPTION(IndexError) << "index " << temp_x << " out of bounds for dimension with size " << dim_size;
446       }
447       MS_EXCEPTION_IF_ZERO("dim_size", dim_size);
448       return py::int_(CheckRange(temp_x, dim_size));
449     }
450     return x;
451   }
CheckRange(int64_t x,int64_t dim_size)452   static inline int64_t CheckRange(int64_t x, int64_t dim_size) {
453     MS_EXCEPTION_IF_ZERO("dim_size", dim_size);
454     return (dim_size + (x % dim_size)) % dim_size;
455   }
456 
CheckScalarValue(const py::handle & value)457   static bool CheckScalarValue(const py::handle &value) {
458     if (py::isinstance<Tensor>(value)) {
459       TensorPtr data = value.cast<TensorPtr>();
460       MS_EXCEPTION_IF_NULL(data);
461       auto data_shape = data->shape();
462       return data_shape.empty();
463     }
464     if (IsStubTensor(value)) {
465       auto data_shape = GetStubTensorInfo(value).first;
466       return data_shape.empty();
467     }
468     return CheckTypeIsInstance(TensorIndex(value).type(),
469                                {TensorIndexType::Float, TensorIndexType::Integer, TensorIndexType::Boolean});
470   }
471 
472   static py::object DeepList(const py::object &array_like, int64_t dim_size);
473   static py::object DeepTensorToNdArray(const py::object &array_like);
474   static py::array MakeNdArray(const py::object &a, int64_t dim_size);
475 
476   // This is the c++ version of _transform_ellipsis_to_slice in
477   // "mindspore/python/mindspore/ops/composite/multitype_ops/_compile_utils.py"
478   // Converts slice index into array
479   static std::vector<TensorIndex> TransformEllipsisToSlice(const ShapeVector &data_shape,
480                                                            const std::vector<TensorIndex> &indices);
481   static std::tuple<ShapeVector, ShapeVector, ShapeVector, int64_t> GenerateIndexInfoFromTupleOfMixedTensors(
482     const std::vector<int64_t> &tensor_positions, const std::vector<ShapeVector> &tensor_indexes_shapes,
483     const ShapeVector &slice_shapes, const TensorIndex &py_fancy_position);
484   // This is the c++ version of slice2indices in
485   // "mindspore/python/mindspore/ops/composite/multitype_ops/_constexpr_utils.py"
486   // Converts slice index into array
487   static TensorIndex SliceToArray(const TensorIndex &tensor_index, const ShapeVector &shape);
488 
489   // This is the c++ version of convert_slice_to_tensor in
490   // "mindspore/python/mindspore/ops/composite/multitype_ops/_compile_utils.py"
491   // Converts slice index into array
492   static TensorIndex SliceToArray(const TensorPtr &index, const ShapeVector &final_shape, size_t slice_cnt,
493                                   const ShapeVector &broadcast_shape, const ShapeVector &slice_shape,
494                                   int64_t fancy_position);
495 
ComputeSliceShape(const ShapeVector & slice_shape,size_t broadcast_shape_len,size_t slice_cnt,int64_t fancy_position)496   static ShapeVector ComputeSliceShape(const ShapeVector &slice_shape, size_t broadcast_shape_len, size_t slice_cnt,
497                                        int64_t fancy_position) {
498     ShapeVector shape(slice_shape.size(), 1);
499     if (slice_cnt >= shape.size()) {
500       MS_EXCEPTION(IndexError) << "Index out of shape size.";
501     }
502     shape[slice_cnt] = slice_shape[slice_cnt];
503     ShapeVector temp_shape(broadcast_shape_len, 1);
504     (void)shape.insert(shape.begin() + fancy_position, temp_shape.begin(), temp_shape.end());
505     return shape;
506   }
507 
ComputeMultiples(const ShapeVector & origin_shape,const ShapeVector & broadcast_shape)508   static ShapeVector ComputeMultiples(const ShapeVector &origin_shape, const ShapeVector &broadcast_shape) {
509     int64_t len_gap = SizeToLong(broadcast_shape.size()) - SizeToLong(origin_shape.size());
510     ShapeVector output_shape = broadcast_shape;
511     (void)std::transform(broadcast_shape.begin() + len_gap, broadcast_shape.end(), origin_shape.begin(),
512                          output_shape.begin() + len_gap, [](int64_t x, int64_t y) {
513                            MS_EXCEPTION_IF_ZERO("dim of data shape", y);
514                            return x / y;
515                          });
516     return output_shape;
517   }
518 
GeneratePaddingShape(const ShapeVector & shape,int64_t length)519   static ShapeVector GeneratePaddingShape(const ShapeVector &shape, int64_t length) {
520     if (SizeToLong(shape.size()) > length) {
521       MS_EXCEPTION(ValueError) << "Can not pad " << shape << " to length " << length;
522     }
523     ShapeVector pad_shape(length - SizeToLong(shape.size()), 1);
524     (void)pad_shape.insert(pad_shape.begin(), shape.begin(), shape.end());
525     return pad_shape;
526   }
527   static py::object BroadCastTo(const ShapeVector &broadcast_shape, const py::object &item);
528 
529   // This is the c++ version of _transform_indexing_tensor in
530   // "mindspore/python/mindspore/ops/composite/multitype_ops/_compile_utils.py"
531   //  BroadCast tensor to the required
532   static TensorIndex BroadCastTensor(const ShapeVector &broadcast_shape, const ShapeVector &final_shape,
533                                      const ShapeVector &new_shape, const TensorPtr &item);
534   static constexpr int64_t set_item_by_one_tensor = 0;
535   static constexpr int64_t set_item_by_tuple_tensor = 1;
536   static constexpr int64_t set_item_by_non_tensor = 2;
537   static constexpr int64_t int32_bytes_number = 4;
538   static std::tuple<int64_t, py::object, ShapeVector> GetValueTransferType(const TensorIndexType &py_value_type,
539                                                                            int64_t op_type, const TypePtr &data_type,
540                                                                            bool is_view);
541 
542   // This is the c++ version of format_tuple_indices in
543   // "mindspore/python/mindspore/ops/composite/multitype_ops/_compile_utils.py"
544   // Format tuple indices by unpacking high-dimension tuple and removing expand
545   // dimension signs(Bool and None).
UnpackTuple(const TensorIndex & tensor_index)546   static inline TensorIndex UnpackTuple(const TensorIndex &tensor_index) {
547     return tensor_index.type_ == TensorIndexType::List ? UnpackTuple<py::list>(tensor_index.list_)
548                                                        : UnpackTuple<py::tuple>(tensor_index.tuple_);
549   }
550 
551   // Expand tuple TensorIndex to std::vector<TensorIndex>
ExpandToVector()552   inline std::vector<TensorIndex> ExpandToVector() const {
553     std::vector<TensorIndex> output;
554     if (type_ == TensorIndexType::Tuple) {
555       output.reserve(tuple_.size());
556       for (auto const &e : tuple_) {
557         (void)output.emplace_back(TensorIndex(e));
558       }
559     } else {
560       output.reserve(list_.size());
561       for (auto const &e : list_) {
562         (void)output.emplace_back(TensorIndex(e));
563       }
564     }
565     return output;
566   }
567 
568   template <typename T>
569   static TensorIndex UnpackTuple(const T &sequence);
570 
UseCopySlice(const std::vector<TensorIndex> & indices,int64_t data_dims)571   static inline bool UseCopySlice(const std::vector<TensorIndex> &indices, int64_t data_dims) {
572     constexpr size_t min_tuple_index_len = 2;
573     if (indices.size() >= min_tuple_index_len && LongToSize(data_dims) >= min_tuple_index_len) {
574       bool valid = indices[0].IsInteger() && indices[1].IsSlice() && indices[1].slice().step() == 1;
575       return valid && std::all_of(indices.begin() + min_tuple_index_len, indices.end(), [](const TensorIndex &x) {
576                return x.IsSlice() && x.slice().start_init_by_none() && x.slice().stop_init_by_none() &&
577                       x.slice().step() == 1;
578              });
579     }
580     return false;
581   }
582 
583   // ***********************************************for get_item*******************************************
584   // This is the c++ version of get_stride_info_from_tuple in
585   // "mindspore/python/mindspore/ops/composite/multitype_ops/_constexpr_utils.py"
586   // Get stride info from a tuple
587   static py::tuple GenerateNonZeroIndex(const ShapeVector &data_shape, const TensorPtr &tensor_index, bool check_align);
588   static std::vector<TensorPtr> GenerateNonZeroIndexTensorList(const ShapeVector &data_shape,
589                                                                const TensorPtr &tensor_index, bool check_align);
590   static std::tuple<std::vector<std::vector<int64_t>>, std::vector<int64_t>> GetStrideInfoFromTuple(
591     const ShapeVector &data_shape, const std::vector<TensorIndex> &tuple_index);
592   static bool TensorGetitemByTupleParseTensorIndex(const ShapeVector &data_shape, const TensorPtr &tensor_index,
593                                                    std::vector<TensorPtr> *tuple_index_new,
594                                                    std::vector<TensorPtr> *tensor_indexes,
595                                                    std::vector<int64_t> *tensor_positions, bool check_align);
596   static std::tuple<bool, ShapeVector, std::vector<TensorIndex>> GetExpandDimsInfo(
597     const ShapeVector &data_shape, const std::vector<TensorIndex> &index);
598   static py::object GenerateIndices(const std::vector<TensorPtr> &tuple_index_new,
599                                     const std::vector<int64_t> &broadcast_shape,
600                                     const std::vector<int64_t> &index_tensor_new_shape,
601                                     const std::vector<int64_t> &final_shape,
602                                     const std::vector<int64_t> &tensor_positions,
603                                     const std::vector<int64_t> &slice_shapes, int64_t fancy_position);
604   static py::object TensorGetitemByTuple(const ShapeVector &data_shape, const std::vector<TensorIndex> &tuple_index,
605                                          std::vector<int64_t> *data_transfer_type,
606                                          std::vector<py::object> *data_transfer_args);
607 
608   // ***********************************************for set_item*******************************************
609   // This is the c++ version of format_list_indices in
610   // "mindspore/python/mindspore/ops/composite/multitype_ops/_compile_utils.py"
611   // Convert list indices to array or list indices based on its contents.
612   static TensorIndex FormatList(const TensorIndex &tensor_index, int64_t length);
613   static TensorPtr IntToTensor(int64_t i, const ShapeVector &shape);
614   static py::object GenerateIndicesFromTupleOfTensor(const ShapeVector &data_shape,
615                                                      const std::vector<TensorIndex> &tuple_index,
616                                                      ShapeVector *output_index_shape, py::object *data_transfer_arg);
617   static void RemNotExpandedDims(int64_t *idx_advanced, bool expand_true, int64_t tensor_index_ndim, int64_t rem_ndim,
618                                  std::vector<bool> *not_expanded_dim);
619 
FilterExpandedDims(const ShapeVector & shape,const std::vector<bool> & not_expanded_dim)620   static inline ShapeVector FilterExpandedDims(const ShapeVector &shape, const std::vector<bool> &not_expanded_dim) {
621     int64_t diff = SizeToLong(not_expanded_dim.size()) - SizeToLong(shape.size());
622     if (diff < 0) {
623       MS_EXCEPTION(ValueError) << "Input array must have the same size across all dimensions.";
624     }
625     std::vector<int64_t> res;
626     size_t index = std::min(shape.size(), not_expanded_dim.size() - static_cast<size_t>(diff));
627     for (size_t i = 0; i < index; i++) {
628       if (not_expanded_dim[(i + static_cast<size_t>(diff))]) {
629         (void)res.emplace_back(shape[i]);
630       }
631     }
632     return res;
633   }
634 
635   // This is the c++ version of format_index in
636   // "mindspore/python/mindspore/ops/composite/multitype_ops/_compile_utils.py"
637   // Converts advanced index into array
638   static TensorIndex FormatIndex(const TensorIndex &idx, const ShapeVector &data_shape, size_t cur_dim,
639                                  bool *need_format);
640   static bool RemoveExpandedDimsParseTensorIndex(const ShapeVector &data_shape, const TensorPtr &index_out,
641                                                  std::vector<TensorIndex> *indices_out,
642                                                  std::vector<ShapeVector> *shapes, bool *has_sequence, size_t *cur_dim,
643                                                  bool check_align);
644   static std::pair<std::vector<TensorIndex>, ShapeVector> RemoveExpandedDims(
645     const std::vector<TensorIndex> &indices, const ShapeVector &data_shape, const ShapeVector &value_shape,
646     std::vector<int64_t> *value_transfer_type, std::vector<py::object> *value_transfer_args, int64_t *idx_advanced,
647     bool *by_pass, std::vector<size_t> *format_index, std::vector<int64_t> *format_dim);
648   static py::object GenerateIndicesFromTuple(const ShapeVector &data_shape, const std::vector<TensorIndex> &tuple_index,
649                                              int64_t py_fancy_position, bool *by_pass, ShapeVector *output_index_shape,
650                                              py::object *data_transfer_arg);
651   static py::object ReSetitemByTensor(const std::vector<TensorIndex> &new_tuple_index,
652                                       const std::vector<int64_t> &value_transfer_types,
653                                       const std::vector<py::object> &value_transfer_args);
654   static py::object SetitemByTupleWithTensor(const ShapeVector &data_shape, const std::vector<TensorIndex> &indices,
655                                              const ShapeVector &value_shape, std::vector<int64_t> *value_transfer_type,
656                                              std::vector<py::object> *value_transfer_args);
657 
658   static py::object SetitemBySliceWithTensor(const ShapeVector &data_shape, const TensorIndex &slice_index,
659                                              std::vector<int64_t> *value_transfer_type,
660                                              std::vector<py::object> *value_transfer_args, const ValuePtr &data_value,
661                                              const TypePtr &data_type);
662 
663   static py::array SetItemByTensorByBool(const ShapeVector &data_shape, const TensorPtr &index, int64_t data_dims,
664                                          std::vector<int64_t> *value_transfer_types,
665                                          std::vector<py::object> *value_transfer_args,
666                                          ValueTransferType *tensor_update_type);
667 
668   friend std::ostream &operator<<(std::ostream &stream, const TensorIndex &tensor_index);
669   friend std::ostream &operator<<(std::ostream &stream, const std::vector<TensorIndex> &tensor_indices);
670 };
671 }  // namespace tensor
672 }  // namespace mindspore
673 #endif  // MINDSPORE_CCSRC_UTILS_TENSOR_INDEX_PY_H_
674