• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pybind_api/ir/tensor_index_py.h"
18 #include <pybind11/stl.h>
19 #include <memory>
20 #include <string>
21 #include <algorithm>
22 #include <utility>
23 #include <vector>
24 #include <functional>
25 #include "pybind11/pytypes.h"
26 #include "pipeline/jit/ps/parse/parse_base.h"
27 #include "utils/hash_set.h"
28 #include "utils/log_adapter.h"
29 #include "pipeline/pynative/pynative_execute.h"
30 #include "mindspore/core/ops/array_ops.h"
31 
32 namespace mindspore::tensor {
33 using tensor::TensorPy;
34 py::handle TensorIndex::py_index_handle_ = py::none();
35 py::handle TensorIndex::py_value_handle_ = py::none();
36 bool TensorIndex::is_ascend_ = false;
37 IndexOpType TensorIndex::index_op_type_ = IndexOpType::GetItem;
38 py::module TensorIndex::np_module_ = py::module();
39 static const std::vector<TypeId> kIntTypes{kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64};
40 // ***********************************************utils*******************************************
operator <<(std::ostream & stream,const TensorIndex & tensor_index)41 std::ostream &operator<<(std::ostream &stream, const TensorIndex &tensor_index) {
42   TensorIndexType tensor_index_type = tensor_index.type();
43   switch (tensor_index_type) {
44     case TensorIndexType::None: {
45       stream << "None";
46       break;
47     }
48     case TensorIndexType::Integer: {
49       stream << tensor_index.integer();
50       break;
51     }
52     case TensorIndexType::Ellipsis: {
53       stream << "...";
54       break;
55     }
56     case TensorIndexType::Boolean: {
57       stream << std::boolalpha << tensor_index.boolean();
58       break;
59     }
60     case TensorIndexType::Slice: {
61       stream << tensor_index.slice();
62       break;
63     }
64     case TensorIndexType::Tensor: {
65       MS_EXCEPTION_IF_NULL(tensor_index.tensor());
66       stream << tensor_index.tensor()->ToString();
67       break;
68     }
69     case TensorIndexType::List: {
70       stream << tensor_index.list();
71       break;
72     }
73     case TensorIndexType::Tuple: {
74       stream << tensor_index.tuple();
75       break;
76     }
77     case TensorIndexType::Array: {
78       stream << tensor_index.array();
79       break;
80     }
81     case TensorIndexType::Float: {
82       stream << tensor_index.floating_point();
83       break;
84     }
85   }
86   return stream;
87 }
88 
operator <<(std::ostream & stream,const std::vector<TensorIndex> & tensor_indices)89 std::ostream &operator<<(std::ostream &stream, const std::vector<TensorIndex> &tensor_indices) {
90   stream << "(";
91   for (size_t i = 0; i < tensor_indices.size(); i++) {
92     stream << tensor_indices[i];
93     if (i < tensor_indices.size() - 1) {
94       stream << ", ";
95     }
96   }
97   stream << ")";
98   return stream;
99 }
100 
CheckGetItemIndex(const TensorIndexType & index_data_type)101 void TensorIndex::CheckGetItemIndex(const TensorIndexType &index_data_type) {
102   bool valid = CheckTypeIsInstance<TensorIndexType>(
103     index_data_type,
104     {TensorIndexType::Tensor, TensorIndexType::List, TensorIndexType::Boolean, TensorIndexType::Slice,
105      TensorIndexType::Integer, TensorIndexType::Tuple, TensorIndexType::Ellipsis, TensorIndexType::None});
106   if (!valid) {
107     MS_EXCEPTION(IndexError)
108       << "Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor, int, list and "
109          "tuple as index, but got "
110       << TensorIndex::py_index_handle_ << " with type " << TensorIndex::py_index_handle_.get_type();
111   }
112 }
113 
CheckSetItemIndex(const TensorIndexType & index_data_type,const TensorIndexType & value_data_type)114 void TensorIndex::CheckSetItemIndex(const TensorIndexType &index_data_type, const TensorIndexType &value_data_type) {
115   CheckGetItemIndex(index_data_type);
116   bool valid = CheckTypeIsInstance<TensorIndexType>(
117     value_data_type, {TensorIndexType::Integer, TensorIndexType::Float, TensorIndexType::Boolean,
118                       TensorIndexType::Tensor, TensorIndexType::List, TensorIndexType::Tuple});
119   if (!valid) {
120     MS_EXCEPTION(TypeError) << "only support numbers, Tensor, tuple, list as value, but got "
121                             << TensorIndex::py_value_handle_ << " with type "
122                             << TensorIndex::py_value_handle_.get_type();
123   }
124 }
125 
BroadCastShape(const ShapeVector & x_shape,const ShapeVector & y_shape)126 ShapeVector TensorIndex::BroadCastShape(const ShapeVector &x_shape, const ShapeVector &y_shape) {
127   if (x_shape == y_shape) {
128     return x_shape;
129   }
130   const size_t x_len = x_shape.size();
131   const size_t y_len = y_shape.size();
132   const size_t min_length = std::min(x_len, y_len);
133   ShapeVector broadcast_shape_back;
134 
135   for (size_t i = 0; i < min_length; i++) {
136     size_t x_shape_index = x_len - min_length + i;
137     size_t y_shape_index = y_len - min_length + i;
138     if (x_shape[x_shape_index] == 1) {
139       (void)broadcast_shape_back.emplace_back(y_shape[y_shape_index]);
140     } else if (y_shape[y_shape_index] == 1 || x_shape[x_shape_index] == y_shape[y_shape_index]) {
141       (void)broadcast_shape_back.emplace_back(x_shape[x_shape_index]);
142     } else {
143       string index_op_type = index_op_type_ == IndexOpType::GetItem ? "tensor getitem" : "tensor setitem";
144       MS_EXCEPTION(ValueError) << "For '" << index_op_type
145                                << "', x.shape and y.shape need to broadcast. The value of x.shape["
146                                << std::to_string(x_shape_index) << "] or y.shape[" << std::to_string(y_shape_index)
147                                << "] must be 1 or -1 when they are not the same, but got x.shape = " << x_shape
148                                << " and y.shape = " << y_shape;
149     }
150   }
151   ShapeVector broadcast_shape_front;
152   if (min_length == x_len) {
153     (void)broadcast_shape_front.insert(
154       broadcast_shape_front.end(), y_shape.begin(),
155       y_shape.begin() + static_cast<int64_t>(y_len) - static_cast<int64_t>(min_length));
156   } else {
157     (void)broadcast_shape_front.insert(
158       broadcast_shape_front.end(), x_shape.begin(),
159       x_shape.begin() + static_cast<int64_t>(x_len) - static_cast<int64_t>(min_length));
160   }
161   (void)broadcast_shape_front.insert(broadcast_shape_front.end(), broadcast_shape_back.begin(),
162                                      broadcast_shape_back.end());
163   return broadcast_shape_front;
164 }
165 
166 template <typename T>
SequenceToTensor(const T & sequence,int64_t dim_size)167 TensorIndex TensorIndex::SequenceToTensor(const T &sequence, int64_t dim_size) {
168   if (sequence.empty()) {
169     return TensorIndex(py::bool_(false));
170   }
171   if (std::all_of(sequence.begin(), sequence.end(), [](auto &x) { return py::isinstance<py::bool_>(x); })) {
172     int64_t seq_size = SizeToLong(sequence.size());
173     if (seq_size != dim_size) {
174       MS_EXCEPTION(IndexError) << "dimension is " << dim_size << " but corresponding boolean dimension is " << seq_size;
175     }
176     py::list new_range_dim_size;
177     for (size_t i = 0; i < sequence.size(); i++) {
178       if (py::cast<bool>(sequence[i]) == true) {
179         new_range_dim_size.append(py::int_(i));
180       }
181     }
182     if (new_range_dim_size.empty()) {
183       return TensorIndex(py::bool_(false));
184     }
185     return TensorIndex(TensorPy::MakeTensor(MakeNdArray(new_range_dim_size, dim_size)));
186   }
187   py::array output = MakeNdArray(sequence, dim_size);
188   if (output.dtype() == pybind11::dtype("object")) {
189     MS_LOG(EXCEPTION) << "Sequence as indices must have the same size across all dimensions and elements must be "
190                          "integer (or boolean) type";
191   }
192   return TensorIndex(TensorPy::MakeTensor(output));
193 }
194 
Unpack(const py::object & x)195 py::object TensorIndex::Unpack(const py::object &x) {
196   if (py::isinstance<py::tuple>(x)) {
197     auto new_x = x.cast<py::tuple>();
198     if (new_x.size() == 1) {
199       return Unpack(new_x[0]);
200     }
201   }
202   if (py::isinstance<py::list>(x)) {
203     auto new_x = x.cast<py::list>();
204     if (new_x.size() == 1) {
205       return Unpack(new_x[0]);
206     }
207   }
208   return x;
209 }
210 
211 template <typename T>
UnpackTuple(const T & sequence)212 TensorIndex TensorIndex::UnpackTuple(const T &sequence) {
213   py::tuple res(sequence.size());
214   for (size_t i = 0; i < sequence.size(); i++) {
215     if (py::isinstance<py::list>(sequence[i]) || py::isinstance<py::tuple>(sequence[i])) {
216       res[i] = Unpack(sequence[i]);
217     } else {
218       res[i] = sequence[i];
219     }
220   }
221   return TensorIndex(res);
222 }
223 
DeepList(const py::object & array_like,int64_t dim_size)224 py::object TensorIndex::DeepList(const py::object &array_like, int64_t dim_size) {
225   py::object new_array_like = CheckRange(array_like, dim_size);
226   if (py::isinstance<py::list>(array_like) || py::isinstance<py::tuple>(array_like)) {
227     auto list_array_like = array_like.cast<py::list>();
228     for (size_t i = 0; i < list_array_like.size(); i++) {
229       list_array_like[i] = DeepList(list_array_like[i], dim_size);
230     }
231     return list_array_like;
232   }
233   return new_array_like;
234 }
235 
DeepTensorToNdArray(const py::object & array_like)236 py::object TensorIndex::DeepTensorToNdArray(const py::object &array_like) {
237   if (py::isinstance<tensor::Tensor>(array_like) || IsStubTensor(array_like)) {
238     auto tensor_index = IsStubTensor(array_like) ? ConvertStubTensor(array_like) : py::cast<TensorPtr>(array_like);
239     MS_EXCEPTION_IF_NULL(tensor_index);
240     return TensorPy::AsNumpy(*tensor_index);
241   }
242   if (py::isinstance<py::list>(array_like)) {
243     auto new_array_like_vector = array_like.cast<py::list>();
244     for (size_t i = 0; i < new_array_like_vector.size(); i++) {
245       new_array_like_vector[i] = DeepTensorToNdArray(new_array_like_vector[i]);
246     }
247     return new_array_like_vector;
248   }
249   return array_like;
250 }
251 
MakeNdArray(const py::object & a,int64_t dim_size)252 py::array TensorIndex::MakeNdArray(const py::object &a, int64_t dim_size) {
253   if (!py::isinstance<py::list>(a) && !py::isinstance<py::tuple>(a) && !py::isinstance<py::int_>(a) &&
254       !py::isinstance<py::float_>(a) && !py::isinstance<py::bool_>(a)) {
255     MS_EXCEPTION(TypeError) << "Input data must be `int`, `float`, `bool`, `list` or `tuple` but got " << a.get_type();
256   }
257   py::object new_array = CheckRange(a, dim_size);
258   if (py::isinstance<py::list>(new_array) || py::isinstance<py::tuple>(new_array)) {
259     new_array = DeepList(new_array, dim_size);
260     new_array = DeepTensorToNdArray(new_array);
261   }
262   return new_array;
263 }
264 
265 namespace Convert {
ConvertTypeToString(const TensorIndex & index)266 string ConvertTypeToString(const TensorIndex &index) {
267   if (index.IsNone())
268     return "None";
269   else if (index.IsEllipsis())
270     return "Ellipsis";
271   else if (index.IsInteger())
272     return "Integer";
273   else if (index.IsBoolean())
274     return "Boolean";
275   else if (index.IsSlice())
276     return "Slice";
277   else if (index.IsTensor())
278     return "Tensor";
279   else if (index.IsList())
280     return "List";
281   else if (index.IsTuple())
282     return "Tuple";
283   else if (index.IsArray())
284     return "Array";
285   else if (index.IsFloat())
286     return "Float";
287   return "Unknown";
288 }
289 }  // namespace Convert
290 
TransformEllipsisToSlice(const ShapeVector & data_shape,const std::vector<TensorIndex> & indices)291 std::vector<TensorIndex> TensorIndex::TransformEllipsisToSlice(const ShapeVector &data_shape,
292                                                                const std::vector<TensorIndex> &indices) {
293   // Check if the tuple index len is longer than the data's dims and transform ellipsis in the indices
294   // to several slice.
295   int64_t ellipsis_occupy_dims = SizeToLong(data_shape.size());
296   int64_t ellipsis_positions = 0;
297   int64_t ellipsis_cnt = 0;
298   for (size_t i = 0; i < indices.size(); i++) {
299     bool valid = (CheckTypeIsInstance<TensorIndexType>(
300       indices[i].type(),
301       {TensorIndexType::List, TensorIndexType::Ellipsis, TensorIndexType::Tuple, TensorIndexType::None,
302        TensorIndexType::Integer, TensorIndexType::Tensor, TensorIndexType::Slice, TensorIndexType::Boolean}));
303     if (!valid) {
304       MS_EXCEPTION(TypeError) << "For tuple index, the types only support 'Slice', 'Ellipsis', 'None', 'Tensor', "
305                                  "'int', 'List', 'Tuple', 'bool', but got type '"
306                               << Convert::ConvertTypeToString(indices[i]) << "', value: " << indices[i];
307     }
308     if (indices[i].IsSlice() || indices[i].IsInteger() || indices[i].IsTensor() || indices[i].IsSequence()) {
309       ellipsis_occupy_dims -= 1;
310     } else if (indices[i].IsEllipsis()) {
311       if (ellipsis_cnt >= 1) {
312         MS_EXCEPTION(IndexError) << "An index can only have a single ellipsis('...')";
313       }
314       ellipsis_cnt += 1;
315       ellipsis_positions = static_cast<int64_t>(i);
316     }
317   }
318   if (ellipsis_occupy_dims < 0) {
319     MS_EXCEPTION(IndexError) << "Tuple index " << indices << " out rang of tensor shape " << data_shape;
320   }
321 
322   if (ellipsis_cnt == 0) {
323     return indices;
324   }
325 
326   std::vector<TensorIndex> empty_slice(ellipsis_occupy_dims, TensorIndex(Slice()));
327   std::vector<TensorIndex> new_indices(indices.begin(), indices.end());
328   MS_EXCEPTION_IF_CHECK_FAIL(ellipsis_positions <= SizeToLong(new_indices.size()), "Index out of vector size.");
329   (void)new_indices.insert(new_indices.erase(new_indices.begin() + ellipsis_positions), empty_slice.begin(),
330                            empty_slice.end());
331   return new_indices;
332 }
333 
GenerateIndexInfoFromTupleOfMixedTensors(const std::vector<int64_t> & tensor_positions,const std::vector<ShapeVector> & tensor_indexes_shapes,const ShapeVector & slice_shapes,const TensorIndex & py_fancy_position)334 std::tuple<ShapeVector, ShapeVector, ShapeVector, int64_t> TensorIndex::GenerateIndexInfoFromTupleOfMixedTensors(
335   const std::vector<int64_t> &tensor_positions, const std::vector<ShapeVector> &tensor_indexes_shapes,
336   const ShapeVector &slice_shapes, const TensorIndex &py_fancy_position) {
337   bool tensor_index_continue_tag = true;
338   if (tensor_positions.empty()) {
339     tensor_index_continue_tag = false;
340   }
341   for (size_t i = 1; i < tensor_positions.size(); i++) {
342     if (tensor_positions[i] != tensor_positions[i - 1] + 1) {
343       tensor_index_continue_tag = false;
344       break;
345     }
346   }
347   int64_t fancy_position = 0;
348   if (py_fancy_position.IsNone()) {
349     fancy_position = tensor_index_continue_tag ? tensor_positions[0] : 0;
350   } else {
351     fancy_position = py_fancy_position.integer();
352   }
353 
354   ShapeVector broadcast_shape = BroadCastShape(tensor_indexes_shapes);
355 
356   fancy_position = std::min(fancy_position, SizeToLong(slice_shapes.size()));
357   ShapeVector final_shape = slice_shapes;
358   (void)final_shape.insert(final_shape.begin() + fancy_position, broadcast_shape.begin(), broadcast_shape.end());
359 
360   ShapeVector index_tensor_new_shape(slice_shapes.size(), 1);
361   fancy_position = std::min(fancy_position, SizeToLong(index_tensor_new_shape.size()));
362 
363   (void)index_tensor_new_shape.insert(index_tensor_new_shape.begin() + fancy_position, broadcast_shape.begin(),
364                                       broadcast_shape.end());
365 
366   return std::make_tuple(broadcast_shape, index_tensor_new_shape, final_shape, fancy_position);
367 }
368 
SliceToArray(const TensorIndex & tensor_index,const ShapeVector & shape)369 TensorIndex TensorIndex::SliceToArray(const TensorIndex &tensor_index, const ShapeVector &shape) {
370   MS_EXCEPTION_IF_CHECK_FAIL(!shape.empty(), "DataShape of Tensor can not be empty when sed item");
371   Slice slice_info = Slice(tensor_index.slice(), shape[0]);
372   int64_t start = slice_info.start();
373   int64_t stop = slice_info.stop();
374   int64_t step = slice_info.step();
375   if ((start - stop) * step >= 0) {
376     return TensorIndex(py::bool_(false));
377   }
378   int64_t n_dim = SizeToLong(shape.size());
379   py::tuple grids(n_dim);
380   grids[0] = TensorIndex::np_module_.attr("arange")(py::int_(start), py::int_(stop), py::int_(step));
381   for (size_t i = 1; i < shape.size(); i++) {
382     grids[i] = TensorIndex::np_module_.attr("arange")(0, py::int_(shape[i]), 1, TensorIndex::np_module_.attr("int32"));
383   }
384 
385   py::object mesh = TensorIndex::np_module_.attr("ix_")(*grids);
386   py::tuple broadcast_mesh = TensorIndex::np_module_.attr("broadcast_arrays")(*mesh);
387   return TensorIndex(TensorIndex::np_module_.attr("stack")(broadcast_mesh, -1));
388 }
389 
SliceToArray(const TensorPtr & index,const ShapeVector & final_shape,size_t slice_cnt,const ShapeVector & broadcast_shape,const ShapeVector & slice_shape,int64_t fancy_position)390 TensorIndex TensorIndex::SliceToArray(const TensorPtr &index, const ShapeVector &final_shape, size_t slice_cnt,
391                                       const ShapeVector &broadcast_shape, const ShapeVector &slice_shape,
392                                       int64_t fancy_position) {
393   ShapeVector shape = ComputeSliceShape(slice_shape, broadcast_shape.size(), slice_cnt, fancy_position);
394 
395   py::object array = TensorPy::SyncAsNumpy(*index);
396   array = TensorIndex::np_module_.attr("ndarray").attr("astype")(array, TensorIndex::np_module_.attr("int32"));
397   array = TensorIndex::np_module_.attr("reshape")(array, py::cast(shape));
398   array = BroadCastTo(final_shape, array);
399   return TensorIndex(array);
400 }
401 
BroadCastTo(const ShapeVector & broadcast_shape,const py::object & item)402 py::object TensorIndex::BroadCastTo(const ShapeVector &broadcast_shape, const py::object &item) {
403   return TensorIndex::np_module_.attr("broadcast_to")(item, py::cast(broadcast_shape));
404 }
405 
BroadCastTensor(const ShapeVector & broadcast_shape,const ShapeVector & final_shape,const ShapeVector & new_shape,const TensorPtr & item)406 TensorIndex TensorIndex::BroadCastTensor(const ShapeVector &broadcast_shape, const ShapeVector &final_shape,
407                                          const ShapeVector &new_shape, const TensorPtr &item) {
408   py::array py_item = TensorPy::SyncAsNumpy(*item);
409   py_item = TensorIndex::np_module_.attr("ndarray").attr("astype")(py_item, TensorIndex::np_module_.attr("int32"));
410   py_item = BroadCastTo(broadcast_shape, py_item);
411   return TensorIndex(BroadCastTo(final_shape, TensorIndex::np_module_.attr("reshape")(py_item, py::cast(new_shape))));
412 }
413 
GetValueTransferType(const TensorIndexType & py_value_type,int64_t op_type,const TypePtr & data_type,bool is_view)414 std::tuple<int64_t, py::object, ShapeVector> TensorIndex::GetValueTransferType(const TensorIndexType &py_value_type,
415                                                                                int64_t op_type,
416                                                                                const TypePtr &data_type, bool is_view) {
417   ValueTransferType value_transfer_type = ValueTransferType::kByPass;
418   py::object value_transfer_arg = py::none();
419   ShapeVector value_shape = {};
420   if (py_value_type == TensorIndexType::Tensor) {
421     if (is_view) {
422       return std::make_tuple(static_cast<int>(value_transfer_type), value_transfer_arg, value_shape);
423     }
424     value_transfer_arg = py::none();
425     if (IsStubTensor(TensorIndex::py_value_handle_)) {
426       value_shape = GetStubTensorInfo(TensorIndex::py_value_handle_).first;
427     } else {
428       auto value_ptr = TensorIndex::py_value_handle_.cast<TensorPtr>();
429       MS_EXCEPTION_IF_NULL(value_ptr);
430       value_shape = value_ptr->shape();
431     }
432   } else if (CheckTypeIsInstance(py_value_type,
433                                  {TensorIndexType::Float, TensorIndexType::Integer, TensorIndexType::Boolean})) {
434     value_transfer_type = ValueTransferType::kNumberToTensor;
435     value_transfer_arg = py::none();
436   } else if (py_value_type == TensorIndexType::List || py_value_type == TensorIndexType::Tuple) {
437     value_transfer_type = ValueTransferType::kHandleSequenceValue;
438     auto py_value_list = TensorIndex::py_value_handle_.cast<py::list>();
439     if (!py_value_list.empty()) {
440       (void)value_shape.emplace_back(SizeToLong(py_value_list.size()));
441       const py::object &first_py_ele = py_value_list[0];
442       TensorPtr ele;
443       if (py::isinstance<Tensor>(first_py_ele) || IsStubTensor(first_py_ele)) {
444         ele = IsStubTensor(first_py_ele) ? ConvertStubTensor(first_py_ele) : py::cast<TensorPtr>(first_py_ele);
445       } else {
446         ele = TensorPy::MakeTensor(py_value_list[0], data_type);
447       }
448       MS_EXCEPTION_IF_NULL(ele);
449       (void)value_shape.insert(value_shape.end(), ele->shape().begin(), ele->shape().end());
450     }
451     value_transfer_arg = py::make_tuple(py::int_(op_type), TensorIndex::py_index_handle_);
452   }
453   return std::make_tuple(static_cast<int>(value_transfer_type), value_transfer_arg, value_shape);
454 }
455 
CastToInt(const py::array & input)456 static py::array CastToInt(const py::array &input) {
457   return TensorIndex::np_module_.attr("ndarray").attr("astype")(input, TensorIndex::np_module_.attr("int32"));
458 }
459 
CheckLargeTensor(const ShapeVector & data_shape)460 static bool CheckLargeTensor(const ShapeVector &data_shape) {
461   constexpr int64_t max_dim = 1024 * 32;
462   int64_t data_shape_dim = std::accumulate(data_shape.begin(), data_shape.end(), 1, std::multiplies<>());
463   return data_shape_dim > max_dim;
464 }
465 
466 // ***********************************************for get_item*******************************************
GenerateNonZeroIndex(const ShapeVector & data_shape,const TensorPtr & tensor_index,bool check_align)467 py::tuple TensorIndex::GenerateNonZeroIndex(const ShapeVector &data_shape, const TensorPtr &tensor_index,
468                                             bool check_align) {
469   if (!check_align) {
470     py::array index_array = TensorPy::SyncAsNumpy(*tensor_index);
471     return TensorIndex::np_module_.attr("nonzero")(index_array);
472   }
473   const int64_t data_dim = SizeToLong(data_shape.size());
474   const int64_t index_dims = tensor_index->DataDim();
475   if (data_dim < index_dims) {
476     MS_EXCEPTION(IndexError) << "The dim of index cannot be greater than indexed data, but got dim of index:"
477                              << index_dims << ", dim of data:" << data_dim;
478   }
479   for (size_t i = 0; i < static_cast<size_t>(index_dims); i++) {
480     if (data_shape[i] != tensor_index->shape()[i]) {
481       MS_EXCEPTION(ValueError) << "The shape of index " << tensor_index->shape()
482                                << "does not match the shape of the indexed data " << data_shape << " at dim index" << i;
483     }
484   }
485   py::array index_array = TensorPy::SyncAsNumpy(*tensor_index);
486   return TensorIndex::np_module_.attr("nonzero")(index_array);
487 }
488 
GenerateNonZeroIndexTensorList(const ShapeVector & data_shape,const TensorPtr & tensor_index,bool check_align)489 std::vector<TensorPtr> TensorIndex::GenerateNonZeroIndexTensorList(const ShapeVector &data_shape,
490                                                                    const TensorPtr &tensor_index, bool check_align) {
491   py::tuple nonzero_indices = GenerateNonZeroIndex(data_shape, tensor_index, check_align);
492   MS_EXCEPTION_IF_CHECK_FAIL(!nonzero_indices.empty(), "Output size of nonzero should not be empty");
493   int64_t nonzero_indices_nums = SizeToLong(len(py::array(nonzero_indices[0])));
494   if (nonzero_indices_nums == 0) {
495     return {};
496   }
497   std::vector<TensorPtr> nonzero_indices_tensor_list;
498   (void)std::transform(nonzero_indices.begin(), nonzero_indices.end(), std::back_inserter(nonzero_indices_tensor_list),
499                        [](const py::handle &nonzero_index) {
500                          return TensorPy::MakeTensor(TensorIndex::np_module_.attr("array")(nonzero_index));
501                        });
502   return nonzero_indices_tensor_list;
503 }
504 
TensorGetitemByTupleParseTensorIndex(const ShapeVector & data_shape,const TensorPtr & tensor_index,std::vector<TensorPtr> * tuple_index_new,std::vector<TensorPtr> * tensor_indexes,std::vector<int64_t> * tensor_positions,bool check_align)505 bool TensorIndex::TensorGetitemByTupleParseTensorIndex(const ShapeVector &data_shape, const TensorPtr &tensor_index,
506                                                        std::vector<TensorPtr> *tuple_index_new,
507                                                        std::vector<TensorPtr> *tensor_indexes,
508                                                        std::vector<int64_t> *tensor_positions, bool check_align) {
509   //  parse index of tensor type
510   MS_EXCEPTION_IF_NULL(tensor_index);
511   if (CheckTypeIsInstance<TypeId>(tensor_index->data_type(), kIntTypes)) {
512     tensor_positions->emplace_back(tuple_index_new->size());
513     tuple_index_new->emplace_back(tensor_index);
514     tensor_indexes->emplace_back(tensor_index);
515   } else if (tensor_index->data_type() == kNumberTypeBool) {
516     std::vector<TensorPtr> nonzero_indices_tensors =
517       GenerateNonZeroIndexTensorList(data_shape, tensor_index, check_align);
518     if (nonzero_indices_tensors.empty()) {
519       return false;
520     }
521     int64_t nonzero_indices_position = SizeToLong(tuple_index_new->size());
522     (void)std::transform(nonzero_indices_tensors.begin(), nonzero_indices_tensors.end(),
523                          std::back_inserter(*tensor_positions),
524                          [&nonzero_indices_position](auto &) { return nonzero_indices_position++; });
525     tuple_index_new->insert(tuple_index_new->end(), nonzero_indices_tensors.begin(), nonzero_indices_tensors.end());
526     tensor_indexes->insert(tensor_indexes->end(), nonzero_indices_tensors.begin(), nonzero_indices_tensors.end());
527   } else {
528     MS_EXCEPTION(IndexError) << "The tensor element in tuple index must be int or bool type, but got "
529                              << TypeIdToString(tensor_index->data_type(), false);
530   }
531   return true;
532 }
533 
GetStrideInfoFromTuple(const ShapeVector & data_shape,const std::vector<TensorIndex> & tuple_index)534 std::tuple<std::vector<std::vector<int64_t>>, std::vector<int64_t>> TensorIndex::GetStrideInfoFromTuple(
535   const ShapeVector &data_shape, const std::vector<TensorIndex> &tuple_index) {
536   const size_t data_dim = data_shape.size();
537   const size_t tuple_index_len = tuple_index.size();
538   const size_t stride_slice_info_size = std::min(tuple_index_len, data_dim);
539   std::vector<int64_t> begin_info(stride_slice_info_size);
540   std::vector<int64_t> end_info(stride_slice_info_size);
541   std::vector<int64_t> step_info(stride_slice_info_size);
542 
543   size_t index_count = 0;
544   int64_t shrink_axis = 0;
545   int64_t ellipsis_count = 0;
546 
547   for (size_t i = 0; i < stride_slice_info_size; i++) {
548     const TensorIndex &index = tuple_index[i];
549 
550     int64_t dim_size = data_shape[i];
551     if (index.IsSlice()) {
552       Slice slice_info = Slice(index.slice(), dim_size);
553       begin_info[i] = slice_info.start();
554       end_info[i] = slice_info.stop();
555       step_info[i] = slice_info.step();
556       index_count += 1;
557     } else if (index.IsInteger()) {
558       const auto mask_bit = 1 << index_count;
559       begin_info[i] = index.integer();
560       end_info[i] = index.integer() + 1;
561       step_info[i] = 1;
562       shrink_axis += mask_bit;
563       index_count += 1;
564     } else if (index.IsEllipsis()) {
565       ellipsis_count = ellipsis_count + 1;
566       if (ellipsis_count > 1) {
567         MS_EXCEPTION(ValueError) << "An Tensor index can have only one ellipsis (...) ";
568       }
569       auto ellipsis_range_size = data_dim - tuple_index_len + 1;
570       for (size_t j = 0; j < ellipsis_range_size; j++) {
571         MS_EXCEPTION_IF_CHECK_FAIL(index_count + j < stride_slice_info_size && index_count + j < data_dim,
572                                    "Index out of data dims");
573         begin_info[index_count + j] = 0;
574         end_info[index_count + j] = data_shape[index_count + j];
575         step_info[index_count + j] = 1;
576       }
577       index_count += ellipsis_range_size;
578     }
579   }
580 
581   int64_t begin_mask = 0;
582   int64_t end_mask = 0;
583 
584   for (size_t i = 0; i < tuple_index_len; i++) {
585     if (tuple_index[i].IsSlice()) {
586       Slice slice_info = tuple_index[i].slice();
587       const auto mask_bit = 1 << i;
588       if (slice_info.start_init_by_none()) {
589         begin_mask += mask_bit;
590       }
591       if (slice_info.stop_init_by_none()) {
592         end_mask += mask_bit;
593       }
594     }
595   }
596   for (size_t i = tuple_index_len; i < data_dim; i++) {
597     const auto mask_bit = 1 << i;
598     begin_mask += mask_bit;
599     end_mask += mask_bit;
600   }
601 
602   return std::make_tuple(std::vector<std::vector<int64_t>>({begin_info, end_info, step_info}),
603                          std::vector<int64_t>({begin_mask, end_mask, shrink_axis}));
604 }
605 
GetExpandDimsInfo(const ShapeVector & data_shape,const std::vector<TensorIndex> & index)606 std::tuple<bool, ShapeVector, std::vector<TensorIndex>> TensorIndex::GetExpandDimsInfo(
607   const ShapeVector &data_shape, const std::vector<TensorIndex> &index) {
608   bool need_expand_dims = std::any_of(index.begin(), index.end(), [](auto &x) { return x.IsNone() || x.IsBoolean(); });
609   if (!need_expand_dims) {
610     return std::make_tuple(false, ShapeVector(), std::vector<TensorIndex>());
611   }
612   std::vector<TensorIndex> new_tuple_index;
613   std::vector<int64_t> expand_dims_info;
614   for (size_t i = 0; i < index.size(); i++) {
615     if (index[i].IsNone()) {
616       (void)new_tuple_index.emplace_back(tensor::Slice());
617       (void)expand_dims_info.emplace_back(i);
618     } else if (index[i].IsBoolean()) {
619       if (!index[i].boolean()) {
620         MS_EXCEPTION(IndexError) << "Bool element of tuple index must be 'True', but got 'False'.";
621       }
622       (void)new_tuple_index.emplace_back(std::make_shared<Tensor>(std::vector<int64_t>({0})));
623       (void)expand_dims_info.emplace_back(i);
624     } else {
625       (void)new_tuple_index.emplace_back(index[i]);
626     }
627   }
628   auto reshape_info = data_shape;
629   for (auto dim : expand_dims_info) {
630     dim = std::min(dim, SizeToLong(reshape_info.size()));
631     (void)reshape_info.insert(reshape_info.begin() + dim, 1);
632   }
633 
634   return std::make_tuple(need_expand_dims, reshape_info, new_tuple_index);
635 }
636 
GenerateIndices(const std::vector<TensorPtr> & tuple_index_new,const std::vector<int64_t> & broadcast_shape,const std::vector<int64_t> & index_tensor_new_shape,const std::vector<int64_t> & final_shape,const std::vector<int64_t> & tensor_positions,const std::vector<int64_t> & slice_shapes,int64_t fancy_position)637 py::object TensorIndex::GenerateIndices(const std::vector<TensorPtr> &tuple_index_new,
638                                         const std::vector<int64_t> &broadcast_shape,
639                                         const std::vector<int64_t> &index_tensor_new_shape,
640                                         const std::vector<int64_t> &final_shape,
641                                         const std::vector<int64_t> &tensor_positions,
642                                         const std::vector<int64_t> &slice_shapes, int64_t fancy_position) {
643   py::tuple final_index_tensors(tuple_index_new.size());
644   size_t slice_cnt = 0;
645   for (size_t i = 0; i < tuple_index_new.size(); i++) {
646     if (std::find(tensor_positions.begin(), tensor_positions.end(), i) != tensor_positions.end()) {
647       TensorIndex transform_tensor =
648         BroadCastTensor(broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i]);
649       final_index_tensors[i] = transform_tensor.array();
650     } else {
651       TensorIndex slice_index_tensor =
652         SliceToArray(tuple_index_new[i], final_shape, slice_cnt, broadcast_shape, slice_shapes, fancy_position);
653 
654       final_index_tensors[i] = slice_index_tensor.array();
655       slice_cnt += 1;
656     }
657   }
658   return TensorIndex::np_module_.attr("array")(TensorIndex::np_module_.attr("stack")(final_index_tensors, -1));
659 }
660 
TensorGetitemByTuple(const ShapeVector & data_shape,const std::vector<TensorIndex> & tuple_index,std::vector<int64_t> * data_transfer_types,std::vector<py::object> * data_transfer_args)661 py::object TensorIndex::TensorGetitemByTuple(const ShapeVector &data_shape, const std::vector<TensorIndex> &tuple_index,
662                                              std::vector<int64_t> *data_transfer_types,
663                                              std::vector<py::object> *data_transfer_args) {
664   size_t data_dims = data_shape.size();
665   std::vector<TensorPtr> tensor_indexes;
666   std::vector<TensorPtr> tuple_index_new;
667   std::vector<int64_t> slice_shapes;
668   std::vector<int64_t> tensor_positions;
669   size_t tuple_index_len = tuple_index.size();
670   bool empty_mask_tensor = false;
671   const size_t min_length = std::min(data_dims, tuple_index_len);
672   for (size_t i = 0; i < min_length; i++) {
673     int64_t dim_size = data_shape[i];
674     const TensorIndex &index = tuple_index[i];
675 
676     if (index.IsInteger()) {
677       int64_t int_index = index.integer();
678       if (int_index >= dim_size || int_index < -dim_size) {
679         MS_EXCEPTION(IndexError) << "Index " << int_index << " is out of bounds for dimension with size " << dim_size;
680       }
681       int_index = CheckRange(int_index, dim_size);
682       TensorPtr tensor_index = std::make_shared<Tensor>(int_index);
683       (void)tensor_positions.emplace_back(tuple_index_new.size());
684       (void)tuple_index_new.emplace_back(tensor_index);
685       (void)tensor_indexes.emplace_back(tensor_index);
686     } else if (index.IsSequence()) {
687       TensorIndex sequence_list = SequenceToTensor(index, data_shape[i]);
688       TensorPtr tensor_index = sequence_list.tensor();
689       (void)tensor_positions.emplace_back(tuple_index_new.size());
690       (void)tuple_index_new.emplace_back(tensor_index);
691       (void)tensor_indexes.emplace_back(tensor_index);
692     } else if (index.IsTensor()) {
693       const TensorPtr &tensor_index = index.tensor();
694       if (!TensorGetitemByTupleParseTensorIndex(data_shape, tensor_index, &tuple_index_new, &tensor_indexes,
695                                                 &tensor_positions, false)) {
696         TensorPtr new_tensor_index = std::make_shared<Tensor>(kNumberTypeInt32, ShapeVector({0}));
697         for (int j = 0; j < tensor_index->DataDim(); j++) {
698           (void)tensor_positions.emplace_back(tuple_index_new.size());
699           (void)tuple_index_new.emplace_back(new_tensor_index);
700           (void)tensor_indexes.emplace_back(new_tensor_index);
701         }
702         empty_mask_tensor = true;
703       }
704     } else if (index.IsSlice()) {
705       Slice slice_info = Slice(index.slice(), dim_size);
706       int64_t start = slice_info.start();
707       int64_t stop = slice_info.stop();
708       int64_t step = slice_info.step();
709 
710       std::vector<int64_t> slice_ele_list_index;
711       for (int64_t j = start; j < stop; j += step) {
712         (void)slice_ele_list_index.emplace_back(j);
713       }
714       (void)slice_shapes.emplace_back(SizeToLong(slice_ele_list_index.size()));
715       (void)tuple_index_new.emplace_back(std::make_shared<Tensor>(slice_ele_list_index));
716     }
717   }
718   tuple_index_len = tuple_index.size();
719   std::vector<ShapeVector> tensor_indexes_shapes;
720   (void)std::transform(
721     tensor_indexes.begin(), tensor_indexes.end(), std::back_inserter(tensor_indexes_shapes), [](auto &tensor_index) {
722       if (tensor_index == nullptr) {
723         MS_EXCEPTION(IndexError) << "IndexError: The sequence element(tuple/list) in tuple index can't be empty.";
724       }
725       return tensor_index->shape();
726     });
727   std::tuple<ShapeVector, ShapeVector, ShapeVector, int64_t> index_info = GenerateIndexInfoFromTupleOfMixedTensors(
728     tensor_positions, tensor_indexes_shapes, slice_shapes, TensorIndex(py::none()));
729   constexpr size_t broadcast_shape_index = 0;
730   constexpr size_t index_tensor_new_shape_index = 1;
731   constexpr size_t final_shape_index = 2;
732   constexpr size_t fancy_position_index = 3;
733   ShapeVector broadcast_shape = std::get<broadcast_shape_index>(index_info);
734   ShapeVector index_tensor_new_shape = std::get<index_tensor_new_shape_index>(index_info);
735   ShapeVector final_shape = std::get<final_shape_index>(index_info);
736   int64_t fancy_position = std::get<fancy_position_index>(index_info);
737   if (empty_mask_tensor) {
738     (void)data_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kEmptyTensor));
739     (void)data_transfer_args->emplace_back(VectorToPyTuple(final_shape));
740     return py::make_tuple(py::none(), VectorToPyTuple(*data_transfer_types), VectorToPyTuple(*data_transfer_args));
741   }
742   if (std::find(final_shape.begin(), final_shape.end(), 0) != final_shape.end() ||
743       std::find(data_shape.begin(), data_shape.end(), 0) != data_shape.end()) {
744     if (tuple_index_len < data_dims) {
745       (void)final_shape.insert(final_shape.end(), data_shape.begin() + SizeToLong(tuple_index_len), data_shape.end());
746     }
747     data_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kEmptyTensor));
748     data_transfer_args->emplace_back(VectorToPyTuple(final_shape));
749     return py::make_tuple(py::none(), VectorToPyTuple(*data_transfer_types), VectorToPyTuple(*data_transfer_args));
750   }
751 
752   data_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kGatherND));
753   data_transfer_args->emplace_back(py::make_tuple(
754     VectorToPyTuple(broadcast_shape), VectorToPyTuple(final_shape), VectorToPyTuple(index_tensor_new_shape),
755     VectorToPyTuple(slice_shapes), VectorToPyTuple(tensor_positions), fancy_position));
756   if (CheckLargeTensor(data_shape)) {
757     return py::make_tuple(tuple_index_new, VectorToPyTuple(*data_transfer_types), VectorToPyTuple(*data_transfer_args));
758   }
759   py::array new_index = GenerateIndices(tuple_index_new, broadcast_shape, index_tensor_new_shape, final_shape,
760                                         tensor_positions, slice_shapes, fancy_position);
761   return py::make_tuple(TensorPy::MakeTensor(CastToInt(new_index)), VectorToPyTuple(*data_transfer_types),
762                         VectorToPyTuple(*data_transfer_args));
763 }
764 
765 // ***********************************************for set_item*******************************************
FormatList(const TensorIndex & tensor_index,int64_t length)766 TensorIndex TensorIndex::FormatList(const TensorIndex &tensor_index, int64_t length) {
767   bool transform_to_array = std::all_of(tensor_index.list_.begin(), tensor_index.list_.end(), [](auto &x) {
768     return py::isinstance<py::int_>(x) || py::isinstance<py::bool_>(x);
769   });
770   if (transform_to_array) {
771     return SequenceToTensor<py::list>(tensor_index.list_, length);
772   }
773   return TensorIndex(DeepList(tensor_index.list_, length).cast<py::tuple>());
774 }
775 
IntToTensor(int64_t int_index,const ShapeVector & shape)776 TensorPtr TensorIndex::IntToTensor(int64_t int_index, const ShapeVector &shape) {
777   int64_t dim_size = shape[0];
778   auto out_i = static_cast<int32_t>(CheckRange(int_index, dim_size));
779   if (shape.size() == 1) {
780     return std::make_shared<Tensor>(kNumberTypeInt32, ShapeVector({1, 1}), &out_i, int32_bytes_number);
781   }
782 
783   ShapeVector index_shape(shape.begin() + 1, shape.end());
784   int64_t grids_size = SizeToLong(shape.size()) - 1;
785   py::tuple grids(grids_size);
786   for (size_t i = 1; i < shape.size(); i++) {
787     grids[i - 1] =
788       TensorIndex::np_module_.attr("arange")(0, py::int_(shape[i]), 1, TensorIndex::np_module_.attr("int32"));
789   }
790   py::object mesh = TensorIndex::np_module_.attr("ix_")(*grids);
791   py::tuple index(SizeToLong(shape.size()));
792   index[0] =
793     TensorIndex::np_module_.attr("full")(py::cast(index_shape), py::int_(out_i), TensorIndex::np_module_.attr("int32"));
794   py::tuple broadcast_mesh = TensorIndex::np_module_.attr("broadcast_arrays")(*mesh);
795   for (size_t i = 1; i < shape.size(); i++) {
796     index[i] = broadcast_mesh[i - 1];
797   }
798   py::object output_index = TensorIndex::np_module_.attr("stack")(index, -1);
799   return TensorPy::MakeTensor(TensorIndex::np_module_.attr("array")(output_index));
800 }
801 
GenerateIndicesFromTupleOfTensor(const ShapeVector & data_shape,const std::vector<TensorIndex> & tuple_index,ShapeVector * output_index_shape,py::object * data_transfer_arg)802 py::object TensorIndex::GenerateIndicesFromTupleOfTensor(const ShapeVector &data_shape,
803                                                          const std::vector<TensorIndex> &tuple_index,
804                                                          ShapeVector *output_index_shape,
805                                                          py::object *data_transfer_arg) {
806   std::vector<ShapeVector> tensor_index_shape;
807   std::vector<TensorPtr> tuple_index_vector;
808   for (const auto &index : tuple_index) {
809     TensorPtr index_tensor = index.tensor();
810     MS_EXCEPTION_IF_NULL(index_tensor);
811     (void)tuple_index_vector.emplace_back(index_tensor);
812     if (!CheckTypeIsInstance<TypeId>(index_tensor->data_type(), kIntTypes)) {
813       string index_op_type = index_op_type_ == IndexOpType::GetItem ? "tensor getitem" : "tensor setitem";
814       MS_EXCEPTION(IndexError) << "For '" << index_op_type << "', the index tensor data type '"
815                                << index_tensor->data_type() << "' is not supported.";
816     }
817   }
818   (void)std::transform(tuple_index_vector.begin(), tuple_index_vector.end(), std::back_inserter(tensor_index_shape),
819                        [](const TensorPtr &x) { return x->shape(); });
820   ShapeVector broadcast_shape = BroadCastShape(tensor_index_shape);
821 
822   constexpr int64_t min_broadcast_shape_size = 2;
823   if (SizeToLong(broadcast_shape.size()) < min_broadcast_shape_size) {
824     (void)broadcast_shape.insert(broadcast_shape.begin(), 1);
825   }
826 
827   *output_index_shape = broadcast_shape;
828   output_index_shape->emplace_back(tuple_index.size());
829   if (CheckLargeTensor(data_shape)) {
830     *data_transfer_arg = py::make_tuple(VectorToPyTuple(broadcast_shape));
831     return VectorToPyTuple(tuple_index_vector);
832   }
833 
834   std::vector<py::array> broadcast_tensors;
835   (void)std::transform(tuple_index.begin(), tuple_index.end(), std::back_inserter(broadcast_tensors),
836                        [&broadcast_shape](auto &index) {
837                          return TensorIndex::np_module_.attr("broadcast_to")(
838                            CastToInt(TensorPy::SyncAsNumpy(*index.tensor())), broadcast_shape);
839                        });
840   py::array output_index = TensorIndex::np_module_.attr("stack")(py::cast(broadcast_tensors), -1);
841   return py::cast(TensorPy::MakeTensor(TensorIndex::np_module_.attr("array")(output_index)));
842 }
843 
RemNotExpandedDims(int64_t * idx_advanced,bool expand_true,int64_t tensor_index_ndim,int64_t rem_ndim,std::vector<bool> * not_expanded_dim)844 void TensorIndex::RemNotExpandedDims(int64_t *idx_advanced, bool expand_true, int64_t tensor_index_ndim,
845                                      int64_t rem_ndim, std::vector<bool> *not_expanded_dim) {
846   if (*idx_advanced != -1) {
847     std::vector<bool> tensor_dims(tensor_index_ndim, true);
848     if (expand_true) {
849       tensor_dims = {false};
850     }
851     *idx_advanced = std::min(*idx_advanced, SizeToLong(not_expanded_dim->size()));
852     not_expanded_dim->insert(not_expanded_dim->begin() + *idx_advanced, tensor_dims.begin(), tensor_dims.end());
853   }
854   std::vector<bool> rem_ndim_vector(rem_ndim, true);
855   not_expanded_dim->insert(not_expanded_dim->end(), rem_ndim_vector.begin(), rem_ndim_vector.end());
856   size_t count_leading_false = 0;
857   while (count_leading_false < not_expanded_dim->size() && !((*not_expanded_dim)[count_leading_false])) {
858     count_leading_false += 1;
859   }
860   *idx_advanced = std::max(static_cast<int64_t>(0), *idx_advanced - SizeToLong(count_leading_false));
861 }
862 
FormatIndex(const TensorIndex & idx,const ShapeVector & data_shape,size_t cur_dim,bool * need_format)863 TensorIndex TensorIndex::FormatIndex(const TensorIndex &idx, const ShapeVector &data_shape, size_t cur_dim,
864                                      bool *need_format) {
865   if (!CheckTypeIsInstance<TensorIndexType>(idx.type(), {TensorIndexType::List, TensorIndexType::Tuple,
866                                                          TensorIndexType::Integer, TensorIndexType::Tensor})) {
867     return idx;
868   }
869   MS_EXCEPTION_IF_CHECK_FAIL(cur_dim < data_shape.size(), "Index" + std::to_string(cur_dim) + "out of data dims" +
870                                                             std::to_string(data_shape.size()));
871   int64_t dims_size = data_shape[cur_dim];
872   if (idx.IsSequence()) {
873     return SequenceToTensor(idx, dims_size);
874   } else if (idx.IsInteger()) {
875     return TensorIndex(std::make_shared<Tensor>(CheckRange(idx.integer(), dims_size)));
876   }
877   const TensorPtr &tensor_idx = idx.tensor();
878   MS_EXCEPTION_IF_NULL(tensor_idx);
879   if (CheckTypeIsInstance<TypeId>(tensor_idx->data_type(), kIntTypes)) {
880     if (CheckLargeTensor(data_shape)) {
881       *need_format = true;
882       return idx;
883     }
884     py::array new_idx = TensorPy::SyncAsNumpy(*tensor_idx);
885     if (tensor_idx->DataDim() == 0) {
886       auto new_int_idx = new_idx.cast<int64_t>();
887       new_int_idx = new_int_idx < 0 ? new_int_idx + dims_size : new_int_idx;
888       return TensorIndex(std::make_shared<Tensor>(new_int_idx));
889     }
890     // numpy op select is very slow for one dim array
891     new_idx = TensorIndex::np_module_.attr("expand_dims")(new_idx, 0);
892     new_idx = TensorIndex::np_module_.attr("select")(TensorIndex::np_module_.attr("less")(new_idx, 0),
893                                                      TensorIndex::np_module_.attr("add")(new_idx, py::int_(dims_size)),
894                                                      new_idx);
895     new_idx = TensorIndex::np_module_.attr("squeeze")(new_idx, 0);
896     return TensorIndex(TensorPy::MakeTensor(CastToInt(new_idx)));
897   } else if (tensor_idx->data_type() != kNumberTypeBool) {
898     string index_op_type = index_op_type_ == IndexOpType::GetItem ? "tensor getitem" : "tensor setitem";
899     MS_EXCEPTION(IndexError) << "For '" << index_op_type << "', the index tensor data type '"
900                              << TypeIdToString(tensor_idx->data_type(), false) << "' is not supported.";
901   }
902   return idx;
903 }
904 
RemoveExpandedDimsParseTensorIndex(const ShapeVector & data_shape,const TensorPtr & index_out,std::vector<TensorIndex> * indices_out,std::vector<ShapeVector> * shapes,bool * has_sequence,size_t * cur_dim,bool check_align)905 bool TensorIndex::RemoveExpandedDimsParseTensorIndex(const ShapeVector &data_shape, const TensorPtr &index_out,
906                                                      std::vector<TensorIndex> *indices_out,
907                                                      std::vector<ShapeVector> *shapes, bool *has_sequence,
908                                                      size_t *cur_dim, bool check_align) {
909   // Parse tensor_index
910   MS_EXCEPTION_IF_NULL(index_out);
911   if (index_out->data_type() == kNumberTypeBool) {
912     std::vector<TensorPtr> nonzero_indices_tensors = GenerateNonZeroIndexTensorList(data_shape, index_out, check_align);
913     if (nonzero_indices_tensors.empty()) {
914       return false;
915     }
916     std::vector<TensorIndex> true_index_tensors;
917     (void)std::transform(nonzero_indices_tensors.begin(), nonzero_indices_tensors.end(),
918                          std::back_inserter(true_index_tensors),
919                          [](const TensorPtr &true_index) { return TensorIndex(true_index); });
920     size_t true_index_nums = nonzero_indices_tensors.size();
921     indices_out->insert(indices_out->end(), true_index_tensors.begin(), true_index_tensors.end());
922     MS_EXCEPTION_IF_NULL(nonzero_indices_tensors[0]);
923     std::vector<ShapeVector> true_index_shapes(true_index_nums, {nonzero_indices_tensors[0]->shape()});
924     shapes->insert(shapes->end(), true_index_shapes.begin(), true_index_shapes.end());
925     *cur_dim += true_index_nums;
926   } else {
927     if (index_out->DataDim() > 0) {
928       *has_sequence = true;
929     }
930     indices_out->emplace_back(index_out);
931     shapes->emplace_back(index_out->shape());
932     *cur_dim += 1;
933   }
934   return true;
935 }
936 
RemoveExpandedDims(const std::vector<TensorIndex> & indices,const ShapeVector & data_shape,const ShapeVector & value_shape,std::vector<int64_t> * value_transfer_types,std::vector<py::object> * value_transfer_args,int64_t * idx_advanced,bool * by_pass,std::vector<size_t> * format_index,std::vector<int64_t> * format_dim)937 std::pair<std::vector<TensorIndex>, ShapeVector> TensorIndex::RemoveExpandedDims(
938   const std::vector<TensorIndex> &indices, const ShapeVector &data_shape, const ShapeVector &value_shape,
939   std::vector<int64_t> *value_transfer_types, std::vector<py::object> *value_transfer_args, int64_t *idx_advanced,
940   bool *by_pass, std::vector<size_t> *format_index, std::vector<int64_t> *format_dim) {
941   // Removes expanded dimensions in tuple_index and value.
942   size_t cur_dim = 0;
943   bool has_true = false;
944   bool has_false = false;
945   bool has_sequence = false;
946   int64_t idx_tensor = -1;
947   std::vector<bool> not_expanded_dim;
948   std::vector<TensorIndex> indices_out;
949   std::vector<ShapeVector> shapes;
950 
951   for (size_t i = 0; i < indices.size(); i++) {
952     const TensorIndex &v = indices[i];
953     bool need_format = false;
954     TensorIndex index_out = TensorIndex::FormatIndex(v, data_shape, cur_dim, &need_format);
955     if (need_format) {
956       (void)format_index->emplace_back(cur_dim);
957       (void)format_dim->emplace_back(data_shape[cur_dim]);
958     }
959     if (index_out.IsNone()) {
960       (void)not_expanded_dim.emplace_back(false);
961     } else if (index_out.IsSlice()) {
962       (void)indices_out.emplace_back(index_out);
963       (void)not_expanded_dim.emplace_back(true);
964       Slice slice_info = Slice(v.slice(), data_shape[cur_dim]);
965 
966       int64_t start = slice_info.start();
967       int64_t stop = slice_info.stop();
968       int64_t step = slice_info.step();
969       has_false = ((start - stop) * step > 0) || has_false;
970       cur_dim += 1;
971     } else if (index_out.IsBoolean() || index_out.IsTensor()) {
972       if (*idx_advanced == -1) {
973         *idx_advanced = SizeToLong(not_expanded_dim.size());
974       } else if (static_cast<int64_t>(i) - idx_tensor > 1) {
975         *idx_advanced = 0;
976       }
977       idx_tensor = static_cast<int64_t>(i);
978       if (index_out.IsTensor()) {
979         const TensorPtr &index_out_tensor = index_out.tensor();
980         if (!RemoveExpandedDimsParseTensorIndex(data_shape, index_out_tensor, &indices_out, &shapes, &has_sequence,
981                                                 &cur_dim, false)) {
982           *by_pass = true;
983           *idx_advanced = 0;
984           return {std::vector<TensorIndex>(), ShapeVector()};
985         }
986       } else {
987         bool bool_index_out = index_out.boolean();
988         has_true = bool_index_out || has_true;
989         has_false = !bool_index_out || has_false;
990       }
991     } else {
992       MS_EXCEPTION(IndexError) << "Invalid index type, index: " << TensorIndex::py_index_handle_;
993     }
994   }
995 
996   ShapeVector broadcast_shape = BroadCastShape(shapes);
997   if (has_false) {
998     if (std::accumulate(broadcast_shape.begin(), broadcast_shape.end(), 1, std::multiplies<>()) != 1) {
999       MS_EXCEPTION(IndexError) << "Unable to broadcast indices " << broadcast_shape;
1000     }
1001     *by_pass = true;
1002     return std::make_pair(std::vector<TensorIndex>(), ShapeVector());
1003   }
1004 
1005   bool expand_true = has_true && !(has_false || has_sequence);
1006   int64_t tensor_index_ndim = SizeToLong(broadcast_shape.size());
1007   int64_t rem_ndim = SizeToLong(data_shape.size()) - SizeToLong(cur_dim);
1008   RemNotExpandedDims(idx_advanced, expand_true, tensor_index_ndim, rem_ndim, &not_expanded_dim);
1009   if (indices_out.empty()) {
1010     indices_out = {TensorIndex(py::bool_(true))};
1011   }
1012   value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kReshape));
1013   ShapeVector reshape_info = FilterExpandedDims(value_shape, not_expanded_dim);
1014   value_transfer_args->emplace_back(py::cast(reshape_info));
1015   *by_pass = false;
1016   return std::make_pair(indices_out, reshape_info);
1017 }
1018 
GenerateIndicesFromTuple(const ShapeVector & data_shape,const std::vector<TensorIndex> & tuple_index,int64_t py_fancy_position,bool * by_pass,ShapeVector * output_index_shape,py::object * data_transfer_arg)1019 py::object TensorIndex::GenerateIndicesFromTuple(const ShapeVector &data_shape,
1020                                                  const std::vector<TensorIndex> &tuple_index, int64_t py_fancy_position,
1021                                                  bool *by_pass, ShapeVector *output_index_shape,
1022                                                  py::object *data_transfer_arg) {
1023   std::vector<TensorPtr> tensor_indexes;
1024   std::vector<TensorPtr> tuple_index_new;
1025   std::vector<int64_t> slice_shapes;
1026   std::vector<int64_t> tensor_positions;
1027   std::vector<ShapeVector> tensor_indexes_shapes;
1028   const size_t min_length = std::min(data_shape.size(), tuple_index.size());
1029   for (size_t i = 0; i < min_length; i++) {
1030     const TensorIndex &index = tuple_index[i];
1031     int64_t dim_size = data_shape[i];
1032 
1033     if (index.IsInteger()) {
1034       int64_t int_index = index.integer();
1035       if (int_index >= dim_size || int_index < -dim_size) {
1036         MS_EXCEPTION(IndexError) << "Index " << int_index << " is out of bounds for dimension with size " << dim_size;
1037       }
1038       int_index = CheckRange(int_index, dim_size);
1039       TensorPtr tensor_index = std::make_shared<Tensor>(int_index);
1040       MS_EXCEPTION_IF_NULL(tensor_index);
1041       (void)tuple_index_new.emplace_back(tensor_index);
1042       (void)tensor_indexes.emplace_back(tensor_index);
1043       (void)tensor_positions.emplace_back(i);
1044       (void)tensor_indexes_shapes.emplace_back(tensor_index->shape());
1045     } else if (index.IsSequence()) {
1046       TensorIndex sequence_list = SequenceToTensor(index, data_shape[i]);
1047       TensorPtr tensor_index = sequence_list.tensor();
1048       (void)tuple_index_new.emplace_back(tensor_index);
1049       (void)tensor_indexes.emplace_back(tensor_index);
1050       (void)tensor_positions.emplace_back(i);
1051       MS_EXCEPTION_IF_NULL(tensor_index);
1052       (void)tensor_indexes_shapes.emplace_back(tensor_index->shape());
1053     } else if (index.IsTensor()) {
1054       TensorPtr tensor_index = index.tensor();
1055       if (!CheckTypeIsInstance<TypeId>(tensor_index->data_type(), kIntTypes)) {
1056         MS_EXCEPTION(TypeError) << "The tensor element in tuple index must be int type, but got "
1057                                 << tensor_index->data_type();
1058       }
1059       (void)tuple_index_new.emplace_back(tensor_index);
1060       (void)tensor_indexes.emplace_back(tensor_index);
1061       (void)tensor_positions.emplace_back(i);
1062       (void)tensor_indexes_shapes.emplace_back(tensor_index->shape());
1063     } else if (index.IsSlice()) {
1064       Slice slice_info = Slice(index.slice(), dim_size);
1065       int64_t start = slice_info.start();
1066       int64_t stop = slice_info.stop();
1067       int64_t step = slice_info.step();
1068       if ((start - stop) * step >= 0) {
1069         *by_pass = true;
1070         return py::none();
1071       }
1072       std::vector<int64_t> slice_ele_list_index = SliceToVector(start, stop, step);
1073       (void)slice_shapes.emplace_back(SizeToLong(slice_ele_list_index.size()));
1074       (void)tuple_index_new.emplace_back(std::make_shared<Tensor>(slice_ele_list_index));
1075     }
1076   }
1077 
1078   std::tuple<ShapeVector, ShapeVector, ShapeVector, int64_t> index_info = GenerateIndexInfoFromTupleOfMixedTensors(
1079     tensor_positions, tensor_indexes_shapes, slice_shapes, TensorIndex(py_fancy_position));
1080   constexpr size_t k_broadcast_shape_index = 0;
1081   constexpr size_t index_tensor_new_shape_index = 1;
1082   constexpr size_t final_shape_index = 2;
1083   constexpr size_t fancy_position_index = 3;
1084   ShapeVector broadcast_shape = std::get<k_broadcast_shape_index>(index_info);
1085   ShapeVector index_tensor_new_shape = std::get<index_tensor_new_shape_index>(index_info);
1086   ShapeVector final_shape = std::get<final_shape_index>(index_info);
1087   *output_index_shape = final_shape;
1088   output_index_shape->emplace_back(tuple_index_new.size());
1089   int64_t fancy_position = std::get<fancy_position_index>(index_info);
1090   if (CheckLargeTensor(data_shape)) {
1091     *data_transfer_arg = py::make_tuple(VectorToPyTuple(broadcast_shape), VectorToPyTuple(final_shape),
1092                                         VectorToPyTuple(index_tensor_new_shape), VectorToPyTuple(slice_shapes),
1093                                         VectorToPyTuple(tensor_positions), fancy_position);
1094     return VectorToPyTuple(tuple_index_new);
1095   }
1096   py::array output_index = GenerateIndices(tuple_index_new, broadcast_shape, index_tensor_new_shape, final_shape,
1097                                            tensor_positions, slice_shapes, fancy_position);
1098   return py::cast(TensorPy::MakeTensor(TensorIndex::np_module_.attr("array")(output_index)));
1099 }
1100 
ReSetitemByTensor(const std::vector<TensorIndex> & new_tuple_index,const std::vector<int64_t> & value_transfer_types,const std::vector<py::object> & value_transfer_args)1101 py::object TensorIndex::ReSetitemByTensor(const std::vector<TensorIndex> &new_tuple_index,
1102                                           const std::vector<int64_t> &value_transfer_types,
1103                                           const std::vector<py::object> &value_transfer_args) {
1104   py::object output_py_index;
1105   if (new_tuple_index[0].IsSlice()) {
1106     Slice slice_info = new_tuple_index[0].slice();
1107     output_py_index = py::slice(slice_info.start(), slice_info.stop(), slice_info.step());
1108   } else if (new_tuple_index[0].IsTensor()) {
1109     output_py_index = py::cast(new_tuple_index[0].tensor());
1110   } else {
1111     output_py_index = py::cast(new_tuple_index[0].boolean());
1112   }
1113   return py::make_tuple(
1114     output_py_index, VectorToPyTuple<int64_t>(value_transfer_types), VectorToPyTuple<py::object>(value_transfer_args),
1115     py::make_tuple(static_cast<int>(ValueTransferType::kReSetItemByIndex)), py::make_tuple(py::none()));
1116 }
1117 
SetitemByTupleWithTensor(const ShapeVector & data_shape,const std::vector<TensorIndex> & indices,const ShapeVector & value_shape,std::vector<int64_t> * value_transfer_types,std::vector<py::object> * value_transfer_args)1118 py::object TensorIndex::SetitemByTupleWithTensor(const ShapeVector &data_shape, const std::vector<TensorIndex> &indices,
1119                                                  const ShapeVector &value_shape,
1120                                                  std::vector<int64_t> *value_transfer_types,
1121                                                  std::vector<py::object> *value_transfer_args) {
1122   std::vector<TensorIndex> new_indices = TransformEllipsisToSlice(data_shape, indices);
1123   ValueTransferType tensor_update_type = ValueTransferType::kTensorScatterUpdate;
1124   if (UseCopySlice(new_indices, SizeToLong(data_shape.size())) && !TensorIndex::is_ascend_) {
1125     Slice slice_info = Slice(new_indices[1].slice(), data_shape[1]);
1126     int64_t dim1_start = slice_info.start();
1127     int64_t dim1_stop = slice_info.stop();
1128     if (dim1_stop - dim1_start <= 0) {
1129       tensor_update_type = ValueTransferType::kByPass;
1130       return py::make_tuple(py::none(), VectorToPyTuple<int64_t>(*value_transfer_types),
1131                             VectorToPyTuple<py::object>(*value_transfer_args),
1132                             py::make_tuple(static_cast<int>(tensor_update_type)), py::make_tuple(py::none()));
1133     }
1134     if (data_shape.empty()) {
1135       MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
1136     }
1137     int64_t dim0_start =
1138       new_indices[0].integer() >= 0 ? new_indices[0].integer() : new_indices[0].integer() + data_shape[0];
1139     py::tuple start = py::make_tuple(dim0_start, dim1_start);
1140     py::tuple stop = py::make_tuple(dim0_start + 1, dim1_stop);
1141     py::tuple step = py::make_tuple(1, 1);
1142 
1143     ShapeVector new_value_shape = {dim1_stop - dim1_start};
1144     constexpr int64_t start_position_of_data_shape = 2;
1145     (void)new_value_shape.insert(new_value_shape.end(), data_shape.begin() + start_position_of_data_shape,
1146                                  data_shape.end());
1147     value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kBroadCast));
1148     value_transfer_args->emplace_back(VectorToPyTuple(new_value_shape));
1149     value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kCast));
1150     value_transfer_args->emplace_back(py::none());
1151     tensor_update_type = ValueTransferType::kCopySlice;
1152     return py::make_tuple(
1153       py::none(), VectorToPyTuple<int64_t>(*value_transfer_types), VectorToPyTuple<py::object>(*value_transfer_args),
1154       py::make_tuple(static_cast<int>(tensor_update_type)), py::make_tuple(py::make_tuple(start, stop, step)));
1155   }
1156   int64_t idx_advanced = -1;
1157   bool by_pass = false;
1158   std::vector<size_t> format_index;
1159   std::vector<int64_t> format_dim;
1160   std::pair<std::vector<TensorIndex>, ShapeVector> tuple_index_info =
1161     RemoveExpandedDims(new_indices, data_shape, value_shape, value_transfer_types, value_transfer_args, &idx_advanced,
1162                        &by_pass, &format_index, &format_dim);
1163   if (by_pass) {
1164     tensor_update_type = ValueTransferType::kByPass;
1165     return py::make_tuple(py::none(), VectorToPyTuple<int64_t>(*value_transfer_types),
1166                           VectorToPyTuple<py::object>(*value_transfer_args),
1167                           py::make_tuple(static_cast<int>(tensor_update_type)), py::make_tuple(py::none()));
1168   }
1169 
1170   MS_LOG(DEBUG) << "After remove expand dims: " << tuple_index_info.first;
1171 
1172   std::vector<TensorIndex> new_tuple_index = tuple_index_info.first;
1173   ShapeVector new_value_shape = tuple_index_info.second;
1174 
1175   if (new_tuple_index.size() == 1) {
1176     return ReSetitemByTensor(new_tuple_index, *value_transfer_types, *value_transfer_args);
1177   }
1178   py::object output_index;
1179   ShapeVector output_index_shape;
1180   py::object data_transfer_args = py::none();
1181   if (std::all_of(new_tuple_index.begin(), new_tuple_index.end(), [](const TensorIndex &x) { return x.IsTensor(); })) {
1182     output_index =
1183       GenerateIndicesFromTupleOfTensor(data_shape, new_tuple_index, &output_index_shape, &data_transfer_args);
1184   } else {
1185     by_pass = false;
1186     output_index = GenerateIndicesFromTuple(data_shape, new_tuple_index, idx_advanced, &by_pass, &output_index_shape,
1187                                             &data_transfer_args);
1188     if (by_pass) {
1189       tensor_update_type = ValueTransferType::kByPass;
1190       return py::make_tuple(py::none(), VectorToPyTuple<int64_t>(*value_transfer_types),
1191                             VectorToPyTuple<py::object>(*value_transfer_args),
1192                             py::make_tuple(static_cast<int>(tensor_update_type)), py::make_tuple(py::none()));
1193     }
1194   }
1195 
1196   value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kCast));
1197   value_transfer_args->emplace_back(py::make_tuple());
1198   ShapeVector updates_shape(output_index_shape.begin(), output_index_shape.end() - 1);
1199 
1200   if (output_index_shape.back() < SizeToLong(data_shape.size())) {
1201     (void)updates_shape.insert(updates_shape.end(), data_shape.begin() + output_index_shape.back(), data_shape.end());
1202   }
1203 
1204   if (updates_shape != new_value_shape) {
1205     value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kBroadCast));
1206     value_transfer_args->emplace_back(VectorToPyTuple(updates_shape));
1207   }
1208   std::vector<int> tensor_update_types{static_cast<int>(tensor_update_type)};
1209   std::vector<py::object> tensor_update_args{data_transfer_args};
1210   if (!format_index.empty()) {
1211     (void)tensor_update_types.insert(tensor_update_types.begin(),
1212                                      static_cast<int>(ValueTransferType::kFormatIndexTensor));
1213     (void)tensor_update_args.insert(tensor_update_args.begin(), py::make_tuple(VectorToPyTuple<size_t>(format_index),
1214                                                                                VectorToPyTuple<int64_t>(format_dim)));
1215   }
1216   if (py::isinstance<py::tuple>(output_index)) {
1217     return py::make_tuple(py::cast<py::list>(output_index), VectorToPyTuple<int64_t>(*value_transfer_types),
1218                           VectorToPyTuple<py::object>(*value_transfer_args), VectorToPyTuple<int>(tensor_update_types),
1219                           VectorToPyTuple<py::object>(tensor_update_args));
1220   }
1221   return py::make_tuple(py::cast<TensorPtr>(output_index), VectorToPyTuple<int64_t>(*value_transfer_types),
1222                         VectorToPyTuple<py::object>(*value_transfer_args), VectorToPyTuple<int>(tensor_update_types),
1223                         VectorToPyTuple<py::object>(tensor_update_args));
1224 }
1225 
GetStubTensorValue(const py::handle & obj)1226 ValuePtr GetStubTensorValue(const py::handle &obj) {
1227   auto py_stub = py::getattr(obj, stub::PY_ATTR_STUB);
1228   ValuePtr stub = py_stub.cast<stub::StubNodePtr>();
1229   if (stub == nullptr) {
1230     auto tensor_ptr = py::getattr(obj, stub::PY_ATTR_TENSOR).cast<tensor::TensorPtr>();
1231     MS_EXCEPTION_IF_NULL(tensor_ptr);
1232     stub = tensor_ptr;
1233   }
1234   return stub;
1235 }
1236 
SqueezeRDataValue(const TensorPtr & tensor,const py::handle & py_value,const ValuePtr & rdata_value)1237 ValuePtr SqueezeRDataValue(const TensorPtr &tensor, const py::handle &py_value, const ValuePtr &rdata_value) {
1238   auto rdata_shape = tensor->shape();
1239   if (rdata_shape.size() >= 1 && (rdata_shape.at(0) > 1 || rdata_shape.size() > 1)) {
1240     MS_EXCEPTION(ValueError)
1241       << "For SetItem, the shape of right value must be () or (1, ) when shape of left value is 0, but got"
1242       << rdata_shape;
1243   } else if (rdata_shape.size() == 1 && rdata_shape.at(0) == 1) {
1244     auto new_value = py::cast<py::list>(py_value);
1245     auto first_value = new_value[0];
1246     ValuePtr result =
1247       IsStubTensor(first_value) ? GetStubTensorValue(first_value) : first_value.cast<tensor::TensorPtr>();
1248     return result;
1249   }
1250   return rdata_value;
1251 }
1252 
SetitemCopyView(std::vector<pynative::SliceOpInfoPtr> * slice_op_infos,const ValuePtr data_value,const std::vector<int64_t> & new_data_shape,const TypePtr & data_type,const py::handle & py_value)1253 static inline py::object SetitemCopyView(std::vector<pynative::SliceOpInfoPtr> *slice_op_infos,
1254                                          const ValuePtr data_value, const std::vector<int64_t> &new_data_shape,
1255                                          const TypePtr &data_type, const py::handle &py_value) {
1256   auto cast_op_info = std::make_shared<pynative::SliceOpInfo>();
1257   cast_op_info->slice_op_name = prim::kPrimCast->name();
1258   (void)cast_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(data_type->type_id()));
1259   cast_op_info->data_indexs = {1};
1260   (void)slice_op_infos->emplace_back(cast_op_info);
1261 
1262   auto broadcastto_op_info = std::make_shared<pynative::SliceOpInfo>();
1263   broadcastto_op_info->slice_op_name = prim::kPrimBroadcastTo->name();
1264   (void)broadcastto_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(new_data_shape));
1265   broadcastto_op_info->data_indexs = {1};
1266   (void)slice_op_infos->emplace_back(broadcastto_op_info);
1267 
1268   auto copy_op_info = std::make_shared<pynative::SliceOpInfo>();
1269   copy_op_info->slice_op_name = kCopyWithSliceOpName;
1270   copy_op_info->data_indexs = {0, 1};
1271   (void)slice_op_infos->emplace_back(copy_op_info);
1272   ValuePtr rdata_value;
1273   if (IsStubTensor(py_value)) {
1274     rdata_value = GetStubTensorValue(py_value);
1275     if (new_data_shape.size() == 0) {
1276       auto tensor = ConvertStubTensor(py_value);
1277       rdata_value = SqueezeRDataValue(tensor, py_value, rdata_value);
1278     }
1279   } else if (py::isinstance<Tensor>(py_value)) {
1280     auto tensor = py_value.cast<TensorPtr>();
1281     MS_EXCEPTION_IF_NULL(tensor);
1282     rdata_value = tensor;
1283     if (new_data_shape.size() == 0) {
1284       rdata_value = SqueezeRDataValue(tensor, py_value, rdata_value);
1285     }
1286   } else if (py::isinstance<py::int_>(py_value)) {
1287     rdata_value = MakeValue(py::cast<int64_t>(py_value));
1288   } else if (py::isinstance<py::float_>(py_value)) {
1289     rdata_value = MakeValue(py::cast<float>(py_value));
1290   } else if (py::isinstance<py::bool_>(py_value)) {
1291     rdata_value = MakeValue(py::cast<bool>(py_value));
1292   } else {
1293     return py::none();
1294   }
1295   return pynative::PyNativeExecutor::GetInstance()->RunSliceOpStub({data_value, rdata_value}, *slice_op_infos);
1296 }
1297 
SetitemBySliceWithTensor(const ShapeVector & data_shape,const TensorIndex & slice_index,std::vector<int64_t> * value_transfer_types,std::vector<py::object> * value_transfer_args,const ValuePtr & data_value,const TypePtr & data_type)1298 py::object TensorIndex::SetitemBySliceWithTensor(const ShapeVector &data_shape, const TensorIndex &slice_index,
1299                                                  std::vector<int64_t> *value_transfer_types,
1300                                                  std::vector<py::object> *value_transfer_args,
1301                                                  const ValuePtr &data_value, const TypePtr &data_type) {
1302   ValueTransferType tensor_update_type = ValueTransferType::kTensorScatterUpdate;
1303   Slice slice_info = Slice(slice_index.slice(), data_shape[0]);
1304   int64_t start = slice_info.start();
1305   int64_t stop = slice_info.stop();
1306   int64_t step = slice_info.step();
1307   if (step >= 0 && data_value != nullptr) {
1308     std::vector<int64_t> data_transfer_types;
1309     std::vector<py::object> data_transfer_args;
1310     std::vector<int64_t> begin_info(data_shape.size(), 0);
1311     std::vector<int64_t> end_info(data_shape);
1312     std::vector<int64_t> step_info(data_shape.size(), 1);
1313     std::vector<pynative::SliceOpInfoPtr> slice_op_infos;
1314     if (start >= stop) {
1315       (void)data_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kJustReturn));
1316       return py::make_tuple(py::str("view"), py::tuple(), py::tuple(), VectorToPyTuple(data_transfer_types),
1317                             py::tuple());
1318     }
1319     if (slice_info.start() != 0 || slice_info.step() != 1 || slice_info.stop() != end_info[0]) {
1320       begin_info[0] = slice_info.start();
1321       end_info[0] = slice_info.stop();
1322       step_info[0] = slice_info.step();
1323       auto slice_op_info = std::make_shared<pynative::SliceOpInfo>();
1324       slice_op_info->slice_op_name = prim::kPrimStridedSlice->name();
1325       (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(begin_info));
1326       (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(end_info));
1327       (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(step_info));
1328       (void)slice_op_info->data_indexs.emplace_back(0);
1329       (void)slice_op_infos.emplace_back(slice_op_info);
1330     }
1331     auto new_data_shape = data_shape;
1332     if (step != 0) {
1333       auto new_shape_zero = (stop - start) / step;
1334       new_data_shape[0] = (new_shape_zero < 0 ? 0 : (stop + step - 1 - start) / step);
1335     }
1336     auto slice_output = SetitemCopyView(&slice_op_infos, data_value, new_data_shape, data_type, py_value_handle_);
1337     if (slice_output != py::none()) {
1338       data_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kJustReturn));
1339       data_transfer_args.emplace_back(slice_output);
1340       return py::make_tuple(py::str("view"), py::tuple(), py::tuple(), VectorToPyTuple(data_transfer_types),
1341                             VectorToPyTuple(data_transfer_args));
1342     }
1343     (void)data_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kStrideSlice));
1344     (void)data_transfer_args.emplace_back(py::make_tuple(
1345       py::make_tuple(slice_info.start()), py::make_tuple(slice_info.stop()), py::make_tuple(slice_info.step())));
1346     (void)data_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kCopyView));
1347     (void)data_transfer_args.emplace_back(py::none());
1348     return py::make_tuple(py::str("view"), VectorToPyTuple<int64_t>(*value_transfer_types),
1349                           VectorToPyTuple<py::object>(*value_transfer_args), VectorToPyTuple(data_transfer_types),
1350                           VectorToPyTuple(data_transfer_args));
1351   }
1352   if (slice_index.slice().step() == 1 && !TensorIndex::is_ascend_) {
1353     if (data_shape.empty()) {
1354       MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
1355     }
1356     int64_t dim0_size = stop - start;
1357     if (dim0_size <= 0) {
1358       tensor_update_type = ValueTransferType::kByPass;
1359       return py::make_tuple(py::none(), VectorToPyTuple<int64_t>(*value_transfer_types),
1360                             VectorToPyTuple<py::object>(*value_transfer_args),
1361                             py::make_tuple(static_cast<int>(tensor_update_type)), py::make_tuple(py::none()));
1362     }
1363     ShapeVector value_shape = {dim0_size};
1364     (void)value_shape.insert(value_shape.end(), data_shape.begin() + 1, data_shape.end());
1365     value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kBroadCast));
1366     value_transfer_args->emplace_back(VectorToPyTuple(value_shape));
1367     value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kCast));
1368     value_transfer_args->emplace_back(py::none());
1369     tensor_update_type = ValueTransferType::kCopySlice;
1370     return py::make_tuple(
1371       py::none(), VectorToPyTuple<int64_t>(*value_transfer_types), VectorToPyTuple<py::object>(*value_transfer_args),
1372       py::make_tuple(static_cast<int>(tensor_update_type)),
1373       py::make_tuple(py::make_tuple(py::make_tuple(start), py::make_tuple(stop), py::make_tuple(step))));
1374   }
1375   TensorIndex indices = SliceToArray(slice_index, data_shape);
1376   if (indices.IsBoolean()) {
1377     tensor_update_type = ValueTransferType::kByPass;
1378     return py::make_tuple(indices.boolean(), VectorToPyTuple<int64_t>(*value_transfer_types),
1379                           VectorToPyTuple<py::object>(*value_transfer_args),
1380                           py::make_tuple(static_cast<int>(tensor_update_type)), py::make_tuple(py::none()));
1381   }
1382   value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kBroadCast));
1383   TensorPtr indices_tensor = TensorPy::MakeTensor(TensorIndex::np_module_.attr("array")(indices.array()));
1384   MS_EXCEPTION_IF_NULL(indices_tensor);
1385   ShapeVector broad_cast_shape(indices_tensor->shape().begin(), indices_tensor->shape().end() - 1);
1386   value_transfer_args->emplace_back(VectorToPyTuple(broad_cast_shape));
1387   value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kCast));
1388   value_transfer_args->emplace_back(py::none());
1389   return py::make_tuple(indices_tensor, VectorToPyTuple<int64_t>(*value_transfer_types),
1390                         VectorToPyTuple<py::object>(*value_transfer_args),
1391                         py::make_tuple(static_cast<int>(tensor_update_type)), py::make_tuple(py::none()));
1392 }
1393 
SetItemByTensorByBool(const ShapeVector & data_shape,const TensorPtr & index,int64_t data_dims,std::vector<int64_t> * value_transfer_types,std::vector<py::object> * value_transfer_args,ValueTransferType * tensor_update_type)1394 py::array TensorIndex::SetItemByTensorByBool(const ShapeVector &data_shape, const TensorPtr &index, int64_t data_dims,
1395                                              std::vector<int64_t> *value_transfer_types,
1396                                              std::vector<py::object> *value_transfer_args,
1397                                              ValueTransferType *tensor_update_type) {
1398   ShapeVector index_shape = GeneratePaddingShape(index->shape(), data_dims);
1399   py::array np_index = TensorPy::SyncAsNumpy(*index);
1400   py::array output_np_index = TensorIndex::np_module_.attr("broadcast_to")(
1401     TensorIndex::np_module_.attr("reshape")(np_index, VectorToPyTuple(index_shape)), VectorToPyTuple(data_shape));
1402   value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kCast));
1403   value_transfer_args->emplace_back(py::none());
1404   value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kBroadCast));
1405   value_transfer_args->emplace_back(VectorToPyTuple(data_shape));
1406   *tensor_update_type = ValueTransferType::kSelect;
1407   return output_np_index;
1408 }
1409 
1410 // ***********************************************get get_item info*******************************************
GetItemByTensor(const ShapeVector & data_shape,const TensorPtr & index)1411 py::object TensorIndex::GetItemByTensor(const ShapeVector &data_shape, const TensorPtr &index) {
1412   MS_EXCEPTION_IF_NULL(index);
1413   MS_LOG(DEBUG) << "In branch get item by tensor, data_shape: " << data_shape
1414                 << " tensor_indexes: " << index->ToString();
1415   constexpr int min_data_dim = 1;
1416   constexpr int max_data_dim = 7;
1417   const int64_t data_dim = SizeToLong(data_shape.size());
1418   JudgeDataDim(data_dim, min_data_dim, max_data_dim);
1419   py::object output = py::none();
1420   if (CheckTypeIsInstance<TypeId>(index->data_type(), kIntTypes)) {
1421     output =
1422       py::make_tuple(index, py::make_tuple(static_cast<int>(ValueTransferType::kGather)), py::make_tuple(py::none()));
1423   } else if (index->data_type() == kNumberTypeBool) {
1424     py::tuple nonzero_indices = GenerateNonZeroIndex(data_shape, index, true);
1425     MS_EXCEPTION_IF_CHECK_FAIL(!nonzero_indices.empty(), "Output size of nonzero should not be empty");
1426     int64_t nonzero_indices_nums = SizeToLong(len(py::array(nonzero_indices[0])));
1427     if (nonzero_indices_nums == 0) {
1428       ShapeVector empty_tensor_shape(data_shape.begin() + index->DataDim(), data_shape.end());
1429       (void)empty_tensor_shape.insert(empty_tensor_shape.begin(), 0);
1430 
1431       return py::make_tuple(py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kEmptyTensor)),
1432                             py::make_tuple(VectorToPyTuple(empty_tensor_shape)));
1433     }
1434     output = py::make_tuple(index, py::make_tuple(static_cast<int>(ValueTransferType::kGetitemByBoolTensor)),
1435                             py::make_tuple(py::none()));
1436   } else {
1437     MS_EXCEPTION(IndexError) << "The tensor index must be int or bool type, but got " << TensorIndex::py_index_handle_;
1438   }
1439   return output;
1440 }
1441 
GetItemByList(const ShapeVector & data_shape,const TensorIndex & tensor_index)1442 py::object TensorIndex::GetItemByList(const ShapeVector &data_shape, const TensorIndex &tensor_index) {
1443   MS_LOG(DEBUG) << "In branch get item by List, data_shape: " << data_shape << " tensor_index: " << tensor_index;
1444   constexpr int min_data_dim = 1;
1445   constexpr int max_data_dim = 8;
1446   int64_t data_dim = SizeToLong(data_shape.size());
1447   JudgeDataDim(data_dim, min_data_dim, max_data_dim);
1448   bool use_gather = std::all_of(tensor_index.list().begin(), tensor_index.list().end(),
1449                                 [](auto &x) { return py::isinstance<py::int_>(x) || py::isinstance<py::bool_>(x); });
1450   if (use_gather) {
1451     if (data_shape.empty()) {
1452       MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
1453     }
1454     TensorIndex tuple_index = SequenceToTensor(tensor_index, data_shape[0]);
1455     if (tuple_index.IsBoolean() && !tuple_index.boolean()) {
1456       MS_EXCEPTION(IndexError) << "When tensor is indexed by list, the list can't be empty.";
1457     }
1458     return py::make_tuple(tuple_index.tensor(), py::make_tuple(static_cast<int>(ValueTransferType::kGather)),
1459                           py::make_tuple(py::none()));
1460   }
1461   return GetItemByTuple(data_shape, tensor_index.ExpandToVector());
1462 }
1463 
JudgeTupleIndexDim(int64_t data_dim,const std::vector<TensorIndex> & new_tuple_indexes)1464 static void JudgeTupleIndexDim(int64_t data_dim, const std::vector<TensorIndex> &new_tuple_indexes) {
1465   int64_t index_dims = 0;
1466   for (const TensorIndex &index : new_tuple_indexes) {
1467     if (index.IsTensor() && index.tensor() != nullptr && index.tensor()->data_type() == kNumberTypeBool) {
1468       index_dims += index.tensor()->DataDim();
1469     } else {
1470       index_dims += 1;
1471     }
1472   }
1473   if (index_dims > data_dim) {
1474     MS_EXCEPTION(IndexError) << "The dim of index cannot be greater than indexed data, but got dim of index:"
1475                              << index_dims << ", dim of data:" << data_dim;
1476   }
1477 }
1478 
GetSpecifiedDimensions(const py::tuple & new_tuple_index,size_t data_dims)1479 size_t GetSpecifiedDimensions(const py::tuple &new_tuple_index, size_t data_dims) {
1480   size_t specified_dimensions = std::count_if(new_tuple_index.begin(), new_tuple_index.end(), [](auto const &obj) {
1481     return (obj != Py_None && obj != Py_Ellipsis && obj != Py_True && obj != Py_False);
1482   });
1483   constexpr size_t max_data_dim = 8;
1484   if (data_dims > max_data_dim) {
1485     MS_EXCEPTION(ValueError) << "The input data's dim must in the range of [0, " << max_data_dim << "], but got '"
1486                              << data_dims << "'.";
1487   }
1488   if (specified_dimensions > data_dims) {
1489     MS_EXCEPTION(IndexError) << "too many indices for tensor of dimension" << data_dims;
1490   }
1491   return specified_dimensions;
1492 }
1493 
1494 namespace {
CheckDataDim(const ShapeVector & data_shape)1495 void CheckDataDim(const ShapeVector &data_shape) {
1496   constexpr size_t max_data_dim = 8;
1497   if (data_shape.size() > max_data_dim) {
1498     MS_EXCEPTION(ValueError) << "The input data's dim must in the range of [1, " << max_data_dim << "], but got '"
1499                              << data_shape.size() << "'.";
1500   }
1501 }
1502 
CheckNumberOfEllipsis(const size_t counter)1503 void CheckNumberOfEllipsis(const size_t counter) {
1504   if (counter > 0) {
1505     MS_EXCEPTION(IndexError) << "An index can only have a single ellipsis('...')";
1506   }
1507 }
1508 }  // namespace
1509 
GetItemByTupleWithView(const ValuePtr & data_value,const ShapeVector & data_shape,const py::object & py_index,std::vector<int64_t> * data_transfer_types,std::vector<py::object> * data_transfer_args,const TypePtr & data_type)1510 bool TensorIndex::GetItemByTupleWithView(const ValuePtr &data_value, const ShapeVector &data_shape,
1511                                          const py::object &py_index, std::vector<int64_t> *data_transfer_types,
1512                                          std::vector<py::object> *data_transfer_args, const TypePtr &data_type) {
1513   if (data_value == nullptr) {
1514     return false;
1515   }
1516   MS_LOG(DEBUG) << "In branch get item by tuple with view, data_shape: " << data_shape
1517                 << " tensor_indexes: " << py_index;
1518   size_t data_dims = data_shape.size();
1519   auto new_tuple_index = py_index.cast<py::tuple>();
1520   size_t specified_dimensions = GetSpecifiedDimensions(new_tuple_index, data_dims);
1521   bool empty_strided_slice_result = false;
1522   auto new_data_shape = data_shape;
1523   size_t dim = 0;
1524   std::vector<pynative::SliceOpInfoPtr> slice_op_infos;
1525   size_t ellipsis_count = 0;
1526   for (auto const &obj : new_tuple_index) {
1527     if (py::isinstance<py::int_>(obj) && !py::isinstance<py::bool_>(obj)) {
1528       auto index = py::cast<int64_t>(obj);
1529       if (index >= new_data_shape[dim] || index < -new_data_shape[dim]) {
1530         // Raise exception in python, because python iterator need raise IndexError to stop for loop.
1531         data_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kRaiseIndexError));
1532         data_transfer_args->emplace_back(py::make_tuple(index, new_data_shape[dim]));
1533         return true;
1534       }
1535       int64_t transformed_number = CheckRange(index, new_data_shape[dim]);
1536       auto slice_op_info = std::make_shared<pynative::SliceOpInfo>();
1537       slice_op_info->slice_op_name = prim::kPrimSelectView->name();
1538       (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(transformed_number));
1539       (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(dim));
1540       (void)slice_op_info->data_indexs.emplace_back(0);
1541       (void)slice_op_infos.emplace_back(slice_op_info);
1542       (void)new_data_shape.erase(new_data_shape.begin() + dim);
1543     } else if (py::isinstance<py::slice>(obj)) {
1544       auto slice_info = Slice(TensorIndex(obj).slice(), new_data_shape[dim]);
1545       std::vector<int64_t> begin_info(new_data_shape.size(), 0);
1546       std::vector<int64_t> end_info(new_data_shape);
1547       std::vector<int64_t> step_info(new_data_shape.size(), 1);
1548       if (slice_info.step() < 0) {
1549         data_transfer_types->clear();
1550         data_transfer_args->clear();
1551         return false;
1552       }
1553       if (slice_info.start() == 0 && slice_info.step() == 1 && slice_info.stop() == end_info[dim]) {
1554         dim++;
1555         continue;
1556       }
1557       empty_strided_slice_result = (slice_info.start() >= slice_info.stop());
1558       begin_info[dim] = slice_info.start();
1559       end_info[dim] = slice_info.stop();
1560       step_info[dim] = slice_info.step();
1561       auto slice_op_info = std::make_shared<pynative::SliceOpInfo>();
1562       slice_op_info->slice_op_name = prim::kPrimStridedSlice->name();
1563       (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(begin_info));
1564       (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(end_info));
1565       (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(step_info));
1566       (void)slice_op_info->data_indexs.emplace_back(0);
1567       (void)slice_op_infos.emplace_back(slice_op_info);
1568       new_data_shape[dim] = (slice_info.stop() + slice_info.step() - 1 - slice_info.start()) / slice_info.step();
1569       dim++;
1570     } else if (py::isinstance<py::ellipsis>(obj)) {
1571       CheckNumberOfEllipsis(ellipsis_count);
1572       dim += data_shape.size() - specified_dimensions;
1573       ellipsis_count += 1;
1574     } else if (py::isinstance<py::none>(obj)) {
1575       auto slice_op_info = std::make_shared<pynative::SliceOpInfo>();
1576       slice_op_info->slice_op_name = prim::kPrimExpandDims->name();
1577       (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(dim));
1578       (void)slice_op_info->data_indexs.emplace_back(0);
1579       (void)slice_op_infos.emplace_back(slice_op_info);
1580       new_data_shape.insert(new_data_shape.begin() + dim, 1);
1581       dim++;
1582     } else {
1583       data_transfer_types->clear();
1584       data_transfer_args->clear();
1585       return false;
1586     }
1587   }
1588   CheckDataDim(new_data_shape);
1589   py::object slice_output;
1590   if (data_type != nullptr) {
1591     if (empty_strided_slice_result) {
1592       data_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kByPass));
1593       data_transfer_args->emplace_back(py::none());
1594       return true;
1595     }
1596     slice_output = SetitemCopyView(&slice_op_infos, data_value, new_data_shape, data_type, py_value_handle_);
1597     if (slice_output == py::none()) {
1598       return false;
1599     }
1600   } else {
1601     if (slice_op_infos.empty()) {
1602       data_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kByPass));
1603       data_transfer_args->emplace_back(py::none());
1604       return true;
1605     }
1606     slice_output = pynative::PyNativeExecutor::GetInstance()->RunSliceOpStub({data_value}, slice_op_infos);
1607   }
1608   data_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kJustReturn));
1609   data_transfer_args->emplace_back(slice_output);
1610   return true;
1611 }
1612 
GetItemByTuple(const ShapeVector & data_shape,const std::vector<TensorIndex> & tensor_indexes)1613 py::object TensorIndex::GetItemByTuple(const ShapeVector &data_shape, const std::vector<TensorIndex> &tensor_indexes) {
1614   MS_LOG(DEBUG) << "In branch get item by tuple, data_shape: " << data_shape << " tensor_indexes: " << tensor_indexes;
1615   std::vector<int64_t> data_transfer_types;
1616   std::vector<py::object> data_transfer_args;
1617   ShapeVector new_data_shape = data_shape;
1618   if (tensor_indexes.empty()) {
1619     return py::make_tuple(py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kByPass)),
1620                           py::make_tuple(py::none()));
1621   }
1622   std::vector<TensorIndex> new_tuple_indexes = TransformEllipsisToSlice(new_data_shape, tensor_indexes);
1623   std::tuple expand_dim_info = GetExpandDimsInfo(new_data_shape, new_tuple_indexes);
1624   constexpr size_t expand_dim_info_index = 0;
1625   constexpr size_t new_data_shape_index = 1;
1626   constexpr size_t new_tuple_indexes_index = 2;
1627   bool need_expand_dim = std::get<expand_dim_info_index>(expand_dim_info);
1628   if (need_expand_dim) {
1629     (void)data_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kReshape));
1630     new_data_shape = std::get<new_data_shape_index>(expand_dim_info);
1631     (void)data_transfer_args.emplace_back(VectorToPyTuple(new_data_shape));
1632     new_tuple_indexes = std::get<new_tuple_indexes_index>(expand_dim_info);  // NOLINT
1633   }
1634   constexpr int min_data_dim = 1;
1635   constexpr int max_data_dim = 8;
1636   int64_t data_dim = SizeToLong(new_data_shape.size());
1637   JudgeDataDim(data_dim, min_data_dim, max_data_dim);
1638   JudgeTupleIndexDim(data_dim, new_tuple_indexes);
1639   bool normal_tuple = std::all_of(new_tuple_indexes.begin(), new_tuple_indexes.end(), [](auto &index_e) {
1640     return index_e.IsEllipsis() || index_e.IsInteger() || index_e.IsSlice();
1641   });
1642   if (normal_tuple) {
1643     std::tuple stride_slice_info = GetStrideInfoFromTuple(new_data_shape, new_tuple_indexes);
1644     (void)data_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kStrideSliceWithMask));
1645     std::vector<std::vector<int64_t>> stride_info = std::get<0>(stride_slice_info);
1646     std::vector<py::tuple> py_stride_info;
1647     (void)std::transform(stride_info.begin(), stride_info.end(), std::back_inserter(py_stride_info),
1648                          [](auto &stride_info_i) { return VectorToPyTuple(stride_info_i); });
1649     std::vector<int64_t> mask_info = std::get<1>(stride_slice_info);
1650     (void)data_transfer_args.emplace_back(py::make_tuple(VectorToPyTuple(py_stride_info), VectorToPyTuple(mask_info)));
1651     return py::make_tuple(py::none(), VectorToPyTuple(data_transfer_types), VectorToPyTuple(data_transfer_args));
1652   }
1653   return TensorGetitemByTuple(new_data_shape, new_tuple_indexes, &data_transfer_types, &data_transfer_args);
1654 }
1655 
GetItemByBool(const ValuePtr & data_value,const ShapeVector & data_shape,bool index)1656 py::object TensorIndex::GetItemByBool(const ValuePtr &data_value, const ShapeVector &data_shape, bool index) {
1657   MS_LOG(INFO) << "(View) In branch get item by bool, data_shape: " << data_shape << " tensor_indexes: " << index;
1658   constexpr int min_data_dim = 0;
1659   constexpr int max_data_dim = 7;
1660   int64_t data_dim = SizeToLong(data_shape.size());
1661   JudgeDataDim(data_dim, min_data_dim, max_data_dim);
1662   if (!index) {
1663     MS_EXCEPTION(IndexError) << "When tensor is indexed by a bool object, the value only support 'True'.";
1664   }
1665   auto transfer_type = (data_value == nullptr ? ValueTransferType::kExpandDims : ValueTransferType::kUnsqueeze);
1666   return py::make_tuple(py::none(), py::make_tuple(static_cast<int>(transfer_type)), py::make_tuple(py::int_(0)));
1667 }
1668 
GetItemByNumber(const ShapeVector & data_shape,int64_t index)1669 py::object TensorIndex::GetItemByNumber(const ShapeVector &data_shape, int64_t index) {
1670   MS_LOG(DEBUG) << "In branch get item by number, data_shape: " << data_shape << " tensor_indexes: " << index;
1671   if (data_shape.empty()) {
1672     MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
1673   }
1674   constexpr int min_data_dim = 1;
1675   constexpr int max_data_dim = 8;
1676   int64_t data_dim = SizeToLong(data_shape.size());
1677   JudgeDataDim(data_dim, min_data_dim, max_data_dim);
1678   if (index >= data_shape[0] || index < -data_shape[0]) {
1679     // Raise exception in python, because python iterator need raise IndexError to stop for loop.
1680     return py::make_tuple(py::make_tuple(py::none()),
1681                           py::make_tuple(static_cast<int>(ValueTransferType::kRaiseIndexError)),
1682                           py::make_tuple(py::make_tuple(index, data_shape[0])));
1683   }
1684   int64_t transformed_number = CheckRange(index, data_shape[0]);
1685   if (!TensorIndex::is_ascend_) {
1686     return py::make_tuple(std::make_shared<Tensor>(transformed_number),
1687                           py::make_tuple(static_cast<int>(ValueTransferType::kGather)), py::make_tuple(py::none()));
1688   }
1689   std::vector<int64_t> begin_strides = {transformed_number};
1690   std::vector<int64_t> end_strides = {transformed_number + 1};
1691   std::vector<int64_t> step_strides = {1};
1692   for (size_t i = 1; i < data_shape.size(); i++) {
1693     (void)begin_strides.emplace_back(0);
1694     (void)end_strides.emplace_back(data_shape[i]);
1695     (void)step_strides.emplace_back(1);
1696   }
1697   int64_t shrink_axis_mask = 1;
1698   int64_t begin_mask = 0;
1699   int64_t end_mask = 0;
1700   constexpr size_t begin_mask_begin_bit = 2;
1701   constexpr size_t begin_mask_end_bit = 8;
1702   for (size_t i = begin_mask_begin_bit; i < begin_mask_end_bit; i++) {
1703     const auto mask_bit = 1 << i;
1704     begin_mask += mask_bit;
1705     end_mask += mask_bit;
1706   }
1707 
1708   py::tuple stride_info =
1709     py::make_tuple(VectorToPyTuple(begin_strides), VectorToPyTuple(end_strides), VectorToPyTuple(step_strides));
1710   py::tuple mask_info = py::make_tuple(begin_mask, end_mask, shrink_axis_mask);
1711   return py::make_tuple(py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kStrideSliceWithMask)),
1712                         py::make_tuple(py::make_tuple(stride_info, mask_info)));
1713 }
1714 
GetItemByNumberWithView(const ValuePtr & data_value,const ShapeVector & data_shape,int64_t index)1715 py::object TensorIndex::GetItemByNumberWithView(const ValuePtr &data_value, const ShapeVector &data_shape,
1716                                                 int64_t index) {
1717   MS_LOG(INFO) << "(View) In branch get item by number, data_shape: " << data_shape << " tensor_indexes: " << index;
1718   if (data_shape.empty()) {
1719     MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
1720   }
1721   constexpr int min_data_dim = 1;
1722   constexpr int max_data_dim = 8;
1723   int64_t data_dim = SizeToLong(data_shape.size());
1724   JudgeDataDim(data_dim, min_data_dim, max_data_dim);
1725   if (index >= data_shape[0] || index < -data_shape[0]) {
1726     // Raise exception in python, because python iterator need raise IndexError to stop for loop.
1727     return py::make_tuple(py::make_tuple(py::none()),
1728                           py::make_tuple(static_cast<int>(ValueTransferType::kRaiseIndexError)),
1729                           py::make_tuple(py::make_tuple(index, data_shape[0])));
1730   }
1731   int64_t transformed_number = CheckRange(index, data_shape[0]);
1732   // return py::make_tuple(py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kSelectView)),
1733   //                       py::make_tuple(py::make_tuple(py::int_(transformed_number), py::int_(0))));
1734   int64_t dim = 0;
1735   auto slice_op_info = std::make_shared<pynative::SliceOpInfo>();
1736 
1737   slice_op_info->slice_op_name = prim::kPrimSelectView->name();
1738   (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(transformed_number));
1739   (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(dim));
1740   (void)slice_op_info->data_indexs.emplace_back(0);
1741 
1742   auto slice_output = pynative::PyNativeExecutor::GetInstance()->RunSliceOpStub({data_value}, {slice_op_info});
1743   return py::make_tuple(py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kJustReturn)),
1744                         py::make_tuple(slice_output));
1745 }
1746 
GetItemBySlice(const ValuePtr & data_value,const ShapeVector & data_shape,const TensorIndex & py_index)1747 py::object TensorIndex::GetItemBySlice(const ValuePtr &data_value, const ShapeVector &data_shape,
1748                                        const TensorIndex &py_index) {
1749   MS_LOG(INFO) << "(View) In branch get item by slice, data_shape: " << data_shape << " tensor_indexes: " << py_index;
1750   constexpr int min_data_dim = 1;
1751   constexpr int max_data_dim = 8;
1752   size_t data_dim = data_shape.size();
1753   JudgeDataDim(SizeToLong(data_dim), min_data_dim, max_data_dim);
1754   if (data_shape.empty()) {
1755     MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
1756   }
1757   Slice slice_info = Slice(py_index.slice(), data_shape[0]);
1758   if (slice_info.step() >= 0 && data_value != nullptr) {
1759     std::vector<int64_t> begin_info(data_dim, 0);
1760     std::vector<int64_t> end_info(data_shape);
1761     std::vector<int64_t> step_info(data_dim, 1);
1762     begin_info[0] = slice_info.start();
1763     end_info[0] = slice_info.stop();
1764     step_info[0] = slice_info.step();
1765     return py::make_tuple(py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kStrideSlice)),
1766                           py::make_tuple(py::make_tuple(VectorToPyTuple(begin_info), VectorToPyTuple(end_info),
1767                                                         VectorToPyTuple(step_info))));
1768   }
1769   int64_t begin_mask = slice_info.start_init_by_none() ? 1 : 0;
1770   int64_t end_mask = slice_info.stop_init_by_none() ? 1 : 0;
1771   for (size_t i = 1; i < data_dim; i++) {
1772     const auto mask_bit = 1 << i;
1773     begin_mask += mask_bit;
1774     end_mask += mask_bit;
1775   }
1776   if (begin_mask != 0 || end_mask != 0) {
1777     py::tuple stride_info = py::make_tuple(py::make_tuple(slice_info.start()), py::make_tuple(slice_info.stop()),
1778                                            py::make_tuple(slice_info.step()));
1779     py::tuple mask_info = py::make_tuple(begin_mask, end_mask, 0);
1780     return py::make_tuple(py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kStrideSliceWithMask)),
1781                           py::make_tuple(py::make_tuple(stride_info, mask_info)));
1782   }
1783   return py::make_tuple(
1784     py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kStrideSlice)),
1785     py::make_tuple(py::make_tuple(py::make_tuple(slice_info.start()), py::make_tuple(slice_info.stop()),
1786                                   py::make_tuple(slice_info.step()))));
1787 }
1788 
GetItemIndexSimpleIndex(const py::object & py_index,const ValuePtr & data_value,const ShapeVector & data_shape)1789 py::object TensorIndex::GetItemIndexSimpleIndex(const py::object &py_index, const ValuePtr &data_value,
1790                                                 const ShapeVector &data_shape) {
1791   if (py::isinstance<py::bool_>(py_index)) {
1792     return TensorIndex::GetItemByBool(data_value, data_shape, TensorIndex(py_index).boolean());
1793   }
1794   if (data_value != nullptr && py::isinstance<py::int_>(py_index)) {
1795     return TensorIndex::GetItemByNumberWithView(data_value, data_shape, TensorIndex(py_index).integer());
1796   }
1797   if (py::isinstance<py::slice>(py_index) || TensorIndex(py_index).slice().step() == -1) {
1798     return TensorIndex::GetItemBySlice(data_value, data_shape, TensorIndex(py_index));
1799   }
1800   if (py::isinstance<py::none>(py_index)) {
1801     return TensorIndex::GetItemByBool(data_value, data_shape, 1);
1802   }
1803   return py::none();
1804 }
1805 
GetStubAbsTypeId(const AbstractBasePtr & abs)1806 TypeId GetStubAbsTypeId(const AbstractBasePtr &abs) {
1807   MS_EXCEPTION_IF_NULL(abs);
1808 
1809   if (abs->isa<abstract::AbstractTensor>()) {
1810     auto tensor_abs = abs->cast<abstract::AbstractTensorPtr>();
1811     MS_EXCEPTION_IF_NULL(tensor_abs);
1812     MS_EXCEPTION_IF_NULL(tensor_abs->element());
1813     MS_EXCEPTION_IF_NULL(tensor_abs->element()->BuildType());
1814     return tensor_abs->element()->BuildType()->type_id();
1815   } else {
1816     MS_EXCEPTION_IF_NULL(abs->BuildType());
1817     return abs->BuildType()->type_id();
1818   }
1819 }
1820 
EnableView(bool is_setitem=false)1821 bool EnableView(bool is_setitem = false) {
1822   if (pynative::PyNativeExecutor::GetInstance()->grad_executor()->is_high_order_top_cell()) {
1823     // 1. pack node will slice failed with view.
1824     // 2. SelectView and CopyWithSlice has no kernel, can not enable view in high order cell.
1825     return false;
1826   }
1827 
1828   // For setitem, the grad of CopyWithSlice is erroneous. If we are in setitem and requires grad, disable view.
1829   if (is_setitem && pynative::PyNativeExecutor::GetInstance()->grad_executor()->RequiresGrad()) return false;
1830 
1831   return true;
1832 }
1833 
GetItemIndexInfo(const py::object & py_data,const py::object & py_index,const py::bool_ & is_ascend)1834 py::object TensorIndex::GetItemIndexInfo(const py::object &py_data, const py::object &py_index,
1835                                          const py::bool_ &is_ascend) {
1836   ShapeVector data_shape;
1837   ValuePtr data_value;
1838   if (IsStubTensor(py_data)) {
1839     auto value = GetStubTensorValue(py_data);
1840     MS_EXCEPTION_IF_NULL(value);
1841     auto abs = value->ToAbstract();
1842     MS_EXCEPTION_IF_NULL(abs);
1843     data_shape = dyn_cast<abstract::Shape>(abs->BuildShape())->shape();
1844 
1845     if (EnableView()) {
1846       data_value = value;
1847     }
1848   } else if (py::isinstance<Tensor>(py_data)) {
1849     auto tensor = py_data.cast<TensorPtr>();
1850     MS_EXCEPTION_IF_NULL(tensor);
1851     if (EnableView()) {
1852       data_value = tensor;
1853     }
1854     data_shape = tensor->shape();
1855   } else {
1856     MS_EXCEPTION(TypeError) << "First input of Tensor index must be tensor but got " << py_data;
1857   }
1858 
1859   const auto &simple_index_output = GetItemIndexSimpleIndex(py_index, data_value, data_shape);
1860   if (simple_index_output != py::none()) {
1861     return simple_index_output;
1862   }
1863 
1864   std::vector<int64_t> data_transfer_types;
1865   std::vector<py::object> data_transfer_args;
1866   if (py::isinstance<py::tuple>(py_index) &&
1867       GetItemByTupleWithView(data_value, data_shape, py_index, &data_transfer_types, &data_transfer_args, nullptr)) {
1868     MS_LOG(INFO) << "(View) In branch get item by tuple with view, data_shape: " << data_shape
1869                  << " tensor_indexes: " << py_index;
1870     return py::make_tuple(py::none(), VectorToPyTuple(data_transfer_types), VectorToPyTuple(data_transfer_args));
1871   }
1872   MS_LOG(INFO) << "(Tensor) Get item datashape is: " << data_shape << ", index is: " << py_index;
1873   py::object new_py_index = IsStubTensor(py_index) ? py::cast(ConvertStubTensor(py_index)) : py_index;
1874   TensorIndex::py_index_handle_ = new_py_index;
1875   TensorIndex::is_ascend_ = is_ascend;
1876   TensorIndex::np_module_ = py::module::import("numpy");
1877   TensorIndex::index_op_type_ = IndexOpType::GetItem;
1878   TensorIndex index(new_py_index);
1879   CheckGetItemIndex(index.type());
1880   py::object output = py::none();
1881   switch (index.type()) {
1882     case TensorIndexType::Tensor: {
1883       output = GetItemByTensor(data_shape, index.tensor());
1884       break;
1885     }
1886     case TensorIndexType::List: {
1887       output = GetItemByList(data_shape, index);
1888       break;
1889     }
1890     case TensorIndexType::Tuple: {
1891       output = GetItemByTuple(data_shape, index.ExpandToVector());
1892       break;
1893     }
1894     case TensorIndexType::Boolean: {
1895       output = GetItemByBool(data_value, data_shape, index.boolean());
1896       break;
1897     }
1898     case TensorIndexType::Ellipsis: {
1899       output = py::make_tuple(py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kByPass)),
1900                               py::make_tuple(py::none()));
1901       break;
1902     }
1903     case TensorIndexType::Integer: {
1904       output = GetItemByNumber(data_shape, index.integer());
1905       break;
1906     }
1907     default: {
1908       MS_EXCEPTION(TypeError)
1909         << "Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor, int, list and "
1910            "tuple as index, but got "
1911         << TensorIndex::py_index_handle_ << " with type " << TensorIndex::py_index_handle_.get_type();
1912     }
1913   }
1914   return output;
1915 }
1916 
1917 // ***********************************************get set_item info*******************************************
SetItemByNumber(const ShapeVector & data_shape,const TypePtr & data_type,bool is_parameter,const TensorIndex & tensor_index,const TensorIndexType & py_value_type)1918 py::object TensorIndex::SetItemByNumber(const ShapeVector &data_shape, const TypePtr &data_type, bool is_parameter,
1919                                         const TensorIndex &tensor_index, const TensorIndexType &py_value_type) {
1920   // If tensor is small, we use method in IntToTensor for faster
1921   MS_LOG(DEBUG) << "In branch Set item by number, data_shape: " << data_shape << " tensor_indexes: " << tensor_index
1922                 << "value: " << TensorIndex::py_value_handle_;
1923 
1924   std::tuple<int64_t, py::object, ShapeVector> value_transfer =
1925     GetValueTransferType(py_value_type, set_item_by_non_tensor, data_type, false);
1926   std::vector<int64_t> value_transfer_types = {std::get<0>(value_transfer)};
1927   std::vector<py::object> value_transfer_args = {std::get<1>(value_transfer)};
1928   if (data_shape.empty()) {
1929     MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
1930   }
1931   int64_t dim_size = data_shape[0];
1932   int64_t index = tensor_index.integer();
1933   if (index < -dim_size || index >= dim_size) {
1934     MS_EXCEPTION(IndexError) << "Index " << index << " is out of bounds for axis 0 with size " << dim_size;
1935   }
1936   TensorPtr new_index = std::make_shared<Tensor>();
1937   if (!CheckLargeTensor(data_shape)) {
1938     new_index = IntToTensor(index, data_shape);
1939     (void)value_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kBroadCast));
1940     MS_EXCEPTION_IF_NULL(new_index);
1941     ShapeVector value_shape(new_index->shape().begin(), new_index->shape().end() - 1);
1942     value_transfer_args.push_back(VectorToPyTuple<int64_t>(value_shape));
1943   } else {
1944     auto out_i = static_cast<int32_t>(CheckRange(index, dim_size));
1945     new_index = std::make_shared<Tensor>(kNumberTypeInt32, ShapeVector({1, 1}), &out_i, int32_bytes_number);
1946     ShapeVector updates_shape = {1};
1947     (void)updates_shape.insert(updates_shape.end(), data_shape.begin() + 1, data_shape.end());
1948     (void)value_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kBroadCast));
1949     (void)value_transfer_args.emplace_back(VectorToPyTuple(updates_shape));
1950   }
1951   ValueTransferType data_transfer_type =
1952     is_parameter ? ValueTransferType::kScatterNdUpdate : ValueTransferType::kTensorScatterUpdate;
1953   return py::make_tuple(new_index, VectorToPyTuple<int64_t>(value_transfer_types),
1954                         VectorToPyTuple<py::object>(value_transfer_args),
1955                         py::make_tuple(static_cast<int>(data_transfer_type)), py::make_tuple(py::none()));
1956 }
1957 
SetItemByNumberWithView(const ShapeVector & data_shape,const TypePtr & data_type,bool is_parameter,const TensorIndex & tensor_index,const TensorIndexType & py_value_type,const ValuePtr & data_value)1958 py::object TensorIndex::SetItemByNumberWithView(const ShapeVector &data_shape, const TypePtr &data_type,
1959                                                 bool is_parameter, const TensorIndex &tensor_index,
1960                                                 const TensorIndexType &py_value_type, const ValuePtr &data_value) {
1961   // If tensor is small, we use method in IntToTensor for faster
1962   MS_LOG(INFO) << "(View) In branch set item by number, data_shape: " << data_shape
1963                << " tensor_indexes: " << tensor_index << "value: " << TensorIndex::py_value_handle_;
1964 
1965   std::tuple<int64_t, py::object, ShapeVector> value_transfer =
1966     GetValueTransferType(py_value_type, set_item_by_non_tensor, data_type, true);
1967   std::vector<int64_t> value_transfer_types = {std::get<0>(value_transfer)};
1968   std::vector<py::object> value_transfer_args = {std::get<1>(value_transfer)};
1969   if (data_shape.empty()) {
1970     MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
1971   }
1972   int64_t dim_size = data_shape[0];
1973   int64_t index = tensor_index.integer();
1974   if (index < -dim_size || index >= dim_size) {
1975     MS_EXCEPTION(IndexError) << "Index " << index << " is out of bounds for axis 0 with size " << dim_size;
1976   }
1977   ShapeVector updates_shape = {1};
1978   (void)updates_shape.insert(updates_shape.end(), data_shape.begin() + 1, data_shape.end());
1979   std::vector<int64_t> data_transfer_types;
1980   std::vector<py::object> data_transfer_args;
1981   int64_t transformed_number = CheckRange(index, data_shape.at(0));
1982 
1983   std::vector<pynative::SliceOpInfoPtr> slice_op_infos;
1984   std::vector<int64_t> new_data_shape(data_shape.begin() + 1, data_shape.end());
1985   auto slice_op_info = std::make_shared<pynative::SliceOpInfo>();
1986   slice_op_info->slice_op_name = prim::kPrimSelectView->name();
1987   (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(transformed_number));
1988   (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(0));
1989   (void)slice_op_info->data_indexs.emplace_back(0);
1990   (void)slice_op_infos.emplace_back(slice_op_info);
1991   auto slice_output = SetitemCopyView(&slice_op_infos, data_value, new_data_shape, data_type, py_value_handle_);
1992   if (slice_output != py::none()) {
1993     data_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kJustReturn));
1994     data_transfer_args.emplace_back(slice_output);
1995     return py::make_tuple(py::str("view"), py::tuple(), py::tuple(), VectorToPyTuple(data_transfer_types),
1996                           VectorToPyTuple(data_transfer_args));
1997   }
1998 
1999   (void)data_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kSelectView));
2000   (void)data_transfer_args.emplace_back(py::make_tuple(py::int_(transformed_number), py::int_(0)));
2001   (void)data_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kCopyView));
2002   (void)data_transfer_args.emplace_back(py::none());
2003   return py::make_tuple(py::str("view"), VectorToPyTuple<int64_t>(value_transfer_types),
2004                         VectorToPyTuple<py::object>(value_transfer_args), VectorToPyTuple(data_transfer_types),
2005                         VectorToPyTuple(data_transfer_args));
2006 }
2007 
SetItemByTensor(const ShapeVector & data_shape,bool is_parameter,const TensorIndex & tensor_index,const TensorIndexType & py_value_type)2008 py::object TensorIndex::SetItemByTensor(const ShapeVector &data_shape, bool is_parameter,
2009                                         const TensorIndex &tensor_index, const TensorIndexType &py_value_type) {
2010   MS_LOG(DEBUG) << "In branch Set item by tensor, data_shape: " << data_shape << " tensor_indexes: " << tensor_index
2011                 << "value: " << TensorIndex::py_value_handle_;
2012   std::vector<int64_t> value_transfer_types;
2013   std::vector<py::object> value_transfer_args;
2014   const TensorPtr &index = tensor_index.tensor();
2015   int64_t data_dims = SizeToLong(data_shape.size());
2016   MS_EXCEPTION_IF_NULL(index);
2017   bool format_index_tensor = false;
2018   ValueTransferType tensor_update_type = ValueTransferType::kTensorScatterUpdate;
2019   py::array np_index;
2020   if (CheckTypeIsInstance(py_value_type, {TensorIndexType::Float, TensorIndexType::Integer, TensorIndexType::Boolean,
2021                                           TensorIndexType::Tensor})) {
2022     if (!CheckTypeIsInstance<TypeId>(index->data_type(), {kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32,
2023                                                           kNumberTypeInt64, kNumberTypeBool})) {
2024       MS_EXCEPTION(IndexError) << "For tensor set item, the index tensor data type" << index->data_type()
2025                                << " is not supported.";
2026     }
2027     if (index->data_type() == kNumberTypeBool) {
2028       if (CheckScalarValue(TensorIndex::py_value_handle_)) {
2029         np_index = SetItemByTensorByBool(data_shape, index, data_dims, &value_transfer_types, &value_transfer_args,
2030                                          &tensor_update_type);
2031       } else {
2032         return py::make_tuple(index, py::make_tuple(), py::make_tuple(),
2033                               py::make_tuple(static_cast<int>(ValueTransferType::kSetitemByBoolTensor)),
2034                               py::make_tuple(py::none()));
2035       }
2036     } else {
2037       ShapeVector index_shape = index->shape();
2038       np_index = TensorPy::SyncAsNumpy(*index);
2039       if (index_shape.empty()) {
2040         np_index = TensorIndex::np_module_.attr("expand_dims")(np_index, -1);
2041         (void)index_shape.emplace_back(1);
2042       }
2043       ShapeVector updates_shape = index_shape;
2044       (void)updates_shape.insert(updates_shape.end(), data_shape.begin() + 1, data_shape.end());
2045       if (py_value_type != TensorIndexType::Tensor) {
2046         (void)value_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kNumberToTensor));
2047       } else {
2048         (void)value_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kCast));
2049       }
2050       (void)value_transfer_args.emplace_back(py::none());
2051       (void)value_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kBroadCast));
2052       (void)value_transfer_args.emplace_back(VectorToPyTuple(updates_shape));
2053       if (data_shape.empty()) {
2054         MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
2055       }
2056       int64_t index_shape_dim = std::accumulate(index_shape.begin(), index_shape.end(), 1, std::multiplies<>());
2057       if (index_shape_dim <= 1) {
2058         int64_t first_val = data_shape[0];
2059         np_index = TensorIndex::np_module_.attr("select")(
2060           TensorIndex::np_module_.attr("less")(np_index, 0),
2061           TensorIndex::np_module_.attr("add")(np_index, py::int_(first_val)), np_index);
2062       } else {
2063         format_index_tensor = true;
2064       }
2065       np_index = TensorIndex::np_module_.attr("expand_dims")(np_index, -1);
2066       (void)index_shape.emplace_back(1);
2067       constexpr int64_t min_index_shape_size = 2;
2068       if (index_shape.size() < min_index_shape_size) {
2069         auto np_expand_dims_method = TensorIndex::np_module_.attr("expand_dims");
2070         np_index = np_expand_dims_method(np_index, 0);
2071         (void)value_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kExpandDims));
2072         (void)value_transfer_args.emplace_back(py::int_(0));
2073       }
2074       tensor_update_type = is_parameter ? ValueTransferType::kScatterNdUpdate : ValueTransferType::kTensorScatterUpdate;
2075     }
2076   } else if (py_value_type == TensorIndexType::Tuple || py_value_type == TensorIndexType::List) {
2077     (void)value_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kHandleSequenceValue));
2078     (void)value_transfer_args.emplace_back(py::make_tuple(py::int_(set_item_by_one_tensor), index));
2079     if (CheckTypeIsInstance<TypeId>(index->data_type(), kIntTypes)) {
2080       np_index = TensorPy::SyncAsNumpy(*index);
2081       np_index = CastToInt(TensorIndex::np_module_.attr("expand_dims")(np_index, -1));
2082       tensor_update_type = ValueTransferType::kTensorScatterUpdate;
2083     } else if (index->data_type() == kNumberTypeBool) {
2084       return py::make_tuple(
2085         index, VectorToPyTuple<int64_t>(value_transfer_types), VectorToPyTuple<py::object>(value_transfer_args),
2086         py::make_tuple(static_cast<int>(ValueTransferType::kSetitemByBoolTensor)), py::make_tuple(py::none()));
2087     } else {
2088       MS_EXCEPTION(TypeError) << "The tensor index must be int or bool type, but got " << tensor_index;
2089     }
2090   }
2091   std::vector<int> tensor_update_types{static_cast<int>(tensor_update_type)};
2092   std::vector<py::object> tensor_update_args{py::none()};
2093   if (format_index_tensor) {
2094     (void)tensor_update_types.insert(tensor_update_types.begin(),
2095                                      static_cast<int>(ValueTransferType::kFormatIndexTensor));
2096     (void)tensor_update_args.insert(tensor_update_args.begin(), py::make_tuple(0, data_shape[0]));
2097   }
2098   return py::make_tuple(TensorPy::MakeTensor(TensorIndex::np_module_.attr("array")(np_index)),
2099                         VectorToPyTuple<int64_t>(value_transfer_types),
2100                         VectorToPyTuple<py::object>(value_transfer_args), VectorToPyTuple<int>(tensor_update_types),
2101                         VectorToPyTuple<py::object>(tensor_update_args));
2102 }
2103 
SetItemByTuple(const ShapeVector & data_shape,const TypePtr & data_type,const TensorIndex & py_index,const TensorIndexType & py_value_type)2104 py::object TensorIndex::SetItemByTuple(const ShapeVector &data_shape, const TypePtr &data_type,
2105                                        const TensorIndex &py_index, const TensorIndexType &py_value_type) {
2106   MS_LOG(DEBUG) << "In branch Set item by tuple, data_shape: " << data_shape << " tensor_indexes: " << py_index
2107                 << "value: " << TensorIndex::py_value_handle_;
2108   if (!CheckTypeIsInstance<TensorIndexType>(py_value_type,
2109                                             {TensorIndexType::Integer, TensorIndexType::Float, TensorIndexType::Boolean,
2110                                              TensorIndexType::Tensor, TensorIndexType::List, TensorIndexType::Tuple})) {
2111     MS_EXCEPTION(TypeError) << "Only support int, float, bool, Tensor, list, tuple as value, but got "
2112                             << TensorIndex::py_value_handle_.get_type();
2113   }
2114 
2115   std::tuple<int64_t, py::object, ShapeVector> value_transfer =
2116     GetValueTransferType(py_value_type, set_item_by_non_tensor, data_type, false);
2117   constexpr size_t value_transfer_types_index = 0;
2118   constexpr size_t value_transfer_args_index = 1;
2119   constexpr size_t value_transfer_shapes_index = 2;
2120   std::vector<int64_t> value_transfer_types = {std::get<value_transfer_types_index>(value_transfer)};
2121   std::vector<py::object> value_transfer_args = {std::get<value_transfer_args_index>(value_transfer)};
2122   ShapeVector value_transfer_shape = {std::get<value_transfer_shapes_index>(value_transfer)};
2123 
2124   if (CheckTypeIsInstance<TensorIndexType>(
2125         py_value_type, {TensorIndexType::Boolean, TensorIndexType::Float, TensorIndexType::Integer})) {
2126     TensorIndex index = TensorIndex::UnpackTuple(py_index);
2127     std::vector<TensorIndex> index_list = index.ExpandToVector();
2128     return SetitemByTupleWithTensor(data_shape, index_list, value_transfer_shape, &value_transfer_types,
2129                                     &value_transfer_args);
2130   }
2131   std::vector<TensorIndex> index_list = py_index.ExpandToVector();
2132   return SetitemByTupleWithTensor(data_shape, index_list, value_transfer_shape, &value_transfer_types,
2133                                   &value_transfer_args);
2134 }
2135 
SetItemBySlice(const ShapeVector & data_shape,const TypePtr & data_type,const TensorIndex & tensor_index,const TensorIndexType & py_value_type,const ValuePtr & data_value)2136 py::object TensorIndex::SetItemBySlice(const ShapeVector &data_shape, const TypePtr &data_type,
2137                                        const TensorIndex &tensor_index, const TensorIndexType &py_value_type,
2138                                        const ValuePtr &data_value) {
2139   MS_LOG(INFO) << "(View) In branch set item by slice, data_shape: " << data_shape
2140                << " tensor_indexes: " << tensor_index << "value: " << TensorIndex::py_value_handle_;
2141   if (!CheckTypeIsInstance<TensorIndexType>(py_value_type,
2142                                             {TensorIndexType::Integer, TensorIndexType::Float, TensorIndexType::Boolean,
2143                                              TensorIndexType::Tensor, TensorIndexType::List, TensorIndexType::Tuple})) {
2144     MS_EXCEPTION(TypeError) << "Only support int, float, bool, Tensor, list, tuple as value, but got "
2145                             << TensorIndex::py_value_handle_.get_type();
2146   }
2147   Slice slice_info = Slice(tensor_index.slice(), data_shape[0]);
2148   std::tuple<int64_t, py::object, ShapeVector> value_transfer =
2149     GetValueTransferType(py_value_type, set_item_by_non_tensor, data_type, slice_info.step() >= 0);
2150   std::vector<int64_t> value_transfer_types = {std::get<0>(value_transfer)};
2151   std::vector<py::object> value_transfer_args = {std::get<1>(value_transfer)};
2152   return SetitemBySliceWithTensor(data_shape, tensor_index, &value_transfer_types, &value_transfer_args, data_value,
2153                                   data_type);
2154 }
2155 
SetItemIndexInfo(const py::object & py_data,const py::object & py_index,const py::object & py_value,const py::bool_ & is_ascend)2156 py::object TensorIndex::SetItemIndexInfo(const py::object &py_data, const py::object &py_index,
2157                                          const py::object &py_value, const py::bool_ &is_ascend) {
2158   if (!py::isinstance<Tensor>(py_data) && !IsStubTensor(py_data)) {
2159     MS_EXCEPTION(TypeError) << "First input of Tensor index must be tensor but got " << py_data;
2160   }
2161   ShapeVector data_shape;
2162   TypePtr data_type;
2163   bool is_parameter = false;
2164   ValuePtr data_value;
2165   if (IsStubTensor(py_data)) {  // PackTensor have not real Tensor.
2166     auto value = GetStubTensorValue(py_data);
2167     MS_EXCEPTION_IF_NULL(value);
2168     auto abs = value->ToAbstract();
2169     MS_EXCEPTION_IF_NULL(abs);
2170     data_shape = dyn_cast<abstract::Shape>(abs->BuildShape())->shape();
2171     data_type = abs->BuildType();
2172     MS_EXCEPTION_IF_NULL(data_type);
2173     if (EnableView()) {
2174       data_value = value;
2175     }
2176   } else {
2177     TensorPtr data = py_data.cast<TensorPtr>();
2178     MS_EXCEPTION_IF_NULL(data);
2179     if (EnableView(true)) {
2180       data_value = data;
2181     }
2182     data_shape = data->shape();
2183     data_type = data->Dtype();
2184     is_parameter = data->is_parameter();
2185   }
2186   TensorIndex::py_value_handle_ = py_value;
2187   TensorIndex::np_module_ = py::module::import("numpy");
2188   TensorIndex::py_index_handle_ = py_index;
2189   TensorIndex::is_ascend_ = is_ascend;
2190   TensorIndex::index_op_type_ = IndexOpType::SetItem;
2191   const TensorIndexType value_type = IsStubTensor(py_value) ? TensorIndexType::Tensor : TensorIndex(py_value).type();
2192   bool valid = CheckTypeIsInstance<TensorIndexType>(
2193     value_type, {TensorIndexType::Integer, TensorIndexType::Float, TensorIndexType::Boolean, TensorIndexType::Tensor,
2194                  TensorIndexType::List, TensorIndexType::Tuple});
2195   if (!valid) {
2196     MS_EXCEPTION(TypeError) << "only support numbers, Tensor, tuple, list as value, but got "
2197                             << TensorIndex::py_value_handle_ << " with type "
2198                             << TensorIndex::py_value_handle_.get_type();
2199   }
2200   if (py::isinstance<py::int_>(py_index) && !py::isinstance<py::bool_>(py_index) && data_value != nullptr) {
2201     return SetItemByNumberWithView(data_shape, data_type, is_parameter, TensorIndex(py_index), value_type, data_value);
2202   }
2203   if (py::isinstance<py::slice>(py_index)) {
2204     return TensorIndex::SetItemBySlice(data_shape, data_type, TensorIndex(py_index), value_type, data_value);
2205   }
2206   if (data_value != nullptr && (py::isinstance<py::none>(py_index) || py::isinstance<py::ellipsis>(py_index))) {
2207     auto output = py::make_tuple(
2208       py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kByPass)), py::make_tuple(py::none()),
2209       py::make_tuple(static_cast<int>(ValueTransferType::kSetItemByEllipsis)), py::make_tuple(py::none()));
2210     return output;
2211   }
2212   std::vector<int64_t> data_transfer_types;
2213   std::vector<py::object> data_transfer_args;
2214   if (py::isinstance<py::tuple>(py_index) &&
2215       GetItemByTupleWithView(data_value, data_shape, py_index, &data_transfer_types, &data_transfer_args, data_type)) {
2216     MS_LOG(INFO) << "(View) In branch set item by tuple with view, data_shape: " << data_shape
2217                  << " tensor_indexes: " << py_index;
2218     return py::make_tuple(py::str("view"), py::tuple(), py::tuple(), VectorToPyTuple(data_transfer_types),
2219                           VectorToPyTuple(data_transfer_args));
2220   }
2221   MS_LOG(INFO) << "(Tensor) Set item data shape is: " << data_shape << ", index is: " << py_index
2222                << ", value is: " << py_value;
2223   TensorIndex index = TensorIndex(py_index);
2224 
2225   CheckSetItemIndex(index.type(), value_type);
2226   if (index.IsList()) {
2227     if (data_shape.empty()) {
2228       MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
2229     }
2230     index = TensorIndex::FormatList(index, data_shape[0]);
2231   }
2232 
2233   return SetItemIndexByIndexType(index, py_index, data_shape, data_type, value_type, is_parameter);
2234 }
2235 
SetItemIndexByIndexType(const TensorIndex & index,const py::object & py_index,const ShapeVector & data_shape,const TypePtr & data_type,const TensorIndexType & value_type,bool is_parameter)2236 py::object TensorIndex::SetItemIndexByIndexType(const TensorIndex &index, const py::object &py_index,
2237                                                 const ShapeVector &data_shape, const TypePtr &data_type,
2238                                                 const TensorIndexType &value_type, bool is_parameter) {
2239   py::object output =
2240     py::make_tuple(py::none(), py::none(), py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kUnknown)),
2241                    py::make_tuple(py::none()));
2242   switch (index.type()) {
2243     case TensorIndexType::Integer: {
2244       output = SetItemByNumber(data_shape, data_type, is_parameter, index, value_type);
2245       break;
2246     }
2247     case TensorIndexType::Tensor: {
2248       output = SetItemByTensor(data_shape, is_parameter, index, value_type);
2249       break;
2250     }
2251     case TensorIndexType::Tuple: {
2252       output = SetItemByTuple(data_shape, data_type, index, value_type);
2253       break;
2254     }
2255     case TensorIndexType::Ellipsis:
2256     case TensorIndexType::None: {
2257       output = py::make_tuple(
2258         py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kByPass)), py::make_tuple(py::none()),
2259         py::make_tuple(static_cast<int>(ValueTransferType::kSetItemByEllipsis)), py::make_tuple(py::none()));
2260       break;
2261     }
2262     case TensorIndexType::Boolean: {
2263       output = py::make_tuple(
2264         py_index, py::make_tuple(static_cast<int>(ValueTransferType::kByPass)), py::make_tuple(py::none()),
2265         py::make_tuple(static_cast<int>(ValueTransferType::kSetItemByBool)), py::make_tuple(py::none()));
2266       break;
2267     }
2268     default: {
2269       MS_EXCEPTION(TypeError)
2270         << "Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor, int, list and "
2271            "tuple as index, but got "
2272         << TensorIndex::py_index_handle_ << "with type " << TensorIndex::py_index_handle_.get_type();
2273     }
2274   }
2275 
2276   return output;
2277 }
2278 
2279 }  // namespace mindspore::tensor
2280