1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/get_tuple_index_info.h"
18
19 #include <algorithm>
20 #include <memory>
21 #include <bitset>
22
23 #include "ops/op_utils.h"
24 #include "utils/check_convert_utils.h"
25 #include "include/common/utils/utils.h"
26 #include "mindapi/src/helper.h"
27 #include "mindspore/core/ops/structure_ops.h"
28
29 namespace mindspore {
30 namespace ops {
BroadCastShape(const ShapeVector & x_shape,const ShapeVector & y_shape)31 static ShapeVector BroadCastShape(const ShapeVector &x_shape, const ShapeVector &y_shape) {
32 if (x_shape == y_shape) {
33 return x_shape;
34 }
35 const size_t x_len = x_shape.size();
36 const size_t y_len = y_shape.size();
37 const size_t min_length = std::min(x_len, y_len);
38 ShapeVector broadcast_shape_back;
39 for (size_t i = 0; i < min_length; i++) {
40 size_t x_shape_index = x_len - min_length + i;
41 size_t y_shape_index = y_len - min_length + i;
42 if (x_shape[x_shape_index] == 1) {
43 (void)broadcast_shape_back.emplace_back(y_shape[y_shape_index]);
44 } else if (y_shape[y_shape_index] == 1 || x_shape[x_shape_index] == y_shape[y_shape_index]) {
45 (void)broadcast_shape_back.emplace_back(x_shape[x_shape_index]);
46 } else {
47 MS_EXCEPTION(ValueError) << "For tensor getitem or setitem, x.shape and y.shape need to broadcast. "
48 << "The value of x.shape[" << std::to_string(x_shape_index) << "] or y.shape["
49 << std::to_string(y_shape_index) << "] must be 1 or -1 when they are not the same but "
50 << "got x.shape =" << x_shape << "and y.shape = " << y_shape;
51 }
52 }
53 ShapeVector broadcast_shape_front;
54 if (min_length == x_len) {
55 (void)broadcast_shape_front.insert(
56 broadcast_shape_front.end(), y_shape.begin(),
57 y_shape.begin() + static_cast<int64_t>(y_len) - static_cast<int64_t>(min_length));
58 } else {
59 (void)broadcast_shape_front.insert(
60 broadcast_shape_front.end(), x_shape.begin(),
61 x_shape.begin() + static_cast<int64_t>(x_len) - static_cast<int64_t>(min_length));
62 }
63 (void)broadcast_shape_front.insert(broadcast_shape_front.end(), broadcast_shape_back.begin(),
64 broadcast_shape_back.end());
65 return broadcast_shape_front;
66 }
67
BroadCastShape(const std::vector<ShapeVector> & tensor_indexes_shapes)68 static ShapeVector BroadCastShape(const std::vector<ShapeVector> &tensor_indexes_shapes) {
69 if (tensor_indexes_shapes.empty()) {
70 return {};
71 }
72 return std::accumulate(tensor_indexes_shapes.begin(), tensor_indexes_shapes.end(), tensor_indexes_shapes[0],
73 [](const auto &output_shape, const auto &tensor_indexes_shape) {
74 return BroadCastShape(output_shape, tensor_indexes_shape);
75 });
76 }
77
GetFancyPosition(const std::vector<int64_t> & tuple_index_types,size_t fancy_position,size_t ellipse_occupy_dims,const string & tuple_index_info_type)78 static size_t GetFancyPosition(const std::vector<int64_t> &tuple_index_types, size_t fancy_position,
79 size_t ellipse_occupy_dims, const string &tuple_index_info_type) {
80 std::vector<int64_t> final_tuple_index_types;
81 for (size_t i = 0; i < tuple_index_types.size(); i++) {
82 if (tuple_index_types[i] == kMetaTypeEllipsis) {
83 auto ellipsis_slice = std::vector<int64_t>(ellipse_occupy_dims, kObjectTypeSlice);
84 (void)final_tuple_index_types.insert(final_tuple_index_types.end(), ellipsis_slice.begin(), ellipsis_slice.end());
85 } else {
86 (void)final_tuple_index_types.emplace_back(tuple_index_types[i]);
87 }
88 }
89 MS_LOG(DEBUG) << "final_tuple_index_types" << final_tuple_index_types;
90 std::bitset<kMaxTensorIndexDimNums> tensor_position_mask = 0;
91 for (size_t i = 0; i < final_tuple_index_types.size(); i++) {
92 if (final_tuple_index_types[i] == kObjectTypeTensorType) {
93 tensor_position_mask[i] = 1;
94 }
95 }
96 if (tuple_index_info_type != kSetitemByTuple && tuple_index_info_type != kSetitemByTupleWithTensor) {
97 int64_t new_fancy_position = -1;
98 if (tensor_position_mask == 0) {
99 return 0;
100 }
101 for (size_t i = 0; i < kMaxTensorIndexDimNums; i++) {
102 if (tensor_position_mask[i] == 0) {
103 continue;
104 }
105 bool first_tensor_found = new_fancy_position != -1;
106 if (first_tensor_found && tensor_position_mask[i - 1] == 0) {
107 return 0;
108 }
109 if (!first_tensor_found) {
110 new_fancy_position = static_cast<int64_t>(i);
111 }
112 }
113 return LongToSize(new_fancy_position);
114 }
115 return fancy_position;
116 }
117
GetSliceShape(const std::vector<int64_t> & tuple_index_types,const std::vector<ShapeVector> & tensor_shapes,bool * has_zero_tensor,size_t * slice_nums,std::vector<ShapeVector> * tensor_indices_shapes)118 static std::vector<ShapeVector> GetSliceShape(const std::vector<int64_t> &tuple_index_types,
119 const std::vector<ShapeVector> &tensor_shapes, bool *has_zero_tensor,
120 size_t *slice_nums, std::vector<ShapeVector> *tensor_indices_shapes) {
121 std::vector<ShapeVector> slice_shapes;
122 auto new_tuple_index_types = tuple_index_types;
123 for (size_t i = 0; i < tuple_index_types.size(); i++) {
124 if (new_tuple_index_types[i] == kMetaTypeEllipsis) {
125 (void)new_tuple_index_types.erase(new_tuple_index_types.begin() + i);
126 (void)new_tuple_index_types.emplace_back(kMetaTypeEllipsis);
127 break;
128 }
129 }
130 for (size_t i = 0; i < tensor_shapes.size(); i++) {
131 if (new_tuple_index_types[i] == kObjectTypeTensorType) {
132 if (!tensor_shapes[i].empty() && tensor_shapes[i][0] == 0) {
133 *has_zero_tensor = true;
134 }
135 (void)tensor_indices_shapes->emplace_back(tensor_shapes[i]);
136 } else {
137 (void)slice_shapes.emplace_back(tensor_shapes[i]);
138 *slice_nums = *slice_nums + 1;
139 }
140 }
141 return slice_shapes;
142 }
143
ComputeSliceShape(const ShapeVector & slice_shape,size_t broadcast_shape_len,size_t slice_cnt,int64_t fancy_position)144 static ShapeVector ComputeSliceShape(const ShapeVector &slice_shape, size_t broadcast_shape_len, size_t slice_cnt,
145 int64_t fancy_position) {
146 ShapeVector shape(slice_shape.size(), 1);
147 if (slice_cnt < shape.size()) {
148 shape[slice_cnt] = slice_shape[slice_cnt];
149 }
150 ShapeVector temp_shape(broadcast_shape_len, 1);
151 (void)shape.insert(shape.begin() + fancy_position, temp_shape.begin(), temp_shape.end());
152 return shape;
153 }
154
ConstGetTupleIndexInfo(const ShapeVector & data_shape,const std::vector<ShapeVector> & tensor_shapes,const std::vector<int64_t> & tuple_index_types,ShapeVector * broadcast_shape,ShapeVector * final_shape,ShapeVector * index_tensor_new_shape,size_t * fancy_position,const string & tuple_index_info_type)155 std::vector<ShapeVector> GetTupleIndexInfo::ConstGetTupleIndexInfo(
156 const ShapeVector &data_shape, const std::vector<ShapeVector> &tensor_shapes,
157 const std::vector<int64_t> &tuple_index_types, ShapeVector *broadcast_shape, ShapeVector *final_shape,
158 ShapeVector *index_tensor_new_shape, size_t *fancy_position, const string &tuple_index_info_type) {
159 // Get tuple index info: broadcast_shape
160 size_t not_ellipse_occupy_dims = 0;
161 for (size_t i = 0; i < tuple_index_types.size(); i++) {
162 if (tuple_index_types[i] != kTypeUnknown && tuple_index_types[i] != kMetaTypeEllipsis) {
163 not_ellipse_occupy_dims += 1;
164 }
165 }
166 std::vector<ShapeVector> tensor_indices_shapes;
167 std::vector<ShapeVector> slice_shapes;
168 size_t slice_nums = 0;
169 bool has_zero_tensor = false;
170 slice_shapes = GetSliceShape(tuple_index_types, tensor_shapes, &has_zero_tensor, &slice_nums, &tensor_indices_shapes);
171 MS_LOG(DEBUG) << "slice_shapes: " << slice_shapes;
172 *broadcast_shape = BroadCastShape(tensor_indices_shapes);
173 constexpr size_t min_broadcast_shape_size = 2;
174 if (tuple_index_info_type == kSetitemByTupleWithTensor && broadcast_shape->size() < min_broadcast_shape_size) {
175 (void)broadcast_shape->insert(broadcast_shape->begin(), 1);
176 }
177 MS_LOG(DEBUG) << "broadcast_shape:" << *broadcast_shape;
178 // Get tuple index info: fancy_position
179 size_t ellipse_occupy_dims = tensor_shapes.size() - not_ellipse_occupy_dims;
180 *fancy_position = GetFancyPosition(tuple_index_types, *fancy_position, ellipse_occupy_dims, tuple_index_info_type);
181 MS_LOG(DEBUG) << "fancy_position:" << *fancy_position;
182 if (tuple_index_info_type == kPreSetitemByTuple) {
183 return {};
184 }
185 // Get tuple index info: final_shape
186 size_t pre_size_len = 0;
187 for (auto type : tuple_index_types) {
188 if (type == kObjectTypeSlice) {
189 pre_size_len += 1;
190 } else if (type == kMetaTypeEllipsis) {
191 break;
192 }
193 }
194 ShapeVector slice_len;
195 size_t not_ellipse_slice_cnt = slice_shapes.size() - ellipse_occupy_dims;
196 std::transform(slice_shapes.begin(), slice_shapes.begin() + not_ellipse_slice_cnt, std::back_inserter(slice_len),
197 [](const ShapeVector &slice_shape) {
198 if (slice_shape.empty()) {
199 MS_LOG(EXCEPTION) << "Slice tensor can not be empty!";
200 }
201 return slice_shape[0];
202 });
203 ShapeVector ellipse_slice;
204 std::transform(slice_shapes.begin() + not_ellipse_slice_cnt, slice_shapes.end(), std::back_inserter(ellipse_slice),
205 [](const ShapeVector &slice_shape) {
206 if (slice_shape.empty()) {
207 MS_LOG(EXCEPTION) << "Slice tensor can not be empty!";
208 }
209 return slice_shape[0];
210 });
211 (void)slice_len.insert(slice_len.begin() + pre_size_len, ellipse_slice.begin(), ellipse_slice.end());
212 *fancy_position = std::min(*fancy_position, slice_nums);
213 *final_shape = ShapeVector(slice_len.begin(), slice_len.begin() + slice_nums);
214 (void)final_shape->insert(final_shape->begin() + *fancy_position, broadcast_shape->begin(), broadcast_shape->end());
215 has_zero_tensor =
216 has_zero_tensor || std::all_of(data_shape.begin(), data_shape.end(), [](int64_t dim) { return dim == 0; });
217 if (has_zero_tensor && tensor_shapes.size() < data_shape.size()) {
218 (void)final_shape->insert(final_shape->end(), data_shape.begin() + tensor_shapes.size(), data_shape.end());
219 }
220 MS_LOG(DEBUG) << "final_shape:" << *final_shape;
221 // Get tuple index info: index_tensor_new_shape
222 *index_tensor_new_shape = ShapeVector(slice_nums, 1);
223 *fancy_position = std::min(*fancy_position, index_tensor_new_shape->size());
224 (void)index_tensor_new_shape->insert(index_tensor_new_shape->begin() + *fancy_position, broadcast_shape->begin(),
225 broadcast_shape->end());
226 MS_LOG(DEBUG) << "index_tensor_new_shape:" << *index_tensor_new_shape;
227 // Get tuple index info: new_slice_shapes
228 std::vector<ShapeVector> new_slice_shapes;
229 for (size_t i = 0; i < slice_nums; i++) {
230 (void)new_slice_shapes.emplace_back(ComputeSliceShape(slice_len, broadcast_shape->size(), i, *fancy_position));
231 }
232 std::vector<ShapeVector> ellipse_slice_shape_vector(new_slice_shapes.begin() + pre_size_len,
233 new_slice_shapes.begin() + pre_size_len + ellipse_occupy_dims);
234 (void)new_slice_shapes.erase(new_slice_shapes.begin() + pre_size_len,
235 new_slice_shapes.begin() + pre_size_len + ellipse_occupy_dims);
236 (void)new_slice_shapes.insert(new_slice_shapes.end(), ellipse_slice_shape_vector.begin(),
237 ellipse_slice_shape_vector.end());
238 MS_LOG(DEBUG) << "new_slice_shapes:" << new_slice_shapes;
239 return new_slice_shapes;
240 }
241
VectorToAbstract(std::vector<int64_t> nums)242 static AbstractBasePtr VectorToAbstract(std::vector<int64_t> nums) {
243 abstract::AbstractBasePtrList elems;
244 std::transform(nums.begin(), nums.end(), std::back_inserter(elems),
245 [](int64_t num) { return std::make_shared<abstract::AbstractScalar>(num); });
246 return std::make_shared<abstract::AbstractTuple>(elems);
247 }
248
249 MIND_API_OPERATOR_IMPL(GetTupleIndexInfo, BaseOperator);
250
GetTupleIndexInfoInferInner(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)251 AbstractBasePtr GetTupleIndexInfoInferInner(const PrimitivePtr &primitive,
252 const std::vector<AbstractBasePtr> &input_args) {
253 MS_EXCEPTION_IF_NULL(primitive);
254 auto data_shape = input_args[kIndex0]->GetShape()->GetShapeVector();
255 const AbstractBasePtr &fancy_position_abs = input_args[kIndex1];
256 auto tuple_index_types = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAttrTupleIndexTypes));
257 string tuple_index_info_type;
258 if (primitive->HasAttr(kAttrTupleIndexInfoType)) {
259 tuple_index_info_type = GetValue<string>(primitive->GetAttr(kAttrTupleIndexInfoType));
260 }
261 const size_t output_size = 12;
262 if (fancy_position_abs->GetValue()->isa<ValueAny>() ||
263 std::any_of(input_args.begin() + kIndex0, input_args.end(),
264 [](const AbstractBasePtr &shape_abs) { return shape_abs->GetShape()->IsDynamic(); })) {
265 auto scalar_abs_any = std::make_shared<abstract::AbstractScalar>(kValueAny, kInt64);
266 auto tuple_abs = std::make_shared<abstract::AbstractTuple>(std::vector<abstract::AbstractBasePtr>{scalar_abs_any});
267 AbstractBasePtrList output_abs_list(output_size - 1, tuple_abs->BroadenToDynamicLenSequence());
268 output_abs_list.insert(output_abs_list.begin(), scalar_abs_any);
269 return std::make_shared<abstract::AbstractTuple>(output_abs_list);
270 }
271 if (data_shape.size() < 1 || data_shape.size() > kMaxTensorIndexDimNums) {
272 MS_EXCEPTION(ValueError) << "The input data's dim must in the range of [1, 8], but got " << data_shape.size();
273 }
274 std::vector<ShapeVector> tensor_indices_shapes;
275 ShapeVector slice_shapes;
276 size_t valid_tensor_nums = 0;
277 int64_t expand_dims = GetValue<int64_t>(primitive->GetAttr(kAttrExpandDimsCnt));
278 for (size_t i = 0; i < tuple_index_types.size(); i++) {
279 if (tuple_index_types[i] == kMetaTypeEllipsis) {
280 valid_tensor_nums = data_shape.size() + LongToSize(expand_dims);
281 break;
282 } else if (tuple_index_types[i] != kTypeUnknown) {
283 valid_tensor_nums += 1;
284 }
285 }
286 for (size_t i = 0; i < valid_tensor_nums; i++) {
287 auto input_shape = input_args[i + kIndex2]->GetShape()->GetShapeVector();
288 (void)tensor_indices_shapes.emplace_back(input_shape);
289 }
290 MS_LOG(DEBUG) << "valid_tensor_nums:" << valid_tensor_nums;
291 ShapeVector broadcast_shape;
292 ShapeVector final_shape;
293 ShapeVector index_tensor_new_shape;
294 auto fancy_position_opt = GetScalarValue<int64_t>(fancy_position_abs->GetValue());
295 if (!fancy_position_opt.has_value()) {
296 MS_EXCEPTION(ValueError) << "The value of fancy_position should not be none.";
297 }
298 auto fancy_position = fancy_position_opt.value();
299 auto new_slice_shapes = GetTupleIndexInfo::ConstGetTupleIndexInfo(
300 data_shape, tensor_indices_shapes, tuple_index_types, &broadcast_shape, &final_shape, &index_tensor_new_shape,
301 reinterpret_cast<size_t *>(&fancy_position), tuple_index_info_type);
302
303 AbstractBasePtrList abs_list{std::make_shared<abstract::AbstractScalar>(fancy_position),
304 VectorToAbstract(broadcast_shape), VectorToAbstract(index_tensor_new_shape),
305 VectorToAbstract(final_shape)};
306 for (auto new_slice_shape : new_slice_shapes) {
307 (void)abs_list.emplace_back(VectorToAbstract(new_slice_shape));
308 }
309
310 for (size_t i = 0; i < kMaxTensorIndexDimNums - new_slice_shapes.size(); i++) {
311 ShapeVector shape(1);
312 (void)abs_list.emplace_back(VectorToAbstract(shape));
313 }
314 auto output_abs = std::make_shared<abstract::AbstractTuple>(abs_list);
315 return output_abs;
316 }
317
318 class MIND_API GetTupleIndexInfoInfer : public abstract::OpInferBase {
319 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const320 BaseShapePtr InferShape(const PrimitivePtr &primitive,
321 const std::vector<AbstractBasePtr> &input_args) const override {
322 return GetTupleIndexInfoInferInner(primitive, input_args)->GetShape();
323 }
324
InferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args) const325 TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
326 return GetTupleIndexInfoInferInner(prim, input_args)->GetType();
327 }
328
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const329 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
330 const std::vector<AbstractBasePtr> &input_args) const override {
331 return GetTupleIndexInfoInferInner(primitive, input_args);
332 }
333 };
334
335 REGISTER_PRIMITIVE_OP_INFER_IMPL(GetTupleIndexInfo, prim::kPrimGetTupleIndexInfo, GetTupleIndexInfoInfer, false);
336 } // namespace ops
337 } // namespace mindspore
338