1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pybind_api/ir/tensor_index_py.h"
18 #include <pybind11/stl.h>
19 #include <memory>
20 #include <string>
21 #include <algorithm>
22 #include <utility>
23 #include <vector>
24 #include <functional>
25 #include "pybind11/pytypes.h"
26 #include "pipeline/jit/ps/parse/parse_base.h"
27 #include "utils/hash_set.h"
28 #include "utils/log_adapter.h"
29 #include "pipeline/pynative/pynative_execute.h"
30 #include "mindspore/core/ops/array_ops.h"
31
32 namespace mindspore::tensor {
33 using tensor::TensorPy;
34 py::handle TensorIndex::py_index_handle_ = py::none();
35 py::handle TensorIndex::py_value_handle_ = py::none();
36 bool TensorIndex::is_ascend_ = false;
37 IndexOpType TensorIndex::index_op_type_ = IndexOpType::GetItem;
38 py::module TensorIndex::np_module_ = py::module();
39 static const std::vector<TypeId> kIntTypes{kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64};
40 // ***********************************************utils*******************************************
operator <<(std::ostream & stream,const TensorIndex & tensor_index)41 std::ostream &operator<<(std::ostream &stream, const TensorIndex &tensor_index) {
42 TensorIndexType tensor_index_type = tensor_index.type();
43 switch (tensor_index_type) {
44 case TensorIndexType::None: {
45 stream << "None";
46 break;
47 }
48 case TensorIndexType::Integer: {
49 stream << tensor_index.integer();
50 break;
51 }
52 case TensorIndexType::Ellipsis: {
53 stream << "...";
54 break;
55 }
56 case TensorIndexType::Boolean: {
57 stream << std::boolalpha << tensor_index.boolean();
58 break;
59 }
60 case TensorIndexType::Slice: {
61 stream << tensor_index.slice();
62 break;
63 }
64 case TensorIndexType::Tensor: {
65 MS_EXCEPTION_IF_NULL(tensor_index.tensor());
66 stream << tensor_index.tensor()->ToString();
67 break;
68 }
69 case TensorIndexType::List: {
70 stream << tensor_index.list();
71 break;
72 }
73 case TensorIndexType::Tuple: {
74 stream << tensor_index.tuple();
75 break;
76 }
77 case TensorIndexType::Array: {
78 stream << tensor_index.array();
79 break;
80 }
81 case TensorIndexType::Float: {
82 stream << tensor_index.floating_point();
83 break;
84 }
85 }
86 return stream;
87 }
88
operator <<(std::ostream & stream,const std::vector<TensorIndex> & tensor_indices)89 std::ostream &operator<<(std::ostream &stream, const std::vector<TensorIndex> &tensor_indices) {
90 stream << "(";
91 for (size_t i = 0; i < tensor_indices.size(); i++) {
92 stream << tensor_indices[i];
93 if (i < tensor_indices.size() - 1) {
94 stream << ", ";
95 }
96 }
97 stream << ")";
98 return stream;
99 }
100
CheckGetItemIndex(const TensorIndexType & index_data_type)101 void TensorIndex::CheckGetItemIndex(const TensorIndexType &index_data_type) {
102 bool valid = CheckTypeIsInstance<TensorIndexType>(
103 index_data_type,
104 {TensorIndexType::Tensor, TensorIndexType::List, TensorIndexType::Boolean, TensorIndexType::Slice,
105 TensorIndexType::Integer, TensorIndexType::Tuple, TensorIndexType::Ellipsis, TensorIndexType::None});
106 if (!valid) {
107 MS_EXCEPTION(IndexError)
108 << "Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor, int, list and "
109 "tuple as index, but got "
110 << TensorIndex::py_index_handle_ << " with type " << TensorIndex::py_index_handle_.get_type();
111 }
112 }
113
CheckSetItemIndex(const TensorIndexType & index_data_type,const TensorIndexType & value_data_type)114 void TensorIndex::CheckSetItemIndex(const TensorIndexType &index_data_type, const TensorIndexType &value_data_type) {
115 CheckGetItemIndex(index_data_type);
116 bool valid = CheckTypeIsInstance<TensorIndexType>(
117 value_data_type, {TensorIndexType::Integer, TensorIndexType::Float, TensorIndexType::Boolean,
118 TensorIndexType::Tensor, TensorIndexType::List, TensorIndexType::Tuple});
119 if (!valid) {
120 MS_EXCEPTION(TypeError) << "only support numbers, Tensor, tuple, list as value, but got "
121 << TensorIndex::py_value_handle_ << " with type "
122 << TensorIndex::py_value_handle_.get_type();
123 }
124 }
125
BroadCastShape(const ShapeVector & x_shape,const ShapeVector & y_shape)126 ShapeVector TensorIndex::BroadCastShape(const ShapeVector &x_shape, const ShapeVector &y_shape) {
127 if (x_shape == y_shape) {
128 return x_shape;
129 }
130 const size_t x_len = x_shape.size();
131 const size_t y_len = y_shape.size();
132 const size_t min_length = std::min(x_len, y_len);
133 ShapeVector broadcast_shape_back;
134
135 for (size_t i = 0; i < min_length; i++) {
136 size_t x_shape_index = x_len - min_length + i;
137 size_t y_shape_index = y_len - min_length + i;
138 if (x_shape[x_shape_index] == 1) {
139 (void)broadcast_shape_back.emplace_back(y_shape[y_shape_index]);
140 } else if (y_shape[y_shape_index] == 1 || x_shape[x_shape_index] == y_shape[y_shape_index]) {
141 (void)broadcast_shape_back.emplace_back(x_shape[x_shape_index]);
142 } else {
143 string index_op_type = index_op_type_ == IndexOpType::GetItem ? "tensor getitem" : "tensor setitem";
144 MS_EXCEPTION(ValueError) << "For '" << index_op_type
145 << "', x.shape and y.shape need to broadcast. The value of x.shape["
146 << std::to_string(x_shape_index) << "] or y.shape[" << std::to_string(y_shape_index)
147 << "] must be 1 or -1 when they are not the same, but got x.shape = " << x_shape
148 << " and y.shape = " << y_shape;
149 }
150 }
151 ShapeVector broadcast_shape_front;
152 if (min_length == x_len) {
153 (void)broadcast_shape_front.insert(
154 broadcast_shape_front.end(), y_shape.begin(),
155 y_shape.begin() + static_cast<int64_t>(y_len) - static_cast<int64_t>(min_length));
156 } else {
157 (void)broadcast_shape_front.insert(
158 broadcast_shape_front.end(), x_shape.begin(),
159 x_shape.begin() + static_cast<int64_t>(x_len) - static_cast<int64_t>(min_length));
160 }
161 (void)broadcast_shape_front.insert(broadcast_shape_front.end(), broadcast_shape_back.begin(),
162 broadcast_shape_back.end());
163 return broadcast_shape_front;
164 }
165
166 template <typename T>
SequenceToTensor(const T & sequence,int64_t dim_size)167 TensorIndex TensorIndex::SequenceToTensor(const T &sequence, int64_t dim_size) {
168 if (sequence.empty()) {
169 return TensorIndex(py::bool_(false));
170 }
171 if (std::all_of(sequence.begin(), sequence.end(), [](auto &x) { return py::isinstance<py::bool_>(x); })) {
172 int64_t seq_size = SizeToLong(sequence.size());
173 if (seq_size != dim_size) {
174 MS_EXCEPTION(IndexError) << "dimension is " << dim_size << " but corresponding boolean dimension is " << seq_size;
175 }
176 py::list new_range_dim_size;
177 for (size_t i = 0; i < sequence.size(); i++) {
178 if (py::cast<bool>(sequence[i]) == true) {
179 new_range_dim_size.append(py::int_(i));
180 }
181 }
182 if (new_range_dim_size.empty()) {
183 return TensorIndex(py::bool_(false));
184 }
185 return TensorIndex(TensorPy::MakeTensor(MakeNdArray(new_range_dim_size, dim_size)));
186 }
187 py::array output = MakeNdArray(sequence, dim_size);
188 if (output.dtype() == pybind11::dtype("object")) {
189 MS_LOG(EXCEPTION) << "Sequence as indices must have the same size across all dimensions and elements must be "
190 "integer (or boolean) type";
191 }
192 return TensorIndex(TensorPy::MakeTensor(output));
193 }
194
Unpack(const py::object & x)195 py::object TensorIndex::Unpack(const py::object &x) {
196 if (py::isinstance<py::tuple>(x)) {
197 auto new_x = x.cast<py::tuple>();
198 if (new_x.size() == 1) {
199 return Unpack(new_x[0]);
200 }
201 }
202 if (py::isinstance<py::list>(x)) {
203 auto new_x = x.cast<py::list>();
204 if (new_x.size() == 1) {
205 return Unpack(new_x[0]);
206 }
207 }
208 return x;
209 }
210
211 template <typename T>
UnpackTuple(const T & sequence)212 TensorIndex TensorIndex::UnpackTuple(const T &sequence) {
213 py::tuple res(sequence.size());
214 for (size_t i = 0; i < sequence.size(); i++) {
215 if (py::isinstance<py::list>(sequence[i]) || py::isinstance<py::tuple>(sequence[i])) {
216 res[i] = Unpack(sequence[i]);
217 } else {
218 res[i] = sequence[i];
219 }
220 }
221 return TensorIndex(res);
222 }
223
DeepList(const py::object & array_like,int64_t dim_size)224 py::object TensorIndex::DeepList(const py::object &array_like, int64_t dim_size) {
225 py::object new_array_like = CheckRange(array_like, dim_size);
226 if (py::isinstance<py::list>(array_like) || py::isinstance<py::tuple>(array_like)) {
227 auto list_array_like = array_like.cast<py::list>();
228 for (size_t i = 0; i < list_array_like.size(); i++) {
229 list_array_like[i] = DeepList(list_array_like[i], dim_size);
230 }
231 return list_array_like;
232 }
233 return new_array_like;
234 }
235
DeepTensorToNdArray(const py::object & array_like)236 py::object TensorIndex::DeepTensorToNdArray(const py::object &array_like) {
237 if (py::isinstance<tensor::Tensor>(array_like) || IsStubTensor(array_like)) {
238 auto tensor_index = IsStubTensor(array_like) ? ConvertStubTensor(array_like) : py::cast<TensorPtr>(array_like);
239 MS_EXCEPTION_IF_NULL(tensor_index);
240 return TensorPy::AsNumpy(*tensor_index);
241 }
242 if (py::isinstance<py::list>(array_like)) {
243 auto new_array_like_vector = array_like.cast<py::list>();
244 for (size_t i = 0; i < new_array_like_vector.size(); i++) {
245 new_array_like_vector[i] = DeepTensorToNdArray(new_array_like_vector[i]);
246 }
247 return new_array_like_vector;
248 }
249 return array_like;
250 }
251
MakeNdArray(const py::object & a,int64_t dim_size)252 py::array TensorIndex::MakeNdArray(const py::object &a, int64_t dim_size) {
253 if (!py::isinstance<py::list>(a) && !py::isinstance<py::tuple>(a) && !py::isinstance<py::int_>(a) &&
254 !py::isinstance<py::float_>(a) && !py::isinstance<py::bool_>(a)) {
255 MS_EXCEPTION(TypeError) << "Input data must be `int`, `float`, `bool`, `list` or `tuple` but got " << a.get_type();
256 }
257 py::object new_array = CheckRange(a, dim_size);
258 if (py::isinstance<py::list>(new_array) || py::isinstance<py::tuple>(new_array)) {
259 new_array = DeepList(new_array, dim_size);
260 new_array = DeepTensorToNdArray(new_array);
261 }
262 return new_array;
263 }
264
265 namespace Convert {
ConvertTypeToString(const TensorIndex & index)266 string ConvertTypeToString(const TensorIndex &index) {
267 if (index.IsNone())
268 return "None";
269 else if (index.IsEllipsis())
270 return "Ellipsis";
271 else if (index.IsInteger())
272 return "Integer";
273 else if (index.IsBoolean())
274 return "Boolean";
275 else if (index.IsSlice())
276 return "Slice";
277 else if (index.IsTensor())
278 return "Tensor";
279 else if (index.IsList())
280 return "List";
281 else if (index.IsTuple())
282 return "Tuple";
283 else if (index.IsArray())
284 return "Array";
285 else if (index.IsFloat())
286 return "Float";
287 return "Unknown";
288 }
289 } // namespace Convert
290
TransformEllipsisToSlice(const ShapeVector & data_shape,const std::vector<TensorIndex> & indices)291 std::vector<TensorIndex> TensorIndex::TransformEllipsisToSlice(const ShapeVector &data_shape,
292 const std::vector<TensorIndex> &indices) {
293 // Check if the tuple index len is longer than the data's dims and transform ellipsis in the indices
294 // to several slice.
295 int64_t ellipsis_occupy_dims = SizeToLong(data_shape.size());
296 int64_t ellipsis_positions = 0;
297 int64_t ellipsis_cnt = 0;
298 for (size_t i = 0; i < indices.size(); i++) {
299 bool valid = (CheckTypeIsInstance<TensorIndexType>(
300 indices[i].type(),
301 {TensorIndexType::List, TensorIndexType::Ellipsis, TensorIndexType::Tuple, TensorIndexType::None,
302 TensorIndexType::Integer, TensorIndexType::Tensor, TensorIndexType::Slice, TensorIndexType::Boolean}));
303 if (!valid) {
304 MS_EXCEPTION(TypeError) << "For tuple index, the types only support 'Slice', 'Ellipsis', 'None', 'Tensor', "
305 "'int', 'List', 'Tuple', 'bool', but got type '"
306 << Convert::ConvertTypeToString(indices[i]) << "', value: " << indices[i];
307 }
308 if (indices[i].IsSlice() || indices[i].IsInteger() || indices[i].IsTensor() || indices[i].IsSequence()) {
309 ellipsis_occupy_dims -= 1;
310 } else if (indices[i].IsEllipsis()) {
311 if (ellipsis_cnt >= 1) {
312 MS_EXCEPTION(IndexError) << "An index can only have a single ellipsis('...')";
313 }
314 ellipsis_cnt += 1;
315 ellipsis_positions = static_cast<int64_t>(i);
316 }
317 }
318 if (ellipsis_occupy_dims < 0) {
319 MS_EXCEPTION(IndexError) << "Tuple index " << indices << " out rang of tensor shape " << data_shape;
320 }
321
322 if (ellipsis_cnt == 0) {
323 return indices;
324 }
325
326 std::vector<TensorIndex> empty_slice(ellipsis_occupy_dims, TensorIndex(Slice()));
327 std::vector<TensorIndex> new_indices(indices.begin(), indices.end());
328 MS_EXCEPTION_IF_CHECK_FAIL(ellipsis_positions <= SizeToLong(new_indices.size()), "Index out of vector size.");
329 (void)new_indices.insert(new_indices.erase(new_indices.begin() + ellipsis_positions), empty_slice.begin(),
330 empty_slice.end());
331 return new_indices;
332 }
333
GenerateIndexInfoFromTupleOfMixedTensors(const std::vector<int64_t> & tensor_positions,const std::vector<ShapeVector> & tensor_indexes_shapes,const ShapeVector & slice_shapes,const TensorIndex & py_fancy_position)334 std::tuple<ShapeVector, ShapeVector, ShapeVector, int64_t> TensorIndex::GenerateIndexInfoFromTupleOfMixedTensors(
335 const std::vector<int64_t> &tensor_positions, const std::vector<ShapeVector> &tensor_indexes_shapes,
336 const ShapeVector &slice_shapes, const TensorIndex &py_fancy_position) {
337 bool tensor_index_continue_tag = true;
338 if (tensor_positions.empty()) {
339 tensor_index_continue_tag = false;
340 }
341 for (size_t i = 1; i < tensor_positions.size(); i++) {
342 if (tensor_positions[i] != tensor_positions[i - 1] + 1) {
343 tensor_index_continue_tag = false;
344 break;
345 }
346 }
347 int64_t fancy_position = 0;
348 if (py_fancy_position.IsNone()) {
349 fancy_position = tensor_index_continue_tag ? tensor_positions[0] : 0;
350 } else {
351 fancy_position = py_fancy_position.integer();
352 }
353
354 ShapeVector broadcast_shape = BroadCastShape(tensor_indexes_shapes);
355
356 fancy_position = std::min(fancy_position, SizeToLong(slice_shapes.size()));
357 ShapeVector final_shape = slice_shapes;
358 (void)final_shape.insert(final_shape.begin() + fancy_position, broadcast_shape.begin(), broadcast_shape.end());
359
360 ShapeVector index_tensor_new_shape(slice_shapes.size(), 1);
361 fancy_position = std::min(fancy_position, SizeToLong(index_tensor_new_shape.size()));
362
363 (void)index_tensor_new_shape.insert(index_tensor_new_shape.begin() + fancy_position, broadcast_shape.begin(),
364 broadcast_shape.end());
365
366 return std::make_tuple(broadcast_shape, index_tensor_new_shape, final_shape, fancy_position);
367 }
368
SliceToArray(const TensorIndex & tensor_index,const ShapeVector & shape)369 TensorIndex TensorIndex::SliceToArray(const TensorIndex &tensor_index, const ShapeVector &shape) {
370 MS_EXCEPTION_IF_CHECK_FAIL(!shape.empty(), "DataShape of Tensor can not be empty when sed item");
371 Slice slice_info = Slice(tensor_index.slice(), shape[0]);
372 int64_t start = slice_info.start();
373 int64_t stop = slice_info.stop();
374 int64_t step = slice_info.step();
375 if ((start - stop) * step >= 0) {
376 return TensorIndex(py::bool_(false));
377 }
378 int64_t n_dim = SizeToLong(shape.size());
379 py::tuple grids(n_dim);
380 grids[0] = TensorIndex::np_module_.attr("arange")(py::int_(start), py::int_(stop), py::int_(step));
381 for (size_t i = 1; i < shape.size(); i++) {
382 grids[i] = TensorIndex::np_module_.attr("arange")(0, py::int_(shape[i]), 1, TensorIndex::np_module_.attr("int32"));
383 }
384
385 py::object mesh = TensorIndex::np_module_.attr("ix_")(*grids);
386 py::tuple broadcast_mesh = TensorIndex::np_module_.attr("broadcast_arrays")(*mesh);
387 return TensorIndex(TensorIndex::np_module_.attr("stack")(broadcast_mesh, -1));
388 }
389
SliceToArray(const TensorPtr & index,const ShapeVector & final_shape,size_t slice_cnt,const ShapeVector & broadcast_shape,const ShapeVector & slice_shape,int64_t fancy_position)390 TensorIndex TensorIndex::SliceToArray(const TensorPtr &index, const ShapeVector &final_shape, size_t slice_cnt,
391 const ShapeVector &broadcast_shape, const ShapeVector &slice_shape,
392 int64_t fancy_position) {
393 ShapeVector shape = ComputeSliceShape(slice_shape, broadcast_shape.size(), slice_cnt, fancy_position);
394
395 py::object array = TensorPy::SyncAsNumpy(*index);
396 array = TensorIndex::np_module_.attr("ndarray").attr("astype")(array, TensorIndex::np_module_.attr("int32"));
397 array = TensorIndex::np_module_.attr("reshape")(array, py::cast(shape));
398 array = BroadCastTo(final_shape, array);
399 return TensorIndex(array);
400 }
401
BroadCastTo(const ShapeVector & broadcast_shape,const py::object & item)402 py::object TensorIndex::BroadCastTo(const ShapeVector &broadcast_shape, const py::object &item) {
403 return TensorIndex::np_module_.attr("broadcast_to")(item, py::cast(broadcast_shape));
404 }
405
BroadCastTensor(const ShapeVector & broadcast_shape,const ShapeVector & final_shape,const ShapeVector & new_shape,const TensorPtr & item)406 TensorIndex TensorIndex::BroadCastTensor(const ShapeVector &broadcast_shape, const ShapeVector &final_shape,
407 const ShapeVector &new_shape, const TensorPtr &item) {
408 py::array py_item = TensorPy::SyncAsNumpy(*item);
409 py_item = TensorIndex::np_module_.attr("ndarray").attr("astype")(py_item, TensorIndex::np_module_.attr("int32"));
410 py_item = BroadCastTo(broadcast_shape, py_item);
411 return TensorIndex(BroadCastTo(final_shape, TensorIndex::np_module_.attr("reshape")(py_item, py::cast(new_shape))));
412 }
413
GetValueTransferType(const TensorIndexType & py_value_type,int64_t op_type,const TypePtr & data_type,bool is_view)414 std::tuple<int64_t, py::object, ShapeVector> TensorIndex::GetValueTransferType(const TensorIndexType &py_value_type,
415 int64_t op_type,
416 const TypePtr &data_type, bool is_view) {
417 ValueTransferType value_transfer_type = ValueTransferType::kByPass;
418 py::object value_transfer_arg = py::none();
419 ShapeVector value_shape = {};
420 if (py_value_type == TensorIndexType::Tensor) {
421 if (is_view) {
422 return std::make_tuple(static_cast<int>(value_transfer_type), value_transfer_arg, value_shape);
423 }
424 value_transfer_arg = py::none();
425 if (IsStubTensor(TensorIndex::py_value_handle_)) {
426 value_shape = GetStubTensorInfo(TensorIndex::py_value_handle_).first;
427 } else {
428 auto value_ptr = TensorIndex::py_value_handle_.cast<TensorPtr>();
429 MS_EXCEPTION_IF_NULL(value_ptr);
430 value_shape = value_ptr->shape();
431 }
432 } else if (CheckTypeIsInstance(py_value_type,
433 {TensorIndexType::Float, TensorIndexType::Integer, TensorIndexType::Boolean})) {
434 value_transfer_type = ValueTransferType::kNumberToTensor;
435 value_transfer_arg = py::none();
436 } else if (py_value_type == TensorIndexType::List || py_value_type == TensorIndexType::Tuple) {
437 value_transfer_type = ValueTransferType::kHandleSequenceValue;
438 auto py_value_list = TensorIndex::py_value_handle_.cast<py::list>();
439 if (!py_value_list.empty()) {
440 (void)value_shape.emplace_back(SizeToLong(py_value_list.size()));
441 const py::object &first_py_ele = py_value_list[0];
442 TensorPtr ele;
443 if (py::isinstance<Tensor>(first_py_ele) || IsStubTensor(first_py_ele)) {
444 ele = IsStubTensor(first_py_ele) ? ConvertStubTensor(first_py_ele) : py::cast<TensorPtr>(first_py_ele);
445 } else {
446 ele = TensorPy::MakeTensor(py_value_list[0], data_type);
447 }
448 MS_EXCEPTION_IF_NULL(ele);
449 (void)value_shape.insert(value_shape.end(), ele->shape().begin(), ele->shape().end());
450 }
451 value_transfer_arg = py::make_tuple(py::int_(op_type), TensorIndex::py_index_handle_);
452 }
453 return std::make_tuple(static_cast<int>(value_transfer_type), value_transfer_arg, value_shape);
454 }
455
CastToInt(const py::array & input)456 static py::array CastToInt(const py::array &input) {
457 return TensorIndex::np_module_.attr("ndarray").attr("astype")(input, TensorIndex::np_module_.attr("int32"));
458 }
459
CheckLargeTensor(const ShapeVector & data_shape)460 static bool CheckLargeTensor(const ShapeVector &data_shape) {
461 constexpr int64_t max_dim = 1024 * 32;
462 int64_t data_shape_dim = std::accumulate(data_shape.begin(), data_shape.end(), 1, std::multiplies<>());
463 return data_shape_dim > max_dim;
464 }
465
466 // ***********************************************for get_item*******************************************
GenerateNonZeroIndex(const ShapeVector & data_shape,const TensorPtr & tensor_index,bool check_align)467 py::tuple TensorIndex::GenerateNonZeroIndex(const ShapeVector &data_shape, const TensorPtr &tensor_index,
468 bool check_align) {
469 if (!check_align) {
470 py::array index_array = TensorPy::SyncAsNumpy(*tensor_index);
471 return TensorIndex::np_module_.attr("nonzero")(index_array);
472 }
473 const int64_t data_dim = SizeToLong(data_shape.size());
474 const int64_t index_dims = tensor_index->DataDim();
475 if (data_dim < index_dims) {
476 MS_EXCEPTION(IndexError) << "The dim of index cannot be greater than indexed data, but got dim of index:"
477 << index_dims << ", dim of data:" << data_dim;
478 }
479 for (size_t i = 0; i < static_cast<size_t>(index_dims); i++) {
480 if (data_shape[i] != tensor_index->shape()[i]) {
481 MS_EXCEPTION(ValueError) << "The shape of index " << tensor_index->shape()
482 << "does not match the shape of the indexed data " << data_shape << " at dim index" << i;
483 }
484 }
485 py::array index_array = TensorPy::SyncAsNumpy(*tensor_index);
486 return TensorIndex::np_module_.attr("nonzero")(index_array);
487 }
488
GenerateNonZeroIndexTensorList(const ShapeVector & data_shape,const TensorPtr & tensor_index,bool check_align)489 std::vector<TensorPtr> TensorIndex::GenerateNonZeroIndexTensorList(const ShapeVector &data_shape,
490 const TensorPtr &tensor_index, bool check_align) {
491 py::tuple nonzero_indices = GenerateNonZeroIndex(data_shape, tensor_index, check_align);
492 MS_EXCEPTION_IF_CHECK_FAIL(!nonzero_indices.empty(), "Output size of nonzero should not be empty");
493 int64_t nonzero_indices_nums = SizeToLong(len(py::array(nonzero_indices[0])));
494 if (nonzero_indices_nums == 0) {
495 return {};
496 }
497 std::vector<TensorPtr> nonzero_indices_tensor_list;
498 (void)std::transform(nonzero_indices.begin(), nonzero_indices.end(), std::back_inserter(nonzero_indices_tensor_list),
499 [](const py::handle &nonzero_index) {
500 return TensorPy::MakeTensor(TensorIndex::np_module_.attr("array")(nonzero_index));
501 });
502 return nonzero_indices_tensor_list;
503 }
504
TensorGetitemByTupleParseTensorIndex(const ShapeVector & data_shape,const TensorPtr & tensor_index,std::vector<TensorPtr> * tuple_index_new,std::vector<TensorPtr> * tensor_indexes,std::vector<int64_t> * tensor_positions,bool check_align)505 bool TensorIndex::TensorGetitemByTupleParseTensorIndex(const ShapeVector &data_shape, const TensorPtr &tensor_index,
506 std::vector<TensorPtr> *tuple_index_new,
507 std::vector<TensorPtr> *tensor_indexes,
508 std::vector<int64_t> *tensor_positions, bool check_align) {
509 // parse index of tensor type
510 MS_EXCEPTION_IF_NULL(tensor_index);
511 if (CheckTypeIsInstance<TypeId>(tensor_index->data_type(), kIntTypes)) {
512 tensor_positions->emplace_back(tuple_index_new->size());
513 tuple_index_new->emplace_back(tensor_index);
514 tensor_indexes->emplace_back(tensor_index);
515 } else if (tensor_index->data_type() == kNumberTypeBool) {
516 std::vector<TensorPtr> nonzero_indices_tensors =
517 GenerateNonZeroIndexTensorList(data_shape, tensor_index, check_align);
518 if (nonzero_indices_tensors.empty()) {
519 return false;
520 }
521 int64_t nonzero_indices_position = SizeToLong(tuple_index_new->size());
522 (void)std::transform(nonzero_indices_tensors.begin(), nonzero_indices_tensors.end(),
523 std::back_inserter(*tensor_positions),
524 [&nonzero_indices_position](auto &) { return nonzero_indices_position++; });
525 tuple_index_new->insert(tuple_index_new->end(), nonzero_indices_tensors.begin(), nonzero_indices_tensors.end());
526 tensor_indexes->insert(tensor_indexes->end(), nonzero_indices_tensors.begin(), nonzero_indices_tensors.end());
527 } else {
528 MS_EXCEPTION(IndexError) << "The tensor element in tuple index must be int or bool type, but got "
529 << TypeIdToString(tensor_index->data_type(), false);
530 }
531 return true;
532 }
533
GetStrideInfoFromTuple(const ShapeVector & data_shape,const std::vector<TensorIndex> & tuple_index)534 std::tuple<std::vector<std::vector<int64_t>>, std::vector<int64_t>> TensorIndex::GetStrideInfoFromTuple(
535 const ShapeVector &data_shape, const std::vector<TensorIndex> &tuple_index) {
536 const size_t data_dim = data_shape.size();
537 const size_t tuple_index_len = tuple_index.size();
538 const size_t stride_slice_info_size = std::min(tuple_index_len, data_dim);
539 std::vector<int64_t> begin_info(stride_slice_info_size);
540 std::vector<int64_t> end_info(stride_slice_info_size);
541 std::vector<int64_t> step_info(stride_slice_info_size);
542
543 size_t index_count = 0;
544 int64_t shrink_axis = 0;
545 int64_t ellipsis_count = 0;
546
547 for (size_t i = 0; i < stride_slice_info_size; i++) {
548 const TensorIndex &index = tuple_index[i];
549
550 int64_t dim_size = data_shape[i];
551 if (index.IsSlice()) {
552 Slice slice_info = Slice(index.slice(), dim_size);
553 begin_info[i] = slice_info.start();
554 end_info[i] = slice_info.stop();
555 step_info[i] = slice_info.step();
556 index_count += 1;
557 } else if (index.IsInteger()) {
558 const auto mask_bit = 1 << index_count;
559 begin_info[i] = index.integer();
560 end_info[i] = index.integer() + 1;
561 step_info[i] = 1;
562 shrink_axis += mask_bit;
563 index_count += 1;
564 } else if (index.IsEllipsis()) {
565 ellipsis_count = ellipsis_count + 1;
566 if (ellipsis_count > 1) {
567 MS_EXCEPTION(ValueError) << "An Tensor index can have only one ellipsis (...) ";
568 }
569 auto ellipsis_range_size = data_dim - tuple_index_len + 1;
570 for (size_t j = 0; j < ellipsis_range_size; j++) {
571 MS_EXCEPTION_IF_CHECK_FAIL(index_count + j < stride_slice_info_size && index_count + j < data_dim,
572 "Index out of data dims");
573 begin_info[index_count + j] = 0;
574 end_info[index_count + j] = data_shape[index_count + j];
575 step_info[index_count + j] = 1;
576 }
577 index_count += ellipsis_range_size;
578 }
579 }
580
581 int64_t begin_mask = 0;
582 int64_t end_mask = 0;
583
584 for (size_t i = 0; i < tuple_index_len; i++) {
585 if (tuple_index[i].IsSlice()) {
586 Slice slice_info = tuple_index[i].slice();
587 const auto mask_bit = 1 << i;
588 if (slice_info.start_init_by_none()) {
589 begin_mask += mask_bit;
590 }
591 if (slice_info.stop_init_by_none()) {
592 end_mask += mask_bit;
593 }
594 }
595 }
596 for (size_t i = tuple_index_len; i < data_dim; i++) {
597 const auto mask_bit = 1 << i;
598 begin_mask += mask_bit;
599 end_mask += mask_bit;
600 }
601
602 return std::make_tuple(std::vector<std::vector<int64_t>>({begin_info, end_info, step_info}),
603 std::vector<int64_t>({begin_mask, end_mask, shrink_axis}));
604 }
605
GetExpandDimsInfo(const ShapeVector & data_shape,const std::vector<TensorIndex> & index)606 std::tuple<bool, ShapeVector, std::vector<TensorIndex>> TensorIndex::GetExpandDimsInfo(
607 const ShapeVector &data_shape, const std::vector<TensorIndex> &index) {
608 bool need_expand_dims = std::any_of(index.begin(), index.end(), [](auto &x) { return x.IsNone() || x.IsBoolean(); });
609 if (!need_expand_dims) {
610 return std::make_tuple(false, ShapeVector(), std::vector<TensorIndex>());
611 }
612 std::vector<TensorIndex> new_tuple_index;
613 std::vector<int64_t> expand_dims_info;
614 for (size_t i = 0; i < index.size(); i++) {
615 if (index[i].IsNone()) {
616 (void)new_tuple_index.emplace_back(tensor::Slice());
617 (void)expand_dims_info.emplace_back(i);
618 } else if (index[i].IsBoolean()) {
619 if (!index[i].boolean()) {
620 MS_EXCEPTION(IndexError) << "Bool element of tuple index must be 'True', but got 'False'.";
621 }
622 (void)new_tuple_index.emplace_back(std::make_shared<Tensor>(std::vector<int64_t>({0})));
623 (void)expand_dims_info.emplace_back(i);
624 } else {
625 (void)new_tuple_index.emplace_back(index[i]);
626 }
627 }
628 auto reshape_info = data_shape;
629 for (auto dim : expand_dims_info) {
630 dim = std::min(dim, SizeToLong(reshape_info.size()));
631 (void)reshape_info.insert(reshape_info.begin() + dim, 1);
632 }
633
634 return std::make_tuple(need_expand_dims, reshape_info, new_tuple_index);
635 }
636
GenerateIndices(const std::vector<TensorPtr> & tuple_index_new,const std::vector<int64_t> & broadcast_shape,const std::vector<int64_t> & index_tensor_new_shape,const std::vector<int64_t> & final_shape,const std::vector<int64_t> & tensor_positions,const std::vector<int64_t> & slice_shapes,int64_t fancy_position)637 py::object TensorIndex::GenerateIndices(const std::vector<TensorPtr> &tuple_index_new,
638 const std::vector<int64_t> &broadcast_shape,
639 const std::vector<int64_t> &index_tensor_new_shape,
640 const std::vector<int64_t> &final_shape,
641 const std::vector<int64_t> &tensor_positions,
642 const std::vector<int64_t> &slice_shapes, int64_t fancy_position) {
643 py::tuple final_index_tensors(tuple_index_new.size());
644 size_t slice_cnt = 0;
645 for (size_t i = 0; i < tuple_index_new.size(); i++) {
646 if (std::find(tensor_positions.begin(), tensor_positions.end(), i) != tensor_positions.end()) {
647 TensorIndex transform_tensor =
648 BroadCastTensor(broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i]);
649 final_index_tensors[i] = transform_tensor.array();
650 } else {
651 TensorIndex slice_index_tensor =
652 SliceToArray(tuple_index_new[i], final_shape, slice_cnt, broadcast_shape, slice_shapes, fancy_position);
653
654 final_index_tensors[i] = slice_index_tensor.array();
655 slice_cnt += 1;
656 }
657 }
658 return TensorIndex::np_module_.attr("array")(TensorIndex::np_module_.attr("stack")(final_index_tensors, -1));
659 }
660
TensorGetitemByTuple(const ShapeVector & data_shape,const std::vector<TensorIndex> & tuple_index,std::vector<int64_t> * data_transfer_types,std::vector<py::object> * data_transfer_args)661 py::object TensorIndex::TensorGetitemByTuple(const ShapeVector &data_shape, const std::vector<TensorIndex> &tuple_index,
662 std::vector<int64_t> *data_transfer_types,
663 std::vector<py::object> *data_transfer_args) {
664 size_t data_dims = data_shape.size();
665 std::vector<TensorPtr> tensor_indexes;
666 std::vector<TensorPtr> tuple_index_new;
667 std::vector<int64_t> slice_shapes;
668 std::vector<int64_t> tensor_positions;
669 size_t tuple_index_len = tuple_index.size();
670 bool empty_mask_tensor = false;
671 const size_t min_length = std::min(data_dims, tuple_index_len);
672 for (size_t i = 0; i < min_length; i++) {
673 int64_t dim_size = data_shape[i];
674 const TensorIndex &index = tuple_index[i];
675
676 if (index.IsInteger()) {
677 int64_t int_index = index.integer();
678 if (int_index >= dim_size || int_index < -dim_size) {
679 MS_EXCEPTION(IndexError) << "Index " << int_index << " is out of bounds for dimension with size " << dim_size;
680 }
681 int_index = CheckRange(int_index, dim_size);
682 TensorPtr tensor_index = std::make_shared<Tensor>(int_index);
683 (void)tensor_positions.emplace_back(tuple_index_new.size());
684 (void)tuple_index_new.emplace_back(tensor_index);
685 (void)tensor_indexes.emplace_back(tensor_index);
686 } else if (index.IsSequence()) {
687 TensorIndex sequence_list = SequenceToTensor(index, data_shape[i]);
688 TensorPtr tensor_index = sequence_list.tensor();
689 (void)tensor_positions.emplace_back(tuple_index_new.size());
690 (void)tuple_index_new.emplace_back(tensor_index);
691 (void)tensor_indexes.emplace_back(tensor_index);
692 } else if (index.IsTensor()) {
693 const TensorPtr &tensor_index = index.tensor();
694 if (!TensorGetitemByTupleParseTensorIndex(data_shape, tensor_index, &tuple_index_new, &tensor_indexes,
695 &tensor_positions, false)) {
696 TensorPtr new_tensor_index = std::make_shared<Tensor>(kNumberTypeInt32, ShapeVector({0}));
697 for (int j = 0; j < tensor_index->DataDim(); j++) {
698 (void)tensor_positions.emplace_back(tuple_index_new.size());
699 (void)tuple_index_new.emplace_back(new_tensor_index);
700 (void)tensor_indexes.emplace_back(new_tensor_index);
701 }
702 empty_mask_tensor = true;
703 }
704 } else if (index.IsSlice()) {
705 Slice slice_info = Slice(index.slice(), dim_size);
706 int64_t start = slice_info.start();
707 int64_t stop = slice_info.stop();
708 int64_t step = slice_info.step();
709
710 std::vector<int64_t> slice_ele_list_index;
711 for (int64_t j = start; j < stop; j += step) {
712 (void)slice_ele_list_index.emplace_back(j);
713 }
714 (void)slice_shapes.emplace_back(SizeToLong(slice_ele_list_index.size()));
715 (void)tuple_index_new.emplace_back(std::make_shared<Tensor>(slice_ele_list_index));
716 }
717 }
718 tuple_index_len = tuple_index.size();
719 std::vector<ShapeVector> tensor_indexes_shapes;
720 (void)std::transform(
721 tensor_indexes.begin(), tensor_indexes.end(), std::back_inserter(tensor_indexes_shapes), [](auto &tensor_index) {
722 if (tensor_index == nullptr) {
723 MS_EXCEPTION(IndexError) << "IndexError: The sequence element(tuple/list) in tuple index can't be empty.";
724 }
725 return tensor_index->shape();
726 });
727 std::tuple<ShapeVector, ShapeVector, ShapeVector, int64_t> index_info = GenerateIndexInfoFromTupleOfMixedTensors(
728 tensor_positions, tensor_indexes_shapes, slice_shapes, TensorIndex(py::none()));
729 constexpr size_t broadcast_shape_index = 0;
730 constexpr size_t index_tensor_new_shape_index = 1;
731 constexpr size_t final_shape_index = 2;
732 constexpr size_t fancy_position_index = 3;
733 ShapeVector broadcast_shape = std::get<broadcast_shape_index>(index_info);
734 ShapeVector index_tensor_new_shape = std::get<index_tensor_new_shape_index>(index_info);
735 ShapeVector final_shape = std::get<final_shape_index>(index_info);
736 int64_t fancy_position = std::get<fancy_position_index>(index_info);
737 if (empty_mask_tensor) {
738 (void)data_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kEmptyTensor));
739 (void)data_transfer_args->emplace_back(VectorToPyTuple(final_shape));
740 return py::make_tuple(py::none(), VectorToPyTuple(*data_transfer_types), VectorToPyTuple(*data_transfer_args));
741 }
742 if (std::find(final_shape.begin(), final_shape.end(), 0) != final_shape.end() ||
743 std::find(data_shape.begin(), data_shape.end(), 0) != data_shape.end()) {
744 if (tuple_index_len < data_dims) {
745 (void)final_shape.insert(final_shape.end(), data_shape.begin() + SizeToLong(tuple_index_len), data_shape.end());
746 }
747 data_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kEmptyTensor));
748 data_transfer_args->emplace_back(VectorToPyTuple(final_shape));
749 return py::make_tuple(py::none(), VectorToPyTuple(*data_transfer_types), VectorToPyTuple(*data_transfer_args));
750 }
751
752 data_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kGatherND));
753 data_transfer_args->emplace_back(py::make_tuple(
754 VectorToPyTuple(broadcast_shape), VectorToPyTuple(final_shape), VectorToPyTuple(index_tensor_new_shape),
755 VectorToPyTuple(slice_shapes), VectorToPyTuple(tensor_positions), fancy_position));
756 if (CheckLargeTensor(data_shape)) {
757 return py::make_tuple(tuple_index_new, VectorToPyTuple(*data_transfer_types), VectorToPyTuple(*data_transfer_args));
758 }
759 py::array new_index = GenerateIndices(tuple_index_new, broadcast_shape, index_tensor_new_shape, final_shape,
760 tensor_positions, slice_shapes, fancy_position);
761 return py::make_tuple(TensorPy::MakeTensor(CastToInt(new_index)), VectorToPyTuple(*data_transfer_types),
762 VectorToPyTuple(*data_transfer_args));
763 }
764
765 // ***********************************************for set_item*******************************************
FormatList(const TensorIndex & tensor_index,int64_t length)766 TensorIndex TensorIndex::FormatList(const TensorIndex &tensor_index, int64_t length) {
767 bool transform_to_array = std::all_of(tensor_index.list_.begin(), tensor_index.list_.end(), [](auto &x) {
768 return py::isinstance<py::int_>(x) || py::isinstance<py::bool_>(x);
769 });
770 if (transform_to_array) {
771 return SequenceToTensor<py::list>(tensor_index.list_, length);
772 }
773 return TensorIndex(DeepList(tensor_index.list_, length).cast<py::tuple>());
774 }
775
IntToTensor(int64_t int_index,const ShapeVector & shape)776 TensorPtr TensorIndex::IntToTensor(int64_t int_index, const ShapeVector &shape) {
777 int64_t dim_size = shape[0];
778 auto out_i = static_cast<int32_t>(CheckRange(int_index, dim_size));
779 if (shape.size() == 1) {
780 return std::make_shared<Tensor>(kNumberTypeInt32, ShapeVector({1, 1}), &out_i, int32_bytes_number);
781 }
782
783 ShapeVector index_shape(shape.begin() + 1, shape.end());
784 int64_t grids_size = SizeToLong(shape.size()) - 1;
785 py::tuple grids(grids_size);
786 for (size_t i = 1; i < shape.size(); i++) {
787 grids[i - 1] =
788 TensorIndex::np_module_.attr("arange")(0, py::int_(shape[i]), 1, TensorIndex::np_module_.attr("int32"));
789 }
790 py::object mesh = TensorIndex::np_module_.attr("ix_")(*grids);
791 py::tuple index(SizeToLong(shape.size()));
792 index[0] =
793 TensorIndex::np_module_.attr("full")(py::cast(index_shape), py::int_(out_i), TensorIndex::np_module_.attr("int32"));
794 py::tuple broadcast_mesh = TensorIndex::np_module_.attr("broadcast_arrays")(*mesh);
795 for (size_t i = 1; i < shape.size(); i++) {
796 index[i] = broadcast_mesh[i - 1];
797 }
798 py::object output_index = TensorIndex::np_module_.attr("stack")(index, -1);
799 return TensorPy::MakeTensor(TensorIndex::np_module_.attr("array")(output_index));
800 }
801
GenerateIndicesFromTupleOfTensor(const ShapeVector & data_shape,const std::vector<TensorIndex> & tuple_index,ShapeVector * output_index_shape,py::object * data_transfer_arg)802 py::object TensorIndex::GenerateIndicesFromTupleOfTensor(const ShapeVector &data_shape,
803 const std::vector<TensorIndex> &tuple_index,
804 ShapeVector *output_index_shape,
805 py::object *data_transfer_arg) {
806 std::vector<ShapeVector> tensor_index_shape;
807 std::vector<TensorPtr> tuple_index_vector;
808 for (const auto &index : tuple_index) {
809 TensorPtr index_tensor = index.tensor();
810 MS_EXCEPTION_IF_NULL(index_tensor);
811 (void)tuple_index_vector.emplace_back(index_tensor);
812 if (!CheckTypeIsInstance<TypeId>(index_tensor->data_type(), kIntTypes)) {
813 string index_op_type = index_op_type_ == IndexOpType::GetItem ? "tensor getitem" : "tensor setitem";
814 MS_EXCEPTION(IndexError) << "For '" << index_op_type << "', the index tensor data type '"
815 << index_tensor->data_type() << "' is not supported.";
816 }
817 }
818 (void)std::transform(tuple_index_vector.begin(), tuple_index_vector.end(), std::back_inserter(tensor_index_shape),
819 [](const TensorPtr &x) { return x->shape(); });
820 ShapeVector broadcast_shape = BroadCastShape(tensor_index_shape);
821
822 constexpr int64_t min_broadcast_shape_size = 2;
823 if (SizeToLong(broadcast_shape.size()) < min_broadcast_shape_size) {
824 (void)broadcast_shape.insert(broadcast_shape.begin(), 1);
825 }
826
827 *output_index_shape = broadcast_shape;
828 output_index_shape->emplace_back(tuple_index.size());
829 if (CheckLargeTensor(data_shape)) {
830 *data_transfer_arg = py::make_tuple(VectorToPyTuple(broadcast_shape));
831 return VectorToPyTuple(tuple_index_vector);
832 }
833
834 std::vector<py::array> broadcast_tensors;
835 (void)std::transform(tuple_index.begin(), tuple_index.end(), std::back_inserter(broadcast_tensors),
836 [&broadcast_shape](auto &index) {
837 return TensorIndex::np_module_.attr("broadcast_to")(
838 CastToInt(TensorPy::SyncAsNumpy(*index.tensor())), broadcast_shape);
839 });
840 py::array output_index = TensorIndex::np_module_.attr("stack")(py::cast(broadcast_tensors), -1);
841 return py::cast(TensorPy::MakeTensor(TensorIndex::np_module_.attr("array")(output_index)));
842 }
843
RemNotExpandedDims(int64_t * idx_advanced,bool expand_true,int64_t tensor_index_ndim,int64_t rem_ndim,std::vector<bool> * not_expanded_dim)844 void TensorIndex::RemNotExpandedDims(int64_t *idx_advanced, bool expand_true, int64_t tensor_index_ndim,
845 int64_t rem_ndim, std::vector<bool> *not_expanded_dim) {
846 if (*idx_advanced != -1) {
847 std::vector<bool> tensor_dims(tensor_index_ndim, true);
848 if (expand_true) {
849 tensor_dims = {false};
850 }
851 *idx_advanced = std::min(*idx_advanced, SizeToLong(not_expanded_dim->size()));
852 not_expanded_dim->insert(not_expanded_dim->begin() + *idx_advanced, tensor_dims.begin(), tensor_dims.end());
853 }
854 std::vector<bool> rem_ndim_vector(rem_ndim, true);
855 not_expanded_dim->insert(not_expanded_dim->end(), rem_ndim_vector.begin(), rem_ndim_vector.end());
856 size_t count_leading_false = 0;
857 while (count_leading_false < not_expanded_dim->size() && !((*not_expanded_dim)[count_leading_false])) {
858 count_leading_false += 1;
859 }
860 *idx_advanced = std::max(static_cast<int64_t>(0), *idx_advanced - SizeToLong(count_leading_false));
861 }
862
FormatIndex(const TensorIndex & idx,const ShapeVector & data_shape,size_t cur_dim,bool * need_format)863 TensorIndex TensorIndex::FormatIndex(const TensorIndex &idx, const ShapeVector &data_shape, size_t cur_dim,
864 bool *need_format) {
865 if (!CheckTypeIsInstance<TensorIndexType>(idx.type(), {TensorIndexType::List, TensorIndexType::Tuple,
866 TensorIndexType::Integer, TensorIndexType::Tensor})) {
867 return idx;
868 }
869 MS_EXCEPTION_IF_CHECK_FAIL(cur_dim < data_shape.size(), "Index" + std::to_string(cur_dim) + "out of data dims" +
870 std::to_string(data_shape.size()));
871 int64_t dims_size = data_shape[cur_dim];
872 if (idx.IsSequence()) {
873 return SequenceToTensor(idx, dims_size);
874 } else if (idx.IsInteger()) {
875 return TensorIndex(std::make_shared<Tensor>(CheckRange(idx.integer(), dims_size)));
876 }
877 const TensorPtr &tensor_idx = idx.tensor();
878 MS_EXCEPTION_IF_NULL(tensor_idx);
879 if (CheckTypeIsInstance<TypeId>(tensor_idx->data_type(), kIntTypes)) {
880 if (CheckLargeTensor(data_shape)) {
881 *need_format = true;
882 return idx;
883 }
884 py::array new_idx = TensorPy::SyncAsNumpy(*tensor_idx);
885 if (tensor_idx->DataDim() == 0) {
886 auto new_int_idx = new_idx.cast<int64_t>();
887 new_int_idx = new_int_idx < 0 ? new_int_idx + dims_size : new_int_idx;
888 return TensorIndex(std::make_shared<Tensor>(new_int_idx));
889 }
890 // numpy op select is very slow for one dim array
891 new_idx = TensorIndex::np_module_.attr("expand_dims")(new_idx, 0);
892 new_idx = TensorIndex::np_module_.attr("select")(TensorIndex::np_module_.attr("less")(new_idx, 0),
893 TensorIndex::np_module_.attr("add")(new_idx, py::int_(dims_size)),
894 new_idx);
895 new_idx = TensorIndex::np_module_.attr("squeeze")(new_idx, 0);
896 return TensorIndex(TensorPy::MakeTensor(CastToInt(new_idx)));
897 } else if (tensor_idx->data_type() != kNumberTypeBool) {
898 string index_op_type = index_op_type_ == IndexOpType::GetItem ? "tensor getitem" : "tensor setitem";
899 MS_EXCEPTION(IndexError) << "For '" << index_op_type << "', the index tensor data type '"
900 << TypeIdToString(tensor_idx->data_type(), false) << "' is not supported.";
901 }
902 return idx;
903 }
904
RemoveExpandedDimsParseTensorIndex(const ShapeVector & data_shape,const TensorPtr & index_out,std::vector<TensorIndex> * indices_out,std::vector<ShapeVector> * shapes,bool * has_sequence,size_t * cur_dim,bool check_align)905 bool TensorIndex::RemoveExpandedDimsParseTensorIndex(const ShapeVector &data_shape, const TensorPtr &index_out,
906 std::vector<TensorIndex> *indices_out,
907 std::vector<ShapeVector> *shapes, bool *has_sequence,
908 size_t *cur_dim, bool check_align) {
909 // Parse tensor_index
910 MS_EXCEPTION_IF_NULL(index_out);
911 if (index_out->data_type() == kNumberTypeBool) {
912 std::vector<TensorPtr> nonzero_indices_tensors = GenerateNonZeroIndexTensorList(data_shape, index_out, check_align);
913 if (nonzero_indices_tensors.empty()) {
914 return false;
915 }
916 std::vector<TensorIndex> true_index_tensors;
917 (void)std::transform(nonzero_indices_tensors.begin(), nonzero_indices_tensors.end(),
918 std::back_inserter(true_index_tensors),
919 [](const TensorPtr &true_index) { return TensorIndex(true_index); });
920 size_t true_index_nums = nonzero_indices_tensors.size();
921 indices_out->insert(indices_out->end(), true_index_tensors.begin(), true_index_tensors.end());
922 MS_EXCEPTION_IF_NULL(nonzero_indices_tensors[0]);
923 std::vector<ShapeVector> true_index_shapes(true_index_nums, {nonzero_indices_tensors[0]->shape()});
924 shapes->insert(shapes->end(), true_index_shapes.begin(), true_index_shapes.end());
925 *cur_dim += true_index_nums;
926 } else {
927 if (index_out->DataDim() > 0) {
928 *has_sequence = true;
929 }
930 indices_out->emplace_back(index_out);
931 shapes->emplace_back(index_out->shape());
932 *cur_dim += 1;
933 }
934 return true;
935 }
936
RemoveExpandedDims(const std::vector<TensorIndex> & indices,const ShapeVector & data_shape,const ShapeVector & value_shape,std::vector<int64_t> * value_transfer_types,std::vector<py::object> * value_transfer_args,int64_t * idx_advanced,bool * by_pass,std::vector<size_t> * format_index,std::vector<int64_t> * format_dim)937 std::pair<std::vector<TensorIndex>, ShapeVector> TensorIndex::RemoveExpandedDims(
938 const std::vector<TensorIndex> &indices, const ShapeVector &data_shape, const ShapeVector &value_shape,
939 std::vector<int64_t> *value_transfer_types, std::vector<py::object> *value_transfer_args, int64_t *idx_advanced,
940 bool *by_pass, std::vector<size_t> *format_index, std::vector<int64_t> *format_dim) {
941 // Removes expanded dimensions in tuple_index and value.
942 size_t cur_dim = 0;
943 bool has_true = false;
944 bool has_false = false;
945 bool has_sequence = false;
946 int64_t idx_tensor = -1;
947 std::vector<bool> not_expanded_dim;
948 std::vector<TensorIndex> indices_out;
949 std::vector<ShapeVector> shapes;
950
951 for (size_t i = 0; i < indices.size(); i++) {
952 const TensorIndex &v = indices[i];
953 bool need_format = false;
954 TensorIndex index_out = TensorIndex::FormatIndex(v, data_shape, cur_dim, &need_format);
955 if (need_format) {
956 (void)format_index->emplace_back(cur_dim);
957 (void)format_dim->emplace_back(data_shape[cur_dim]);
958 }
959 if (index_out.IsNone()) {
960 (void)not_expanded_dim.emplace_back(false);
961 } else if (index_out.IsSlice()) {
962 (void)indices_out.emplace_back(index_out);
963 (void)not_expanded_dim.emplace_back(true);
964 Slice slice_info = Slice(v.slice(), data_shape[cur_dim]);
965
966 int64_t start = slice_info.start();
967 int64_t stop = slice_info.stop();
968 int64_t step = slice_info.step();
969 has_false = ((start - stop) * step > 0) || has_false;
970 cur_dim += 1;
971 } else if (index_out.IsBoolean() || index_out.IsTensor()) {
972 if (*idx_advanced == -1) {
973 *idx_advanced = SizeToLong(not_expanded_dim.size());
974 } else if (static_cast<int64_t>(i) - idx_tensor > 1) {
975 *idx_advanced = 0;
976 }
977 idx_tensor = static_cast<int64_t>(i);
978 if (index_out.IsTensor()) {
979 const TensorPtr &index_out_tensor = index_out.tensor();
980 if (!RemoveExpandedDimsParseTensorIndex(data_shape, index_out_tensor, &indices_out, &shapes, &has_sequence,
981 &cur_dim, false)) {
982 *by_pass = true;
983 *idx_advanced = 0;
984 return {std::vector<TensorIndex>(), ShapeVector()};
985 }
986 } else {
987 bool bool_index_out = index_out.boolean();
988 has_true = bool_index_out || has_true;
989 has_false = !bool_index_out || has_false;
990 }
991 } else {
992 MS_EXCEPTION(IndexError) << "Invalid index type, index: " << TensorIndex::py_index_handle_;
993 }
994 }
995
996 ShapeVector broadcast_shape = BroadCastShape(shapes);
997 if (has_false) {
998 if (std::accumulate(broadcast_shape.begin(), broadcast_shape.end(), 1, std::multiplies<>()) != 1) {
999 MS_EXCEPTION(IndexError) << "Unable to broadcast indices " << broadcast_shape;
1000 }
1001 *by_pass = true;
1002 return std::make_pair(std::vector<TensorIndex>(), ShapeVector());
1003 }
1004
1005 bool expand_true = has_true && !(has_false || has_sequence);
1006 int64_t tensor_index_ndim = SizeToLong(broadcast_shape.size());
1007 int64_t rem_ndim = SizeToLong(data_shape.size()) - SizeToLong(cur_dim);
1008 RemNotExpandedDims(idx_advanced, expand_true, tensor_index_ndim, rem_ndim, ¬_expanded_dim);
1009 if (indices_out.empty()) {
1010 indices_out = {TensorIndex(py::bool_(true))};
1011 }
1012 value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kReshape));
1013 ShapeVector reshape_info = FilterExpandedDims(value_shape, not_expanded_dim);
1014 value_transfer_args->emplace_back(py::cast(reshape_info));
1015 *by_pass = false;
1016 return std::make_pair(indices_out, reshape_info);
1017 }
1018
GenerateIndicesFromTuple(const ShapeVector & data_shape,const std::vector<TensorIndex> & tuple_index,int64_t py_fancy_position,bool * by_pass,ShapeVector * output_index_shape,py::object * data_transfer_arg)1019 py::object TensorIndex::GenerateIndicesFromTuple(const ShapeVector &data_shape,
1020 const std::vector<TensorIndex> &tuple_index, int64_t py_fancy_position,
1021 bool *by_pass, ShapeVector *output_index_shape,
1022 py::object *data_transfer_arg) {
1023 std::vector<TensorPtr> tensor_indexes;
1024 std::vector<TensorPtr> tuple_index_new;
1025 std::vector<int64_t> slice_shapes;
1026 std::vector<int64_t> tensor_positions;
1027 std::vector<ShapeVector> tensor_indexes_shapes;
1028 const size_t min_length = std::min(data_shape.size(), tuple_index.size());
1029 for (size_t i = 0; i < min_length; i++) {
1030 const TensorIndex &index = tuple_index[i];
1031 int64_t dim_size = data_shape[i];
1032
1033 if (index.IsInteger()) {
1034 int64_t int_index = index.integer();
1035 if (int_index >= dim_size || int_index < -dim_size) {
1036 MS_EXCEPTION(IndexError) << "Index " << int_index << " is out of bounds for dimension with size " << dim_size;
1037 }
1038 int_index = CheckRange(int_index, dim_size);
1039 TensorPtr tensor_index = std::make_shared<Tensor>(int_index);
1040 MS_EXCEPTION_IF_NULL(tensor_index);
1041 (void)tuple_index_new.emplace_back(tensor_index);
1042 (void)tensor_indexes.emplace_back(tensor_index);
1043 (void)tensor_positions.emplace_back(i);
1044 (void)tensor_indexes_shapes.emplace_back(tensor_index->shape());
1045 } else if (index.IsSequence()) {
1046 TensorIndex sequence_list = SequenceToTensor(index, data_shape[i]);
1047 TensorPtr tensor_index = sequence_list.tensor();
1048 (void)tuple_index_new.emplace_back(tensor_index);
1049 (void)tensor_indexes.emplace_back(tensor_index);
1050 (void)tensor_positions.emplace_back(i);
1051 MS_EXCEPTION_IF_NULL(tensor_index);
1052 (void)tensor_indexes_shapes.emplace_back(tensor_index->shape());
1053 } else if (index.IsTensor()) {
1054 TensorPtr tensor_index = index.tensor();
1055 if (!CheckTypeIsInstance<TypeId>(tensor_index->data_type(), kIntTypes)) {
1056 MS_EXCEPTION(TypeError) << "The tensor element in tuple index must be int type, but got "
1057 << tensor_index->data_type();
1058 }
1059 (void)tuple_index_new.emplace_back(tensor_index);
1060 (void)tensor_indexes.emplace_back(tensor_index);
1061 (void)tensor_positions.emplace_back(i);
1062 (void)tensor_indexes_shapes.emplace_back(tensor_index->shape());
1063 } else if (index.IsSlice()) {
1064 Slice slice_info = Slice(index.slice(), dim_size);
1065 int64_t start = slice_info.start();
1066 int64_t stop = slice_info.stop();
1067 int64_t step = slice_info.step();
1068 if ((start - stop) * step >= 0) {
1069 *by_pass = true;
1070 return py::none();
1071 }
1072 std::vector<int64_t> slice_ele_list_index = SliceToVector(start, stop, step);
1073 (void)slice_shapes.emplace_back(SizeToLong(slice_ele_list_index.size()));
1074 (void)tuple_index_new.emplace_back(std::make_shared<Tensor>(slice_ele_list_index));
1075 }
1076 }
1077
1078 std::tuple<ShapeVector, ShapeVector, ShapeVector, int64_t> index_info = GenerateIndexInfoFromTupleOfMixedTensors(
1079 tensor_positions, tensor_indexes_shapes, slice_shapes, TensorIndex(py_fancy_position));
1080 constexpr size_t k_broadcast_shape_index = 0;
1081 constexpr size_t index_tensor_new_shape_index = 1;
1082 constexpr size_t final_shape_index = 2;
1083 constexpr size_t fancy_position_index = 3;
1084 ShapeVector broadcast_shape = std::get<k_broadcast_shape_index>(index_info);
1085 ShapeVector index_tensor_new_shape = std::get<index_tensor_new_shape_index>(index_info);
1086 ShapeVector final_shape = std::get<final_shape_index>(index_info);
1087 *output_index_shape = final_shape;
1088 output_index_shape->emplace_back(tuple_index_new.size());
1089 int64_t fancy_position = std::get<fancy_position_index>(index_info);
1090 if (CheckLargeTensor(data_shape)) {
1091 *data_transfer_arg = py::make_tuple(VectorToPyTuple(broadcast_shape), VectorToPyTuple(final_shape),
1092 VectorToPyTuple(index_tensor_new_shape), VectorToPyTuple(slice_shapes),
1093 VectorToPyTuple(tensor_positions), fancy_position);
1094 return VectorToPyTuple(tuple_index_new);
1095 }
1096 py::array output_index = GenerateIndices(tuple_index_new, broadcast_shape, index_tensor_new_shape, final_shape,
1097 tensor_positions, slice_shapes, fancy_position);
1098 return py::cast(TensorPy::MakeTensor(TensorIndex::np_module_.attr("array")(output_index)));
1099 }
1100
ReSetitemByTensor(const std::vector<TensorIndex> & new_tuple_index,const std::vector<int64_t> & value_transfer_types,const std::vector<py::object> & value_transfer_args)1101 py::object TensorIndex::ReSetitemByTensor(const std::vector<TensorIndex> &new_tuple_index,
1102 const std::vector<int64_t> &value_transfer_types,
1103 const std::vector<py::object> &value_transfer_args) {
1104 py::object output_py_index;
1105 if (new_tuple_index[0].IsSlice()) {
1106 Slice slice_info = new_tuple_index[0].slice();
1107 output_py_index = py::slice(slice_info.start(), slice_info.stop(), slice_info.step());
1108 } else if (new_tuple_index[0].IsTensor()) {
1109 output_py_index = py::cast(new_tuple_index[0].tensor());
1110 } else {
1111 output_py_index = py::cast(new_tuple_index[0].boolean());
1112 }
1113 return py::make_tuple(
1114 output_py_index, VectorToPyTuple<int64_t>(value_transfer_types), VectorToPyTuple<py::object>(value_transfer_args),
1115 py::make_tuple(static_cast<int>(ValueTransferType::kReSetItemByIndex)), py::make_tuple(py::none()));
1116 }
1117
SetitemByTupleWithTensor(const ShapeVector & data_shape,const std::vector<TensorIndex> & indices,const ShapeVector & value_shape,std::vector<int64_t> * value_transfer_types,std::vector<py::object> * value_transfer_args)1118 py::object TensorIndex::SetitemByTupleWithTensor(const ShapeVector &data_shape, const std::vector<TensorIndex> &indices,
1119 const ShapeVector &value_shape,
1120 std::vector<int64_t> *value_transfer_types,
1121 std::vector<py::object> *value_transfer_args) {
1122 std::vector<TensorIndex> new_indices = TransformEllipsisToSlice(data_shape, indices);
1123 ValueTransferType tensor_update_type = ValueTransferType::kTensorScatterUpdate;
1124 if (UseCopySlice(new_indices, SizeToLong(data_shape.size())) && !TensorIndex::is_ascend_) {
1125 Slice slice_info = Slice(new_indices[1].slice(), data_shape[1]);
1126 int64_t dim1_start = slice_info.start();
1127 int64_t dim1_stop = slice_info.stop();
1128 if (dim1_stop - dim1_start <= 0) {
1129 tensor_update_type = ValueTransferType::kByPass;
1130 return py::make_tuple(py::none(), VectorToPyTuple<int64_t>(*value_transfer_types),
1131 VectorToPyTuple<py::object>(*value_transfer_args),
1132 py::make_tuple(static_cast<int>(tensor_update_type)), py::make_tuple(py::none()));
1133 }
1134 if (data_shape.empty()) {
1135 MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
1136 }
1137 int64_t dim0_start =
1138 new_indices[0].integer() >= 0 ? new_indices[0].integer() : new_indices[0].integer() + data_shape[0];
1139 py::tuple start = py::make_tuple(dim0_start, dim1_start);
1140 py::tuple stop = py::make_tuple(dim0_start + 1, dim1_stop);
1141 py::tuple step = py::make_tuple(1, 1);
1142
1143 ShapeVector new_value_shape = {dim1_stop - dim1_start};
1144 constexpr int64_t start_position_of_data_shape = 2;
1145 (void)new_value_shape.insert(new_value_shape.end(), data_shape.begin() + start_position_of_data_shape,
1146 data_shape.end());
1147 value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kBroadCast));
1148 value_transfer_args->emplace_back(VectorToPyTuple(new_value_shape));
1149 value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kCast));
1150 value_transfer_args->emplace_back(py::none());
1151 tensor_update_type = ValueTransferType::kCopySlice;
1152 return py::make_tuple(
1153 py::none(), VectorToPyTuple<int64_t>(*value_transfer_types), VectorToPyTuple<py::object>(*value_transfer_args),
1154 py::make_tuple(static_cast<int>(tensor_update_type)), py::make_tuple(py::make_tuple(start, stop, step)));
1155 }
1156 int64_t idx_advanced = -1;
1157 bool by_pass = false;
1158 std::vector<size_t> format_index;
1159 std::vector<int64_t> format_dim;
1160 std::pair<std::vector<TensorIndex>, ShapeVector> tuple_index_info =
1161 RemoveExpandedDims(new_indices, data_shape, value_shape, value_transfer_types, value_transfer_args, &idx_advanced,
1162 &by_pass, &format_index, &format_dim);
1163 if (by_pass) {
1164 tensor_update_type = ValueTransferType::kByPass;
1165 return py::make_tuple(py::none(), VectorToPyTuple<int64_t>(*value_transfer_types),
1166 VectorToPyTuple<py::object>(*value_transfer_args),
1167 py::make_tuple(static_cast<int>(tensor_update_type)), py::make_tuple(py::none()));
1168 }
1169
1170 MS_LOG(DEBUG) << "After remove expand dims: " << tuple_index_info.first;
1171
1172 std::vector<TensorIndex> new_tuple_index = tuple_index_info.first;
1173 ShapeVector new_value_shape = tuple_index_info.second;
1174
1175 if (new_tuple_index.size() == 1) {
1176 return ReSetitemByTensor(new_tuple_index, *value_transfer_types, *value_transfer_args);
1177 }
1178 py::object output_index;
1179 ShapeVector output_index_shape;
1180 py::object data_transfer_args = py::none();
1181 if (std::all_of(new_tuple_index.begin(), new_tuple_index.end(), [](const TensorIndex &x) { return x.IsTensor(); })) {
1182 output_index =
1183 GenerateIndicesFromTupleOfTensor(data_shape, new_tuple_index, &output_index_shape, &data_transfer_args);
1184 } else {
1185 by_pass = false;
1186 output_index = GenerateIndicesFromTuple(data_shape, new_tuple_index, idx_advanced, &by_pass, &output_index_shape,
1187 &data_transfer_args);
1188 if (by_pass) {
1189 tensor_update_type = ValueTransferType::kByPass;
1190 return py::make_tuple(py::none(), VectorToPyTuple<int64_t>(*value_transfer_types),
1191 VectorToPyTuple<py::object>(*value_transfer_args),
1192 py::make_tuple(static_cast<int>(tensor_update_type)), py::make_tuple(py::none()));
1193 }
1194 }
1195
1196 value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kCast));
1197 value_transfer_args->emplace_back(py::make_tuple());
1198 ShapeVector updates_shape(output_index_shape.begin(), output_index_shape.end() - 1);
1199
1200 if (output_index_shape.back() < SizeToLong(data_shape.size())) {
1201 (void)updates_shape.insert(updates_shape.end(), data_shape.begin() + output_index_shape.back(), data_shape.end());
1202 }
1203
1204 if (updates_shape != new_value_shape) {
1205 value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kBroadCast));
1206 value_transfer_args->emplace_back(VectorToPyTuple(updates_shape));
1207 }
1208 std::vector<int> tensor_update_types{static_cast<int>(tensor_update_type)};
1209 std::vector<py::object> tensor_update_args{data_transfer_args};
1210 if (!format_index.empty()) {
1211 (void)tensor_update_types.insert(tensor_update_types.begin(),
1212 static_cast<int>(ValueTransferType::kFormatIndexTensor));
1213 (void)tensor_update_args.insert(tensor_update_args.begin(), py::make_tuple(VectorToPyTuple<size_t>(format_index),
1214 VectorToPyTuple<int64_t>(format_dim)));
1215 }
1216 if (py::isinstance<py::tuple>(output_index)) {
1217 return py::make_tuple(py::cast<py::list>(output_index), VectorToPyTuple<int64_t>(*value_transfer_types),
1218 VectorToPyTuple<py::object>(*value_transfer_args), VectorToPyTuple<int>(tensor_update_types),
1219 VectorToPyTuple<py::object>(tensor_update_args));
1220 }
1221 return py::make_tuple(py::cast<TensorPtr>(output_index), VectorToPyTuple<int64_t>(*value_transfer_types),
1222 VectorToPyTuple<py::object>(*value_transfer_args), VectorToPyTuple<int>(tensor_update_types),
1223 VectorToPyTuple<py::object>(tensor_update_args));
1224 }
1225
GetStubTensorValue(const py::handle & obj)1226 ValuePtr GetStubTensorValue(const py::handle &obj) {
1227 auto py_stub = py::getattr(obj, stub::PY_ATTR_STUB);
1228 ValuePtr stub = py_stub.cast<stub::StubNodePtr>();
1229 if (stub == nullptr) {
1230 auto tensor_ptr = py::getattr(obj, stub::PY_ATTR_TENSOR).cast<tensor::TensorPtr>();
1231 MS_EXCEPTION_IF_NULL(tensor_ptr);
1232 stub = tensor_ptr;
1233 }
1234 return stub;
1235 }
1236
SqueezeRDataValue(const TensorPtr & tensor,const py::handle & py_value,const ValuePtr & rdata_value)1237 ValuePtr SqueezeRDataValue(const TensorPtr &tensor, const py::handle &py_value, const ValuePtr &rdata_value) {
1238 auto rdata_shape = tensor->shape();
1239 if (rdata_shape.size() >= 1 && (rdata_shape.at(0) > 1 || rdata_shape.size() > 1)) {
1240 MS_EXCEPTION(ValueError)
1241 << "For SetItem, the shape of right value must be () or (1, ) when shape of left value is 0, but got"
1242 << rdata_shape;
1243 } else if (rdata_shape.size() == 1 && rdata_shape.at(0) == 1) {
1244 auto new_value = py::cast<py::list>(py_value);
1245 auto first_value = new_value[0];
1246 ValuePtr result =
1247 IsStubTensor(first_value) ? GetStubTensorValue(first_value) : first_value.cast<tensor::TensorPtr>();
1248 return result;
1249 }
1250 return rdata_value;
1251 }
1252
SetitemCopyView(std::vector<pynative::SliceOpInfoPtr> * slice_op_infos,const ValuePtr data_value,const std::vector<int64_t> & new_data_shape,const TypePtr & data_type,const py::handle & py_value)1253 static inline py::object SetitemCopyView(std::vector<pynative::SliceOpInfoPtr> *slice_op_infos,
1254 const ValuePtr data_value, const std::vector<int64_t> &new_data_shape,
1255 const TypePtr &data_type, const py::handle &py_value) {
1256 auto cast_op_info = std::make_shared<pynative::SliceOpInfo>();
1257 cast_op_info->slice_op_name = prim::kPrimCast->name();
1258 (void)cast_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(data_type->type_id()));
1259 cast_op_info->data_indexs = {1};
1260 (void)slice_op_infos->emplace_back(cast_op_info);
1261
1262 auto broadcastto_op_info = std::make_shared<pynative::SliceOpInfo>();
1263 broadcastto_op_info->slice_op_name = prim::kPrimBroadcastTo->name();
1264 (void)broadcastto_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(new_data_shape));
1265 broadcastto_op_info->data_indexs = {1};
1266 (void)slice_op_infos->emplace_back(broadcastto_op_info);
1267
1268 auto copy_op_info = std::make_shared<pynative::SliceOpInfo>();
1269 copy_op_info->slice_op_name = kCopyWithSliceOpName;
1270 copy_op_info->data_indexs = {0, 1};
1271 (void)slice_op_infos->emplace_back(copy_op_info);
1272 ValuePtr rdata_value;
1273 if (IsStubTensor(py_value)) {
1274 rdata_value = GetStubTensorValue(py_value);
1275 if (new_data_shape.size() == 0) {
1276 auto tensor = ConvertStubTensor(py_value);
1277 rdata_value = SqueezeRDataValue(tensor, py_value, rdata_value);
1278 }
1279 } else if (py::isinstance<Tensor>(py_value)) {
1280 auto tensor = py_value.cast<TensorPtr>();
1281 MS_EXCEPTION_IF_NULL(tensor);
1282 rdata_value = tensor;
1283 if (new_data_shape.size() == 0) {
1284 rdata_value = SqueezeRDataValue(tensor, py_value, rdata_value);
1285 }
1286 } else if (py::isinstance<py::int_>(py_value)) {
1287 rdata_value = MakeValue(py::cast<int64_t>(py_value));
1288 } else if (py::isinstance<py::float_>(py_value)) {
1289 rdata_value = MakeValue(py::cast<float>(py_value));
1290 } else if (py::isinstance<py::bool_>(py_value)) {
1291 rdata_value = MakeValue(py::cast<bool>(py_value));
1292 } else {
1293 return py::none();
1294 }
1295 return pynative::PyNativeExecutor::GetInstance()->RunSliceOpStub({data_value, rdata_value}, *slice_op_infos);
1296 }
1297
SetitemBySliceWithTensor(const ShapeVector & data_shape,const TensorIndex & slice_index,std::vector<int64_t> * value_transfer_types,std::vector<py::object> * value_transfer_args,const ValuePtr & data_value,const TypePtr & data_type)1298 py::object TensorIndex::SetitemBySliceWithTensor(const ShapeVector &data_shape, const TensorIndex &slice_index,
1299 std::vector<int64_t> *value_transfer_types,
1300 std::vector<py::object> *value_transfer_args,
1301 const ValuePtr &data_value, const TypePtr &data_type) {
1302 ValueTransferType tensor_update_type = ValueTransferType::kTensorScatterUpdate;
1303 Slice slice_info = Slice(slice_index.slice(), data_shape[0]);
1304 int64_t start = slice_info.start();
1305 int64_t stop = slice_info.stop();
1306 int64_t step = slice_info.step();
1307 if (step >= 0 && data_value != nullptr) {
1308 std::vector<int64_t> data_transfer_types;
1309 std::vector<py::object> data_transfer_args;
1310 std::vector<int64_t> begin_info(data_shape.size(), 0);
1311 std::vector<int64_t> end_info(data_shape);
1312 std::vector<int64_t> step_info(data_shape.size(), 1);
1313 std::vector<pynative::SliceOpInfoPtr> slice_op_infos;
1314 if (start >= stop) {
1315 (void)data_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kJustReturn));
1316 return py::make_tuple(py::str("view"), py::tuple(), py::tuple(), VectorToPyTuple(data_transfer_types),
1317 py::tuple());
1318 }
1319 if (slice_info.start() != 0 || slice_info.step() != 1 || slice_info.stop() != end_info[0]) {
1320 begin_info[0] = slice_info.start();
1321 end_info[0] = slice_info.stop();
1322 step_info[0] = slice_info.step();
1323 auto slice_op_info = std::make_shared<pynative::SliceOpInfo>();
1324 slice_op_info->slice_op_name = prim::kPrimStridedSlice->name();
1325 (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(begin_info));
1326 (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(end_info));
1327 (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(step_info));
1328 (void)slice_op_info->data_indexs.emplace_back(0);
1329 (void)slice_op_infos.emplace_back(slice_op_info);
1330 }
1331 auto new_data_shape = data_shape;
1332 if (step != 0) {
1333 auto new_shape_zero = (stop - start) / step;
1334 new_data_shape[0] = (new_shape_zero < 0 ? 0 : (stop + step - 1 - start) / step);
1335 }
1336 auto slice_output = SetitemCopyView(&slice_op_infos, data_value, new_data_shape, data_type, py_value_handle_);
1337 if (slice_output != py::none()) {
1338 data_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kJustReturn));
1339 data_transfer_args.emplace_back(slice_output);
1340 return py::make_tuple(py::str("view"), py::tuple(), py::tuple(), VectorToPyTuple(data_transfer_types),
1341 VectorToPyTuple(data_transfer_args));
1342 }
1343 (void)data_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kStrideSlice));
1344 (void)data_transfer_args.emplace_back(py::make_tuple(
1345 py::make_tuple(slice_info.start()), py::make_tuple(slice_info.stop()), py::make_tuple(slice_info.step())));
1346 (void)data_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kCopyView));
1347 (void)data_transfer_args.emplace_back(py::none());
1348 return py::make_tuple(py::str("view"), VectorToPyTuple<int64_t>(*value_transfer_types),
1349 VectorToPyTuple<py::object>(*value_transfer_args), VectorToPyTuple(data_transfer_types),
1350 VectorToPyTuple(data_transfer_args));
1351 }
1352 if (slice_index.slice().step() == 1 && !TensorIndex::is_ascend_) {
1353 if (data_shape.empty()) {
1354 MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
1355 }
1356 int64_t dim0_size = stop - start;
1357 if (dim0_size <= 0) {
1358 tensor_update_type = ValueTransferType::kByPass;
1359 return py::make_tuple(py::none(), VectorToPyTuple<int64_t>(*value_transfer_types),
1360 VectorToPyTuple<py::object>(*value_transfer_args),
1361 py::make_tuple(static_cast<int>(tensor_update_type)), py::make_tuple(py::none()));
1362 }
1363 ShapeVector value_shape = {dim0_size};
1364 (void)value_shape.insert(value_shape.end(), data_shape.begin() + 1, data_shape.end());
1365 value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kBroadCast));
1366 value_transfer_args->emplace_back(VectorToPyTuple(value_shape));
1367 value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kCast));
1368 value_transfer_args->emplace_back(py::none());
1369 tensor_update_type = ValueTransferType::kCopySlice;
1370 return py::make_tuple(
1371 py::none(), VectorToPyTuple<int64_t>(*value_transfer_types), VectorToPyTuple<py::object>(*value_transfer_args),
1372 py::make_tuple(static_cast<int>(tensor_update_type)),
1373 py::make_tuple(py::make_tuple(py::make_tuple(start), py::make_tuple(stop), py::make_tuple(step))));
1374 }
1375 TensorIndex indices = SliceToArray(slice_index, data_shape);
1376 if (indices.IsBoolean()) {
1377 tensor_update_type = ValueTransferType::kByPass;
1378 return py::make_tuple(indices.boolean(), VectorToPyTuple<int64_t>(*value_transfer_types),
1379 VectorToPyTuple<py::object>(*value_transfer_args),
1380 py::make_tuple(static_cast<int>(tensor_update_type)), py::make_tuple(py::none()));
1381 }
1382 value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kBroadCast));
1383 TensorPtr indices_tensor = TensorPy::MakeTensor(TensorIndex::np_module_.attr("array")(indices.array()));
1384 MS_EXCEPTION_IF_NULL(indices_tensor);
1385 ShapeVector broad_cast_shape(indices_tensor->shape().begin(), indices_tensor->shape().end() - 1);
1386 value_transfer_args->emplace_back(VectorToPyTuple(broad_cast_shape));
1387 value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kCast));
1388 value_transfer_args->emplace_back(py::none());
1389 return py::make_tuple(indices_tensor, VectorToPyTuple<int64_t>(*value_transfer_types),
1390 VectorToPyTuple<py::object>(*value_transfer_args),
1391 py::make_tuple(static_cast<int>(tensor_update_type)), py::make_tuple(py::none()));
1392 }
1393
SetItemByTensorByBool(const ShapeVector & data_shape,const TensorPtr & index,int64_t data_dims,std::vector<int64_t> * value_transfer_types,std::vector<py::object> * value_transfer_args,ValueTransferType * tensor_update_type)1394 py::array TensorIndex::SetItemByTensorByBool(const ShapeVector &data_shape, const TensorPtr &index, int64_t data_dims,
1395 std::vector<int64_t> *value_transfer_types,
1396 std::vector<py::object> *value_transfer_args,
1397 ValueTransferType *tensor_update_type) {
1398 ShapeVector index_shape = GeneratePaddingShape(index->shape(), data_dims);
1399 py::array np_index = TensorPy::SyncAsNumpy(*index);
1400 py::array output_np_index = TensorIndex::np_module_.attr("broadcast_to")(
1401 TensorIndex::np_module_.attr("reshape")(np_index, VectorToPyTuple(index_shape)), VectorToPyTuple(data_shape));
1402 value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kCast));
1403 value_transfer_args->emplace_back(py::none());
1404 value_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kBroadCast));
1405 value_transfer_args->emplace_back(VectorToPyTuple(data_shape));
1406 *tensor_update_type = ValueTransferType::kSelect;
1407 return output_np_index;
1408 }
1409
1410 // ***********************************************get get_item info*******************************************
GetItemByTensor(const ShapeVector & data_shape,const TensorPtr & index)1411 py::object TensorIndex::GetItemByTensor(const ShapeVector &data_shape, const TensorPtr &index) {
1412 MS_EXCEPTION_IF_NULL(index);
1413 MS_LOG(DEBUG) << "In branch get item by tensor, data_shape: " << data_shape
1414 << " tensor_indexes: " << index->ToString();
1415 constexpr int min_data_dim = 1;
1416 constexpr int max_data_dim = 7;
1417 const int64_t data_dim = SizeToLong(data_shape.size());
1418 JudgeDataDim(data_dim, min_data_dim, max_data_dim);
1419 py::object output = py::none();
1420 if (CheckTypeIsInstance<TypeId>(index->data_type(), kIntTypes)) {
1421 output =
1422 py::make_tuple(index, py::make_tuple(static_cast<int>(ValueTransferType::kGather)), py::make_tuple(py::none()));
1423 } else if (index->data_type() == kNumberTypeBool) {
1424 py::tuple nonzero_indices = GenerateNonZeroIndex(data_shape, index, true);
1425 MS_EXCEPTION_IF_CHECK_FAIL(!nonzero_indices.empty(), "Output size of nonzero should not be empty");
1426 int64_t nonzero_indices_nums = SizeToLong(len(py::array(nonzero_indices[0])));
1427 if (nonzero_indices_nums == 0) {
1428 ShapeVector empty_tensor_shape(data_shape.begin() + index->DataDim(), data_shape.end());
1429 (void)empty_tensor_shape.insert(empty_tensor_shape.begin(), 0);
1430
1431 return py::make_tuple(py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kEmptyTensor)),
1432 py::make_tuple(VectorToPyTuple(empty_tensor_shape)));
1433 }
1434 output = py::make_tuple(index, py::make_tuple(static_cast<int>(ValueTransferType::kGetitemByBoolTensor)),
1435 py::make_tuple(py::none()));
1436 } else {
1437 MS_EXCEPTION(IndexError) << "The tensor index must be int or bool type, but got " << TensorIndex::py_index_handle_;
1438 }
1439 return output;
1440 }
1441
GetItemByList(const ShapeVector & data_shape,const TensorIndex & tensor_index)1442 py::object TensorIndex::GetItemByList(const ShapeVector &data_shape, const TensorIndex &tensor_index) {
1443 MS_LOG(DEBUG) << "In branch get item by List, data_shape: " << data_shape << " tensor_index: " << tensor_index;
1444 constexpr int min_data_dim = 1;
1445 constexpr int max_data_dim = 8;
1446 int64_t data_dim = SizeToLong(data_shape.size());
1447 JudgeDataDim(data_dim, min_data_dim, max_data_dim);
1448 bool use_gather = std::all_of(tensor_index.list().begin(), tensor_index.list().end(),
1449 [](auto &x) { return py::isinstance<py::int_>(x) || py::isinstance<py::bool_>(x); });
1450 if (use_gather) {
1451 if (data_shape.empty()) {
1452 MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
1453 }
1454 TensorIndex tuple_index = SequenceToTensor(tensor_index, data_shape[0]);
1455 if (tuple_index.IsBoolean() && !tuple_index.boolean()) {
1456 MS_EXCEPTION(IndexError) << "When tensor is indexed by list, the list can't be empty.";
1457 }
1458 return py::make_tuple(tuple_index.tensor(), py::make_tuple(static_cast<int>(ValueTransferType::kGather)),
1459 py::make_tuple(py::none()));
1460 }
1461 return GetItemByTuple(data_shape, tensor_index.ExpandToVector());
1462 }
1463
JudgeTupleIndexDim(int64_t data_dim,const std::vector<TensorIndex> & new_tuple_indexes)1464 static void JudgeTupleIndexDim(int64_t data_dim, const std::vector<TensorIndex> &new_tuple_indexes) {
1465 int64_t index_dims = 0;
1466 for (const TensorIndex &index : new_tuple_indexes) {
1467 if (index.IsTensor() && index.tensor() != nullptr && index.tensor()->data_type() == kNumberTypeBool) {
1468 index_dims += index.tensor()->DataDim();
1469 } else {
1470 index_dims += 1;
1471 }
1472 }
1473 if (index_dims > data_dim) {
1474 MS_EXCEPTION(IndexError) << "The dim of index cannot be greater than indexed data, but got dim of index:"
1475 << index_dims << ", dim of data:" << data_dim;
1476 }
1477 }
1478
GetSpecifiedDimensions(const py::tuple & new_tuple_index,size_t data_dims)1479 size_t GetSpecifiedDimensions(const py::tuple &new_tuple_index, size_t data_dims) {
1480 size_t specified_dimensions = std::count_if(new_tuple_index.begin(), new_tuple_index.end(), [](auto const &obj) {
1481 return (obj != Py_None && obj != Py_Ellipsis && obj != Py_True && obj != Py_False);
1482 });
1483 constexpr size_t max_data_dim = 8;
1484 if (data_dims > max_data_dim) {
1485 MS_EXCEPTION(ValueError) << "The input data's dim must in the range of [0, " << max_data_dim << "], but got '"
1486 << data_dims << "'.";
1487 }
1488 if (specified_dimensions > data_dims) {
1489 MS_EXCEPTION(IndexError) << "too many indices for tensor of dimension" << data_dims;
1490 }
1491 return specified_dimensions;
1492 }
1493
1494 namespace {
CheckDataDim(const ShapeVector & data_shape)1495 void CheckDataDim(const ShapeVector &data_shape) {
1496 constexpr size_t max_data_dim = 8;
1497 if (data_shape.size() > max_data_dim) {
1498 MS_EXCEPTION(ValueError) << "The input data's dim must in the range of [1, " << max_data_dim << "], but got '"
1499 << data_shape.size() << "'.";
1500 }
1501 }
1502
CheckNumberOfEllipsis(const size_t counter)1503 void CheckNumberOfEllipsis(const size_t counter) {
1504 if (counter > 0) {
1505 MS_EXCEPTION(IndexError) << "An index can only have a single ellipsis('...')";
1506 }
1507 }
1508 } // namespace
1509
GetItemByTupleWithView(const ValuePtr & data_value,const ShapeVector & data_shape,const py::object & py_index,std::vector<int64_t> * data_transfer_types,std::vector<py::object> * data_transfer_args,const TypePtr & data_type)1510 bool TensorIndex::GetItemByTupleWithView(const ValuePtr &data_value, const ShapeVector &data_shape,
1511 const py::object &py_index, std::vector<int64_t> *data_transfer_types,
1512 std::vector<py::object> *data_transfer_args, const TypePtr &data_type) {
1513 if (data_value == nullptr) {
1514 return false;
1515 }
1516 MS_LOG(DEBUG) << "In branch get item by tuple with view, data_shape: " << data_shape
1517 << " tensor_indexes: " << py_index;
1518 size_t data_dims = data_shape.size();
1519 auto new_tuple_index = py_index.cast<py::tuple>();
1520 size_t specified_dimensions = GetSpecifiedDimensions(new_tuple_index, data_dims);
1521 bool empty_strided_slice_result = false;
1522 auto new_data_shape = data_shape;
1523 size_t dim = 0;
1524 std::vector<pynative::SliceOpInfoPtr> slice_op_infos;
1525 size_t ellipsis_count = 0;
1526 for (auto const &obj : new_tuple_index) {
1527 if (py::isinstance<py::int_>(obj) && !py::isinstance<py::bool_>(obj)) {
1528 auto index = py::cast<int64_t>(obj);
1529 if (index >= new_data_shape[dim] || index < -new_data_shape[dim]) {
1530 // Raise exception in python, because python iterator need raise IndexError to stop for loop.
1531 data_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kRaiseIndexError));
1532 data_transfer_args->emplace_back(py::make_tuple(index, new_data_shape[dim]));
1533 return true;
1534 }
1535 int64_t transformed_number = CheckRange(index, new_data_shape[dim]);
1536 auto slice_op_info = std::make_shared<pynative::SliceOpInfo>();
1537 slice_op_info->slice_op_name = prim::kPrimSelectView->name();
1538 (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(transformed_number));
1539 (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(dim));
1540 (void)slice_op_info->data_indexs.emplace_back(0);
1541 (void)slice_op_infos.emplace_back(slice_op_info);
1542 (void)new_data_shape.erase(new_data_shape.begin() + dim);
1543 } else if (py::isinstance<py::slice>(obj)) {
1544 auto slice_info = Slice(TensorIndex(obj).slice(), new_data_shape[dim]);
1545 std::vector<int64_t> begin_info(new_data_shape.size(), 0);
1546 std::vector<int64_t> end_info(new_data_shape);
1547 std::vector<int64_t> step_info(new_data_shape.size(), 1);
1548 if (slice_info.step() < 0) {
1549 data_transfer_types->clear();
1550 data_transfer_args->clear();
1551 return false;
1552 }
1553 if (slice_info.start() == 0 && slice_info.step() == 1 && slice_info.stop() == end_info[dim]) {
1554 dim++;
1555 continue;
1556 }
1557 empty_strided_slice_result = (slice_info.start() >= slice_info.stop());
1558 begin_info[dim] = slice_info.start();
1559 end_info[dim] = slice_info.stop();
1560 step_info[dim] = slice_info.step();
1561 auto slice_op_info = std::make_shared<pynative::SliceOpInfo>();
1562 slice_op_info->slice_op_name = prim::kPrimStridedSlice->name();
1563 (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(begin_info));
1564 (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(end_info));
1565 (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(step_info));
1566 (void)slice_op_info->data_indexs.emplace_back(0);
1567 (void)slice_op_infos.emplace_back(slice_op_info);
1568 new_data_shape[dim] = (slice_info.stop() + slice_info.step() - 1 - slice_info.start()) / slice_info.step();
1569 dim++;
1570 } else if (py::isinstance<py::ellipsis>(obj)) {
1571 CheckNumberOfEllipsis(ellipsis_count);
1572 dim += data_shape.size() - specified_dimensions;
1573 ellipsis_count += 1;
1574 } else if (py::isinstance<py::none>(obj)) {
1575 auto slice_op_info = std::make_shared<pynative::SliceOpInfo>();
1576 slice_op_info->slice_op_name = prim::kPrimExpandDims->name();
1577 (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(dim));
1578 (void)slice_op_info->data_indexs.emplace_back(0);
1579 (void)slice_op_infos.emplace_back(slice_op_info);
1580 new_data_shape.insert(new_data_shape.begin() + dim, 1);
1581 dim++;
1582 } else {
1583 data_transfer_types->clear();
1584 data_transfer_args->clear();
1585 return false;
1586 }
1587 }
1588 CheckDataDim(new_data_shape);
1589 py::object slice_output;
1590 if (data_type != nullptr) {
1591 if (empty_strided_slice_result) {
1592 data_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kByPass));
1593 data_transfer_args->emplace_back(py::none());
1594 return true;
1595 }
1596 slice_output = SetitemCopyView(&slice_op_infos, data_value, new_data_shape, data_type, py_value_handle_);
1597 if (slice_output == py::none()) {
1598 return false;
1599 }
1600 } else {
1601 if (slice_op_infos.empty()) {
1602 data_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kByPass));
1603 data_transfer_args->emplace_back(py::none());
1604 return true;
1605 }
1606 slice_output = pynative::PyNativeExecutor::GetInstance()->RunSliceOpStub({data_value}, slice_op_infos);
1607 }
1608 data_transfer_types->emplace_back(static_cast<int>(ValueTransferType::kJustReturn));
1609 data_transfer_args->emplace_back(slice_output);
1610 return true;
1611 }
1612
GetItemByTuple(const ShapeVector & data_shape,const std::vector<TensorIndex> & tensor_indexes)1613 py::object TensorIndex::GetItemByTuple(const ShapeVector &data_shape, const std::vector<TensorIndex> &tensor_indexes) {
1614 MS_LOG(DEBUG) << "In branch get item by tuple, data_shape: " << data_shape << " tensor_indexes: " << tensor_indexes;
1615 std::vector<int64_t> data_transfer_types;
1616 std::vector<py::object> data_transfer_args;
1617 ShapeVector new_data_shape = data_shape;
1618 if (tensor_indexes.empty()) {
1619 return py::make_tuple(py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kByPass)),
1620 py::make_tuple(py::none()));
1621 }
1622 std::vector<TensorIndex> new_tuple_indexes = TransformEllipsisToSlice(new_data_shape, tensor_indexes);
1623 std::tuple expand_dim_info = GetExpandDimsInfo(new_data_shape, new_tuple_indexes);
1624 constexpr size_t expand_dim_info_index = 0;
1625 constexpr size_t new_data_shape_index = 1;
1626 constexpr size_t new_tuple_indexes_index = 2;
1627 bool need_expand_dim = std::get<expand_dim_info_index>(expand_dim_info);
1628 if (need_expand_dim) {
1629 (void)data_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kReshape));
1630 new_data_shape = std::get<new_data_shape_index>(expand_dim_info);
1631 (void)data_transfer_args.emplace_back(VectorToPyTuple(new_data_shape));
1632 new_tuple_indexes = std::get<new_tuple_indexes_index>(expand_dim_info); // NOLINT
1633 }
1634 constexpr int min_data_dim = 1;
1635 constexpr int max_data_dim = 8;
1636 int64_t data_dim = SizeToLong(new_data_shape.size());
1637 JudgeDataDim(data_dim, min_data_dim, max_data_dim);
1638 JudgeTupleIndexDim(data_dim, new_tuple_indexes);
1639 bool normal_tuple = std::all_of(new_tuple_indexes.begin(), new_tuple_indexes.end(), [](auto &index_e) {
1640 return index_e.IsEllipsis() || index_e.IsInteger() || index_e.IsSlice();
1641 });
1642 if (normal_tuple) {
1643 std::tuple stride_slice_info = GetStrideInfoFromTuple(new_data_shape, new_tuple_indexes);
1644 (void)data_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kStrideSliceWithMask));
1645 std::vector<std::vector<int64_t>> stride_info = std::get<0>(stride_slice_info);
1646 std::vector<py::tuple> py_stride_info;
1647 (void)std::transform(stride_info.begin(), stride_info.end(), std::back_inserter(py_stride_info),
1648 [](auto &stride_info_i) { return VectorToPyTuple(stride_info_i); });
1649 std::vector<int64_t> mask_info = std::get<1>(stride_slice_info);
1650 (void)data_transfer_args.emplace_back(py::make_tuple(VectorToPyTuple(py_stride_info), VectorToPyTuple(mask_info)));
1651 return py::make_tuple(py::none(), VectorToPyTuple(data_transfer_types), VectorToPyTuple(data_transfer_args));
1652 }
1653 return TensorGetitemByTuple(new_data_shape, new_tuple_indexes, &data_transfer_types, &data_transfer_args);
1654 }
1655
GetItemByBool(const ValuePtr & data_value,const ShapeVector & data_shape,bool index)1656 py::object TensorIndex::GetItemByBool(const ValuePtr &data_value, const ShapeVector &data_shape, bool index) {
1657 MS_LOG(INFO) << "(View) In branch get item by bool, data_shape: " << data_shape << " tensor_indexes: " << index;
1658 constexpr int min_data_dim = 0;
1659 constexpr int max_data_dim = 7;
1660 int64_t data_dim = SizeToLong(data_shape.size());
1661 JudgeDataDim(data_dim, min_data_dim, max_data_dim);
1662 if (!index) {
1663 MS_EXCEPTION(IndexError) << "When tensor is indexed by a bool object, the value only support 'True'.";
1664 }
1665 auto transfer_type = (data_value == nullptr ? ValueTransferType::kExpandDims : ValueTransferType::kUnsqueeze);
1666 return py::make_tuple(py::none(), py::make_tuple(static_cast<int>(transfer_type)), py::make_tuple(py::int_(0)));
1667 }
1668
GetItemByNumber(const ShapeVector & data_shape,int64_t index)1669 py::object TensorIndex::GetItemByNumber(const ShapeVector &data_shape, int64_t index) {
1670 MS_LOG(DEBUG) << "In branch get item by number, data_shape: " << data_shape << " tensor_indexes: " << index;
1671 if (data_shape.empty()) {
1672 MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
1673 }
1674 constexpr int min_data_dim = 1;
1675 constexpr int max_data_dim = 8;
1676 int64_t data_dim = SizeToLong(data_shape.size());
1677 JudgeDataDim(data_dim, min_data_dim, max_data_dim);
1678 if (index >= data_shape[0] || index < -data_shape[0]) {
1679 // Raise exception in python, because python iterator need raise IndexError to stop for loop.
1680 return py::make_tuple(py::make_tuple(py::none()),
1681 py::make_tuple(static_cast<int>(ValueTransferType::kRaiseIndexError)),
1682 py::make_tuple(py::make_tuple(index, data_shape[0])));
1683 }
1684 int64_t transformed_number = CheckRange(index, data_shape[0]);
1685 if (!TensorIndex::is_ascend_) {
1686 return py::make_tuple(std::make_shared<Tensor>(transformed_number),
1687 py::make_tuple(static_cast<int>(ValueTransferType::kGather)), py::make_tuple(py::none()));
1688 }
1689 std::vector<int64_t> begin_strides = {transformed_number};
1690 std::vector<int64_t> end_strides = {transformed_number + 1};
1691 std::vector<int64_t> step_strides = {1};
1692 for (size_t i = 1; i < data_shape.size(); i++) {
1693 (void)begin_strides.emplace_back(0);
1694 (void)end_strides.emplace_back(data_shape[i]);
1695 (void)step_strides.emplace_back(1);
1696 }
1697 int64_t shrink_axis_mask = 1;
1698 int64_t begin_mask = 0;
1699 int64_t end_mask = 0;
1700 constexpr size_t begin_mask_begin_bit = 2;
1701 constexpr size_t begin_mask_end_bit = 8;
1702 for (size_t i = begin_mask_begin_bit; i < begin_mask_end_bit; i++) {
1703 const auto mask_bit = 1 << i;
1704 begin_mask += mask_bit;
1705 end_mask += mask_bit;
1706 }
1707
1708 py::tuple stride_info =
1709 py::make_tuple(VectorToPyTuple(begin_strides), VectorToPyTuple(end_strides), VectorToPyTuple(step_strides));
1710 py::tuple mask_info = py::make_tuple(begin_mask, end_mask, shrink_axis_mask);
1711 return py::make_tuple(py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kStrideSliceWithMask)),
1712 py::make_tuple(py::make_tuple(stride_info, mask_info)));
1713 }
1714
GetItemByNumberWithView(const ValuePtr & data_value,const ShapeVector & data_shape,int64_t index)1715 py::object TensorIndex::GetItemByNumberWithView(const ValuePtr &data_value, const ShapeVector &data_shape,
1716 int64_t index) {
1717 MS_LOG(INFO) << "(View) In branch get item by number, data_shape: " << data_shape << " tensor_indexes: " << index;
1718 if (data_shape.empty()) {
1719 MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
1720 }
1721 constexpr int min_data_dim = 1;
1722 constexpr int max_data_dim = 8;
1723 int64_t data_dim = SizeToLong(data_shape.size());
1724 JudgeDataDim(data_dim, min_data_dim, max_data_dim);
1725 if (index >= data_shape[0] || index < -data_shape[0]) {
1726 // Raise exception in python, because python iterator need raise IndexError to stop for loop.
1727 return py::make_tuple(py::make_tuple(py::none()),
1728 py::make_tuple(static_cast<int>(ValueTransferType::kRaiseIndexError)),
1729 py::make_tuple(py::make_tuple(index, data_shape[0])));
1730 }
1731 int64_t transformed_number = CheckRange(index, data_shape[0]);
1732 // return py::make_tuple(py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kSelectView)),
1733 // py::make_tuple(py::make_tuple(py::int_(transformed_number), py::int_(0))));
1734 int64_t dim = 0;
1735 auto slice_op_info = std::make_shared<pynative::SliceOpInfo>();
1736
1737 slice_op_info->slice_op_name = prim::kPrimSelectView->name();
1738 (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(transformed_number));
1739 (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(dim));
1740 (void)slice_op_info->data_indexs.emplace_back(0);
1741
1742 auto slice_output = pynative::PyNativeExecutor::GetInstance()->RunSliceOpStub({data_value}, {slice_op_info});
1743 return py::make_tuple(py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kJustReturn)),
1744 py::make_tuple(slice_output));
1745 }
1746
GetItemBySlice(const ValuePtr & data_value,const ShapeVector & data_shape,const TensorIndex & py_index)1747 py::object TensorIndex::GetItemBySlice(const ValuePtr &data_value, const ShapeVector &data_shape,
1748 const TensorIndex &py_index) {
1749 MS_LOG(INFO) << "(View) In branch get item by slice, data_shape: " << data_shape << " tensor_indexes: " << py_index;
1750 constexpr int min_data_dim = 1;
1751 constexpr int max_data_dim = 8;
1752 size_t data_dim = data_shape.size();
1753 JudgeDataDim(SizeToLong(data_dim), min_data_dim, max_data_dim);
1754 if (data_shape.empty()) {
1755 MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
1756 }
1757 Slice slice_info = Slice(py_index.slice(), data_shape[0]);
1758 if (slice_info.step() >= 0 && data_value != nullptr) {
1759 std::vector<int64_t> begin_info(data_dim, 0);
1760 std::vector<int64_t> end_info(data_shape);
1761 std::vector<int64_t> step_info(data_dim, 1);
1762 begin_info[0] = slice_info.start();
1763 end_info[0] = slice_info.stop();
1764 step_info[0] = slice_info.step();
1765 return py::make_tuple(py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kStrideSlice)),
1766 py::make_tuple(py::make_tuple(VectorToPyTuple(begin_info), VectorToPyTuple(end_info),
1767 VectorToPyTuple(step_info))));
1768 }
1769 int64_t begin_mask = slice_info.start_init_by_none() ? 1 : 0;
1770 int64_t end_mask = slice_info.stop_init_by_none() ? 1 : 0;
1771 for (size_t i = 1; i < data_dim; i++) {
1772 const auto mask_bit = 1 << i;
1773 begin_mask += mask_bit;
1774 end_mask += mask_bit;
1775 }
1776 if (begin_mask != 0 || end_mask != 0) {
1777 py::tuple stride_info = py::make_tuple(py::make_tuple(slice_info.start()), py::make_tuple(slice_info.stop()),
1778 py::make_tuple(slice_info.step()));
1779 py::tuple mask_info = py::make_tuple(begin_mask, end_mask, 0);
1780 return py::make_tuple(py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kStrideSliceWithMask)),
1781 py::make_tuple(py::make_tuple(stride_info, mask_info)));
1782 }
1783 return py::make_tuple(
1784 py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kStrideSlice)),
1785 py::make_tuple(py::make_tuple(py::make_tuple(slice_info.start()), py::make_tuple(slice_info.stop()),
1786 py::make_tuple(slice_info.step()))));
1787 }
1788
GetItemIndexSimpleIndex(const py::object & py_index,const ValuePtr & data_value,const ShapeVector & data_shape)1789 py::object TensorIndex::GetItemIndexSimpleIndex(const py::object &py_index, const ValuePtr &data_value,
1790 const ShapeVector &data_shape) {
1791 if (py::isinstance<py::bool_>(py_index)) {
1792 return TensorIndex::GetItemByBool(data_value, data_shape, TensorIndex(py_index).boolean());
1793 }
1794 if (data_value != nullptr && py::isinstance<py::int_>(py_index)) {
1795 return TensorIndex::GetItemByNumberWithView(data_value, data_shape, TensorIndex(py_index).integer());
1796 }
1797 if (py::isinstance<py::slice>(py_index) || TensorIndex(py_index).slice().step() == -1) {
1798 return TensorIndex::GetItemBySlice(data_value, data_shape, TensorIndex(py_index));
1799 }
1800 if (py::isinstance<py::none>(py_index)) {
1801 return TensorIndex::GetItemByBool(data_value, data_shape, 1);
1802 }
1803 return py::none();
1804 }
1805
GetStubAbsTypeId(const AbstractBasePtr & abs)1806 TypeId GetStubAbsTypeId(const AbstractBasePtr &abs) {
1807 MS_EXCEPTION_IF_NULL(abs);
1808
1809 if (abs->isa<abstract::AbstractTensor>()) {
1810 auto tensor_abs = abs->cast<abstract::AbstractTensorPtr>();
1811 MS_EXCEPTION_IF_NULL(tensor_abs);
1812 MS_EXCEPTION_IF_NULL(tensor_abs->element());
1813 MS_EXCEPTION_IF_NULL(tensor_abs->element()->BuildType());
1814 return tensor_abs->element()->BuildType()->type_id();
1815 } else {
1816 MS_EXCEPTION_IF_NULL(abs->BuildType());
1817 return abs->BuildType()->type_id();
1818 }
1819 }
1820
EnableView(bool is_setitem=false)1821 bool EnableView(bool is_setitem = false) {
1822 if (pynative::PyNativeExecutor::GetInstance()->grad_executor()->is_high_order_top_cell()) {
1823 // 1. pack node will slice failed with view.
1824 // 2. SelectView and CopyWithSlice has no kernel, can not enable view in high order cell.
1825 return false;
1826 }
1827
1828 // For setitem, the grad of CopyWithSlice is erroneous. If we are in setitem and requires grad, disable view.
1829 if (is_setitem && pynative::PyNativeExecutor::GetInstance()->grad_executor()->RequiresGrad()) return false;
1830
1831 return true;
1832 }
1833
GetItemIndexInfo(const py::object & py_data,const py::object & py_index,const py::bool_ & is_ascend)1834 py::object TensorIndex::GetItemIndexInfo(const py::object &py_data, const py::object &py_index,
1835 const py::bool_ &is_ascend) {
1836 ShapeVector data_shape;
1837 ValuePtr data_value;
1838 if (IsStubTensor(py_data)) {
1839 auto value = GetStubTensorValue(py_data);
1840 MS_EXCEPTION_IF_NULL(value);
1841 auto abs = value->ToAbstract();
1842 MS_EXCEPTION_IF_NULL(abs);
1843 data_shape = dyn_cast<abstract::Shape>(abs->BuildShape())->shape();
1844
1845 if (EnableView()) {
1846 data_value = value;
1847 }
1848 } else if (py::isinstance<Tensor>(py_data)) {
1849 auto tensor = py_data.cast<TensorPtr>();
1850 MS_EXCEPTION_IF_NULL(tensor);
1851 if (EnableView()) {
1852 data_value = tensor;
1853 }
1854 data_shape = tensor->shape();
1855 } else {
1856 MS_EXCEPTION(TypeError) << "First input of Tensor index must be tensor but got " << py_data;
1857 }
1858
1859 const auto &simple_index_output = GetItemIndexSimpleIndex(py_index, data_value, data_shape);
1860 if (simple_index_output != py::none()) {
1861 return simple_index_output;
1862 }
1863
1864 std::vector<int64_t> data_transfer_types;
1865 std::vector<py::object> data_transfer_args;
1866 if (py::isinstance<py::tuple>(py_index) &&
1867 GetItemByTupleWithView(data_value, data_shape, py_index, &data_transfer_types, &data_transfer_args, nullptr)) {
1868 MS_LOG(INFO) << "(View) In branch get item by tuple with view, data_shape: " << data_shape
1869 << " tensor_indexes: " << py_index;
1870 return py::make_tuple(py::none(), VectorToPyTuple(data_transfer_types), VectorToPyTuple(data_transfer_args));
1871 }
1872 MS_LOG(INFO) << "(Tensor) Get item datashape is: " << data_shape << ", index is: " << py_index;
1873 py::object new_py_index = IsStubTensor(py_index) ? py::cast(ConvertStubTensor(py_index)) : py_index;
1874 TensorIndex::py_index_handle_ = new_py_index;
1875 TensorIndex::is_ascend_ = is_ascend;
1876 TensorIndex::np_module_ = py::module::import("numpy");
1877 TensorIndex::index_op_type_ = IndexOpType::GetItem;
1878 TensorIndex index(new_py_index);
1879 CheckGetItemIndex(index.type());
1880 py::object output = py::none();
1881 switch (index.type()) {
1882 case TensorIndexType::Tensor: {
1883 output = GetItemByTensor(data_shape, index.tensor());
1884 break;
1885 }
1886 case TensorIndexType::List: {
1887 output = GetItemByList(data_shape, index);
1888 break;
1889 }
1890 case TensorIndexType::Tuple: {
1891 output = GetItemByTuple(data_shape, index.ExpandToVector());
1892 break;
1893 }
1894 case TensorIndexType::Boolean: {
1895 output = GetItemByBool(data_value, data_shape, index.boolean());
1896 break;
1897 }
1898 case TensorIndexType::Ellipsis: {
1899 output = py::make_tuple(py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kByPass)),
1900 py::make_tuple(py::none()));
1901 break;
1902 }
1903 case TensorIndexType::Integer: {
1904 output = GetItemByNumber(data_shape, index.integer());
1905 break;
1906 }
1907 default: {
1908 MS_EXCEPTION(TypeError)
1909 << "Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor, int, list and "
1910 "tuple as index, but got "
1911 << TensorIndex::py_index_handle_ << " with type " << TensorIndex::py_index_handle_.get_type();
1912 }
1913 }
1914 return output;
1915 }
1916
1917 // ***********************************************get set_item info*******************************************
SetItemByNumber(const ShapeVector & data_shape,const TypePtr & data_type,bool is_parameter,const TensorIndex & tensor_index,const TensorIndexType & py_value_type)1918 py::object TensorIndex::SetItemByNumber(const ShapeVector &data_shape, const TypePtr &data_type, bool is_parameter,
1919 const TensorIndex &tensor_index, const TensorIndexType &py_value_type) {
1920 // If tensor is small, we use method in IntToTensor for faster
1921 MS_LOG(DEBUG) << "In branch Set item by number, data_shape: " << data_shape << " tensor_indexes: " << tensor_index
1922 << "value: " << TensorIndex::py_value_handle_;
1923
1924 std::tuple<int64_t, py::object, ShapeVector> value_transfer =
1925 GetValueTransferType(py_value_type, set_item_by_non_tensor, data_type, false);
1926 std::vector<int64_t> value_transfer_types = {std::get<0>(value_transfer)};
1927 std::vector<py::object> value_transfer_args = {std::get<1>(value_transfer)};
1928 if (data_shape.empty()) {
1929 MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
1930 }
1931 int64_t dim_size = data_shape[0];
1932 int64_t index = tensor_index.integer();
1933 if (index < -dim_size || index >= dim_size) {
1934 MS_EXCEPTION(IndexError) << "Index " << index << " is out of bounds for axis 0 with size " << dim_size;
1935 }
1936 TensorPtr new_index = std::make_shared<Tensor>();
1937 if (!CheckLargeTensor(data_shape)) {
1938 new_index = IntToTensor(index, data_shape);
1939 (void)value_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kBroadCast));
1940 MS_EXCEPTION_IF_NULL(new_index);
1941 ShapeVector value_shape(new_index->shape().begin(), new_index->shape().end() - 1);
1942 value_transfer_args.push_back(VectorToPyTuple<int64_t>(value_shape));
1943 } else {
1944 auto out_i = static_cast<int32_t>(CheckRange(index, dim_size));
1945 new_index = std::make_shared<Tensor>(kNumberTypeInt32, ShapeVector({1, 1}), &out_i, int32_bytes_number);
1946 ShapeVector updates_shape = {1};
1947 (void)updates_shape.insert(updates_shape.end(), data_shape.begin() + 1, data_shape.end());
1948 (void)value_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kBroadCast));
1949 (void)value_transfer_args.emplace_back(VectorToPyTuple(updates_shape));
1950 }
1951 ValueTransferType data_transfer_type =
1952 is_parameter ? ValueTransferType::kScatterNdUpdate : ValueTransferType::kTensorScatterUpdate;
1953 return py::make_tuple(new_index, VectorToPyTuple<int64_t>(value_transfer_types),
1954 VectorToPyTuple<py::object>(value_transfer_args),
1955 py::make_tuple(static_cast<int>(data_transfer_type)), py::make_tuple(py::none()));
1956 }
1957
SetItemByNumberWithView(const ShapeVector & data_shape,const TypePtr & data_type,bool is_parameter,const TensorIndex & tensor_index,const TensorIndexType & py_value_type,const ValuePtr & data_value)1958 py::object TensorIndex::SetItemByNumberWithView(const ShapeVector &data_shape, const TypePtr &data_type,
1959 bool is_parameter, const TensorIndex &tensor_index,
1960 const TensorIndexType &py_value_type, const ValuePtr &data_value) {
1961 // If tensor is small, we use method in IntToTensor for faster
1962 MS_LOG(INFO) << "(View) In branch set item by number, data_shape: " << data_shape
1963 << " tensor_indexes: " << tensor_index << "value: " << TensorIndex::py_value_handle_;
1964
1965 std::tuple<int64_t, py::object, ShapeVector> value_transfer =
1966 GetValueTransferType(py_value_type, set_item_by_non_tensor, data_type, true);
1967 std::vector<int64_t> value_transfer_types = {std::get<0>(value_transfer)};
1968 std::vector<py::object> value_transfer_args = {std::get<1>(value_transfer)};
1969 if (data_shape.empty()) {
1970 MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
1971 }
1972 int64_t dim_size = data_shape[0];
1973 int64_t index = tensor_index.integer();
1974 if (index < -dim_size || index >= dim_size) {
1975 MS_EXCEPTION(IndexError) << "Index " << index << " is out of bounds for axis 0 with size " << dim_size;
1976 }
1977 ShapeVector updates_shape = {1};
1978 (void)updates_shape.insert(updates_shape.end(), data_shape.begin() + 1, data_shape.end());
1979 std::vector<int64_t> data_transfer_types;
1980 std::vector<py::object> data_transfer_args;
1981 int64_t transformed_number = CheckRange(index, data_shape.at(0));
1982
1983 std::vector<pynative::SliceOpInfoPtr> slice_op_infos;
1984 std::vector<int64_t> new_data_shape(data_shape.begin() + 1, data_shape.end());
1985 auto slice_op_info = std::make_shared<pynative::SliceOpInfo>();
1986 slice_op_info->slice_op_name = prim::kPrimSelectView->name();
1987 (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(transformed_number));
1988 (void)slice_op_info->slice_index_inputs.emplace_back(std::make_shared<pynative::FastValue>(0));
1989 (void)slice_op_info->data_indexs.emplace_back(0);
1990 (void)slice_op_infos.emplace_back(slice_op_info);
1991 auto slice_output = SetitemCopyView(&slice_op_infos, data_value, new_data_shape, data_type, py_value_handle_);
1992 if (slice_output != py::none()) {
1993 data_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kJustReturn));
1994 data_transfer_args.emplace_back(slice_output);
1995 return py::make_tuple(py::str("view"), py::tuple(), py::tuple(), VectorToPyTuple(data_transfer_types),
1996 VectorToPyTuple(data_transfer_args));
1997 }
1998
1999 (void)data_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kSelectView));
2000 (void)data_transfer_args.emplace_back(py::make_tuple(py::int_(transformed_number), py::int_(0)));
2001 (void)data_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kCopyView));
2002 (void)data_transfer_args.emplace_back(py::none());
2003 return py::make_tuple(py::str("view"), VectorToPyTuple<int64_t>(value_transfer_types),
2004 VectorToPyTuple<py::object>(value_transfer_args), VectorToPyTuple(data_transfer_types),
2005 VectorToPyTuple(data_transfer_args));
2006 }
2007
SetItemByTensor(const ShapeVector & data_shape,bool is_parameter,const TensorIndex & tensor_index,const TensorIndexType & py_value_type)2008 py::object TensorIndex::SetItemByTensor(const ShapeVector &data_shape, bool is_parameter,
2009 const TensorIndex &tensor_index, const TensorIndexType &py_value_type) {
2010 MS_LOG(DEBUG) << "In branch Set item by tensor, data_shape: " << data_shape << " tensor_indexes: " << tensor_index
2011 << "value: " << TensorIndex::py_value_handle_;
2012 std::vector<int64_t> value_transfer_types;
2013 std::vector<py::object> value_transfer_args;
2014 const TensorPtr &index = tensor_index.tensor();
2015 int64_t data_dims = SizeToLong(data_shape.size());
2016 MS_EXCEPTION_IF_NULL(index);
2017 bool format_index_tensor = false;
2018 ValueTransferType tensor_update_type = ValueTransferType::kTensorScatterUpdate;
2019 py::array np_index;
2020 if (CheckTypeIsInstance(py_value_type, {TensorIndexType::Float, TensorIndexType::Integer, TensorIndexType::Boolean,
2021 TensorIndexType::Tensor})) {
2022 if (!CheckTypeIsInstance<TypeId>(index->data_type(), {kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32,
2023 kNumberTypeInt64, kNumberTypeBool})) {
2024 MS_EXCEPTION(IndexError) << "For tensor set item, the index tensor data type" << index->data_type()
2025 << " is not supported.";
2026 }
2027 if (index->data_type() == kNumberTypeBool) {
2028 if (CheckScalarValue(TensorIndex::py_value_handle_)) {
2029 np_index = SetItemByTensorByBool(data_shape, index, data_dims, &value_transfer_types, &value_transfer_args,
2030 &tensor_update_type);
2031 } else {
2032 return py::make_tuple(index, py::make_tuple(), py::make_tuple(),
2033 py::make_tuple(static_cast<int>(ValueTransferType::kSetitemByBoolTensor)),
2034 py::make_tuple(py::none()));
2035 }
2036 } else {
2037 ShapeVector index_shape = index->shape();
2038 np_index = TensorPy::SyncAsNumpy(*index);
2039 if (index_shape.empty()) {
2040 np_index = TensorIndex::np_module_.attr("expand_dims")(np_index, -1);
2041 (void)index_shape.emplace_back(1);
2042 }
2043 ShapeVector updates_shape = index_shape;
2044 (void)updates_shape.insert(updates_shape.end(), data_shape.begin() + 1, data_shape.end());
2045 if (py_value_type != TensorIndexType::Tensor) {
2046 (void)value_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kNumberToTensor));
2047 } else {
2048 (void)value_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kCast));
2049 }
2050 (void)value_transfer_args.emplace_back(py::none());
2051 (void)value_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kBroadCast));
2052 (void)value_transfer_args.emplace_back(VectorToPyTuple(updates_shape));
2053 if (data_shape.empty()) {
2054 MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
2055 }
2056 int64_t index_shape_dim = std::accumulate(index_shape.begin(), index_shape.end(), 1, std::multiplies<>());
2057 if (index_shape_dim <= 1) {
2058 int64_t first_val = data_shape[0];
2059 np_index = TensorIndex::np_module_.attr("select")(
2060 TensorIndex::np_module_.attr("less")(np_index, 0),
2061 TensorIndex::np_module_.attr("add")(np_index, py::int_(first_val)), np_index);
2062 } else {
2063 format_index_tensor = true;
2064 }
2065 np_index = TensorIndex::np_module_.attr("expand_dims")(np_index, -1);
2066 (void)index_shape.emplace_back(1);
2067 constexpr int64_t min_index_shape_size = 2;
2068 if (index_shape.size() < min_index_shape_size) {
2069 auto np_expand_dims_method = TensorIndex::np_module_.attr("expand_dims");
2070 np_index = np_expand_dims_method(np_index, 0);
2071 (void)value_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kExpandDims));
2072 (void)value_transfer_args.emplace_back(py::int_(0));
2073 }
2074 tensor_update_type = is_parameter ? ValueTransferType::kScatterNdUpdate : ValueTransferType::kTensorScatterUpdate;
2075 }
2076 } else if (py_value_type == TensorIndexType::Tuple || py_value_type == TensorIndexType::List) {
2077 (void)value_transfer_types.emplace_back(static_cast<int>(ValueTransferType::kHandleSequenceValue));
2078 (void)value_transfer_args.emplace_back(py::make_tuple(py::int_(set_item_by_one_tensor), index));
2079 if (CheckTypeIsInstance<TypeId>(index->data_type(), kIntTypes)) {
2080 np_index = TensorPy::SyncAsNumpy(*index);
2081 np_index = CastToInt(TensorIndex::np_module_.attr("expand_dims")(np_index, -1));
2082 tensor_update_type = ValueTransferType::kTensorScatterUpdate;
2083 } else if (index->data_type() == kNumberTypeBool) {
2084 return py::make_tuple(
2085 index, VectorToPyTuple<int64_t>(value_transfer_types), VectorToPyTuple<py::object>(value_transfer_args),
2086 py::make_tuple(static_cast<int>(ValueTransferType::kSetitemByBoolTensor)), py::make_tuple(py::none()));
2087 } else {
2088 MS_EXCEPTION(TypeError) << "The tensor index must be int or bool type, but got " << tensor_index;
2089 }
2090 }
2091 std::vector<int> tensor_update_types{static_cast<int>(tensor_update_type)};
2092 std::vector<py::object> tensor_update_args{py::none()};
2093 if (format_index_tensor) {
2094 (void)tensor_update_types.insert(tensor_update_types.begin(),
2095 static_cast<int>(ValueTransferType::kFormatIndexTensor));
2096 (void)tensor_update_args.insert(tensor_update_args.begin(), py::make_tuple(0, data_shape[0]));
2097 }
2098 return py::make_tuple(TensorPy::MakeTensor(TensorIndex::np_module_.attr("array")(np_index)),
2099 VectorToPyTuple<int64_t>(value_transfer_types),
2100 VectorToPyTuple<py::object>(value_transfer_args), VectorToPyTuple<int>(tensor_update_types),
2101 VectorToPyTuple<py::object>(tensor_update_args));
2102 }
2103
SetItemByTuple(const ShapeVector & data_shape,const TypePtr & data_type,const TensorIndex & py_index,const TensorIndexType & py_value_type)2104 py::object TensorIndex::SetItemByTuple(const ShapeVector &data_shape, const TypePtr &data_type,
2105 const TensorIndex &py_index, const TensorIndexType &py_value_type) {
2106 MS_LOG(DEBUG) << "In branch Set item by tuple, data_shape: " << data_shape << " tensor_indexes: " << py_index
2107 << "value: " << TensorIndex::py_value_handle_;
2108 if (!CheckTypeIsInstance<TensorIndexType>(py_value_type,
2109 {TensorIndexType::Integer, TensorIndexType::Float, TensorIndexType::Boolean,
2110 TensorIndexType::Tensor, TensorIndexType::List, TensorIndexType::Tuple})) {
2111 MS_EXCEPTION(TypeError) << "Only support int, float, bool, Tensor, list, tuple as value, but got "
2112 << TensorIndex::py_value_handle_.get_type();
2113 }
2114
2115 std::tuple<int64_t, py::object, ShapeVector> value_transfer =
2116 GetValueTransferType(py_value_type, set_item_by_non_tensor, data_type, false);
2117 constexpr size_t value_transfer_types_index = 0;
2118 constexpr size_t value_transfer_args_index = 1;
2119 constexpr size_t value_transfer_shapes_index = 2;
2120 std::vector<int64_t> value_transfer_types = {std::get<value_transfer_types_index>(value_transfer)};
2121 std::vector<py::object> value_transfer_args = {std::get<value_transfer_args_index>(value_transfer)};
2122 ShapeVector value_transfer_shape = {std::get<value_transfer_shapes_index>(value_transfer)};
2123
2124 if (CheckTypeIsInstance<TensorIndexType>(
2125 py_value_type, {TensorIndexType::Boolean, TensorIndexType::Float, TensorIndexType::Integer})) {
2126 TensorIndex index = TensorIndex::UnpackTuple(py_index);
2127 std::vector<TensorIndex> index_list = index.ExpandToVector();
2128 return SetitemByTupleWithTensor(data_shape, index_list, value_transfer_shape, &value_transfer_types,
2129 &value_transfer_args);
2130 }
2131 std::vector<TensorIndex> index_list = py_index.ExpandToVector();
2132 return SetitemByTupleWithTensor(data_shape, index_list, value_transfer_shape, &value_transfer_types,
2133 &value_transfer_args);
2134 }
2135
SetItemBySlice(const ShapeVector & data_shape,const TypePtr & data_type,const TensorIndex & tensor_index,const TensorIndexType & py_value_type,const ValuePtr & data_value)2136 py::object TensorIndex::SetItemBySlice(const ShapeVector &data_shape, const TypePtr &data_type,
2137 const TensorIndex &tensor_index, const TensorIndexType &py_value_type,
2138 const ValuePtr &data_value) {
2139 MS_LOG(INFO) << "(View) In branch set item by slice, data_shape: " << data_shape
2140 << " tensor_indexes: " << tensor_index << "value: " << TensorIndex::py_value_handle_;
2141 if (!CheckTypeIsInstance<TensorIndexType>(py_value_type,
2142 {TensorIndexType::Integer, TensorIndexType::Float, TensorIndexType::Boolean,
2143 TensorIndexType::Tensor, TensorIndexType::List, TensorIndexType::Tuple})) {
2144 MS_EXCEPTION(TypeError) << "Only support int, float, bool, Tensor, list, tuple as value, but got "
2145 << TensorIndex::py_value_handle_.get_type();
2146 }
2147 Slice slice_info = Slice(tensor_index.slice(), data_shape[0]);
2148 std::tuple<int64_t, py::object, ShapeVector> value_transfer =
2149 GetValueTransferType(py_value_type, set_item_by_non_tensor, data_type, slice_info.step() >= 0);
2150 std::vector<int64_t> value_transfer_types = {std::get<0>(value_transfer)};
2151 std::vector<py::object> value_transfer_args = {std::get<1>(value_transfer)};
2152 return SetitemBySliceWithTensor(data_shape, tensor_index, &value_transfer_types, &value_transfer_args, data_value,
2153 data_type);
2154 }
2155
SetItemIndexInfo(const py::object & py_data,const py::object & py_index,const py::object & py_value,const py::bool_ & is_ascend)2156 py::object TensorIndex::SetItemIndexInfo(const py::object &py_data, const py::object &py_index,
2157 const py::object &py_value, const py::bool_ &is_ascend) {
2158 if (!py::isinstance<Tensor>(py_data) && !IsStubTensor(py_data)) {
2159 MS_EXCEPTION(TypeError) << "First input of Tensor index must be tensor but got " << py_data;
2160 }
2161 ShapeVector data_shape;
2162 TypePtr data_type;
2163 bool is_parameter = false;
2164 ValuePtr data_value;
2165 if (IsStubTensor(py_data)) { // PackTensor have not real Tensor.
2166 auto value = GetStubTensorValue(py_data);
2167 MS_EXCEPTION_IF_NULL(value);
2168 auto abs = value->ToAbstract();
2169 MS_EXCEPTION_IF_NULL(abs);
2170 data_shape = dyn_cast<abstract::Shape>(abs->BuildShape())->shape();
2171 data_type = abs->BuildType();
2172 MS_EXCEPTION_IF_NULL(data_type);
2173 if (EnableView()) {
2174 data_value = value;
2175 }
2176 } else {
2177 TensorPtr data = py_data.cast<TensorPtr>();
2178 MS_EXCEPTION_IF_NULL(data);
2179 if (EnableView(true)) {
2180 data_value = data;
2181 }
2182 data_shape = data->shape();
2183 data_type = data->Dtype();
2184 is_parameter = data->is_parameter();
2185 }
2186 TensorIndex::py_value_handle_ = py_value;
2187 TensorIndex::np_module_ = py::module::import("numpy");
2188 TensorIndex::py_index_handle_ = py_index;
2189 TensorIndex::is_ascend_ = is_ascend;
2190 TensorIndex::index_op_type_ = IndexOpType::SetItem;
2191 const TensorIndexType value_type = IsStubTensor(py_value) ? TensorIndexType::Tensor : TensorIndex(py_value).type();
2192 bool valid = CheckTypeIsInstance<TensorIndexType>(
2193 value_type, {TensorIndexType::Integer, TensorIndexType::Float, TensorIndexType::Boolean, TensorIndexType::Tensor,
2194 TensorIndexType::List, TensorIndexType::Tuple});
2195 if (!valid) {
2196 MS_EXCEPTION(TypeError) << "only support numbers, Tensor, tuple, list as value, but got "
2197 << TensorIndex::py_value_handle_ << " with type "
2198 << TensorIndex::py_value_handle_.get_type();
2199 }
2200 if (py::isinstance<py::int_>(py_index) && !py::isinstance<py::bool_>(py_index) && data_value != nullptr) {
2201 return SetItemByNumberWithView(data_shape, data_type, is_parameter, TensorIndex(py_index), value_type, data_value);
2202 }
2203 if (py::isinstance<py::slice>(py_index)) {
2204 return TensorIndex::SetItemBySlice(data_shape, data_type, TensorIndex(py_index), value_type, data_value);
2205 }
2206 if (data_value != nullptr && (py::isinstance<py::none>(py_index) || py::isinstance<py::ellipsis>(py_index))) {
2207 auto output = py::make_tuple(
2208 py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kByPass)), py::make_tuple(py::none()),
2209 py::make_tuple(static_cast<int>(ValueTransferType::kSetItemByEllipsis)), py::make_tuple(py::none()));
2210 return output;
2211 }
2212 std::vector<int64_t> data_transfer_types;
2213 std::vector<py::object> data_transfer_args;
2214 if (py::isinstance<py::tuple>(py_index) &&
2215 GetItemByTupleWithView(data_value, data_shape, py_index, &data_transfer_types, &data_transfer_args, data_type)) {
2216 MS_LOG(INFO) << "(View) In branch set item by tuple with view, data_shape: " << data_shape
2217 << " tensor_indexes: " << py_index;
2218 return py::make_tuple(py::str("view"), py::tuple(), py::tuple(), VectorToPyTuple(data_transfer_types),
2219 VectorToPyTuple(data_transfer_args));
2220 }
2221 MS_LOG(INFO) << "(Tensor) Set item data shape is: " << data_shape << ", index is: " << py_index
2222 << ", value is: " << py_value;
2223 TensorIndex index = TensorIndex(py_index);
2224
2225 CheckSetItemIndex(index.type(), value_type);
2226 if (index.IsList()) {
2227 if (data_shape.empty()) {
2228 MS_EXCEPTION(TypeError) << "Cannot iterate over a scalar tensor.";
2229 }
2230 index = TensorIndex::FormatList(index, data_shape[0]);
2231 }
2232
2233 return SetItemIndexByIndexType(index, py_index, data_shape, data_type, value_type, is_parameter);
2234 }
2235
SetItemIndexByIndexType(const TensorIndex & index,const py::object & py_index,const ShapeVector & data_shape,const TypePtr & data_type,const TensorIndexType & value_type,bool is_parameter)2236 py::object TensorIndex::SetItemIndexByIndexType(const TensorIndex &index, const py::object &py_index,
2237 const ShapeVector &data_shape, const TypePtr &data_type,
2238 const TensorIndexType &value_type, bool is_parameter) {
2239 py::object output =
2240 py::make_tuple(py::none(), py::none(), py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kUnknown)),
2241 py::make_tuple(py::none()));
2242 switch (index.type()) {
2243 case TensorIndexType::Integer: {
2244 output = SetItemByNumber(data_shape, data_type, is_parameter, index, value_type);
2245 break;
2246 }
2247 case TensorIndexType::Tensor: {
2248 output = SetItemByTensor(data_shape, is_parameter, index, value_type);
2249 break;
2250 }
2251 case TensorIndexType::Tuple: {
2252 output = SetItemByTuple(data_shape, data_type, index, value_type);
2253 break;
2254 }
2255 case TensorIndexType::Ellipsis:
2256 case TensorIndexType::None: {
2257 output = py::make_tuple(
2258 py::none(), py::make_tuple(static_cast<int>(ValueTransferType::kByPass)), py::make_tuple(py::none()),
2259 py::make_tuple(static_cast<int>(ValueTransferType::kSetItemByEllipsis)), py::make_tuple(py::none()));
2260 break;
2261 }
2262 case TensorIndexType::Boolean: {
2263 output = py::make_tuple(
2264 py_index, py::make_tuple(static_cast<int>(ValueTransferType::kByPass)), py::make_tuple(py::none()),
2265 py::make_tuple(static_cast<int>(ValueTransferType::kSetItemByBool)), py::make_tuple(py::none()));
2266 break;
2267 }
2268 default: {
2269 MS_EXCEPTION(TypeError)
2270 << "Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor, int, list and "
2271 "tuple as index, but got "
2272 << TensorIndex::py_index_handle_ << "with type " << TensorIndex::py_index_handle_.get_type();
2273 }
2274 }
2275
2276 return output;
2277 }
2278
2279 } // namespace mindspore::tensor
2280