• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/get_tuple_index_info.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <bitset>
22 
23 #include "ops/op_utils.h"
24 #include "utils/check_convert_utils.h"
25 #include "include/common/utils/utils.h"
26 #include "mindapi/src/helper.h"
27 #include "mindspore/core/ops/structure_ops.h"
28 
29 namespace mindspore {
30 namespace ops {
BroadCastShape(const ShapeVector & x_shape,const ShapeVector & y_shape)31 static ShapeVector BroadCastShape(const ShapeVector &x_shape, const ShapeVector &y_shape) {
32   if (x_shape == y_shape) {
33     return x_shape;
34   }
35   const size_t x_len = x_shape.size();
36   const size_t y_len = y_shape.size();
37   const size_t min_length = std::min(x_len, y_len);
38   ShapeVector broadcast_shape_back;
39   for (size_t i = 0; i < min_length; i++) {
40     size_t x_shape_index = x_len - min_length + i;
41     size_t y_shape_index = y_len - min_length + i;
42     if (x_shape[x_shape_index] == 1) {
43       (void)broadcast_shape_back.emplace_back(y_shape[y_shape_index]);
44     } else if (y_shape[y_shape_index] == 1 || x_shape[x_shape_index] == y_shape[y_shape_index]) {
45       (void)broadcast_shape_back.emplace_back(x_shape[x_shape_index]);
46     } else {
47       MS_EXCEPTION(ValueError) << "For tensor getitem or setitem, x.shape and y.shape need to broadcast. "
48                                << "The value of x.shape[" << std::to_string(x_shape_index) << "] or y.shape["
49                                << std::to_string(y_shape_index) << "] must be 1 or -1 when they are not the same but "
50                                << "got x.shape =" << x_shape << "and y.shape = " << y_shape;
51     }
52   }
53   ShapeVector broadcast_shape_front;
54   if (min_length == x_len) {
55     (void)broadcast_shape_front.insert(
56       broadcast_shape_front.end(), y_shape.begin(),
57       y_shape.begin() + static_cast<int64_t>(y_len) - static_cast<int64_t>(min_length));
58   } else {
59     (void)broadcast_shape_front.insert(
60       broadcast_shape_front.end(), x_shape.begin(),
61       x_shape.begin() + static_cast<int64_t>(x_len) - static_cast<int64_t>(min_length));
62   }
63   (void)broadcast_shape_front.insert(broadcast_shape_front.end(), broadcast_shape_back.begin(),
64                                      broadcast_shape_back.end());
65   return broadcast_shape_front;
66 }
67 
BroadCastShape(const std::vector<ShapeVector> & tensor_indexes_shapes)68 static ShapeVector BroadCastShape(const std::vector<ShapeVector> &tensor_indexes_shapes) {
69   if (tensor_indexes_shapes.empty()) {
70     return {};
71   }
72   return std::accumulate(tensor_indexes_shapes.begin(), tensor_indexes_shapes.end(), tensor_indexes_shapes[0],
73                          [](const auto &output_shape, const auto &tensor_indexes_shape) {
74                            return BroadCastShape(output_shape, tensor_indexes_shape);
75                          });
76 }
77 
GetFancyPosition(const std::vector<int64_t> & tuple_index_types,size_t fancy_position,size_t ellipse_occupy_dims,const string & tuple_index_info_type)78 static size_t GetFancyPosition(const std::vector<int64_t> &tuple_index_types, size_t fancy_position,
79                                size_t ellipse_occupy_dims, const string &tuple_index_info_type) {
80   std::vector<int64_t> final_tuple_index_types;
81   for (size_t i = 0; i < tuple_index_types.size(); i++) {
82     if (tuple_index_types[i] == kMetaTypeEllipsis) {
83       auto ellipsis_slice = std::vector<int64_t>(ellipse_occupy_dims, kObjectTypeSlice);
84       (void)final_tuple_index_types.insert(final_tuple_index_types.end(), ellipsis_slice.begin(), ellipsis_slice.end());
85     } else {
86       (void)final_tuple_index_types.emplace_back(tuple_index_types[i]);
87     }
88   }
89   MS_LOG(DEBUG) << "final_tuple_index_types" << final_tuple_index_types;
90   std::bitset<kMaxTensorIndexDimNums> tensor_position_mask = 0;
91   for (size_t i = 0; i < final_tuple_index_types.size(); i++) {
92     if (final_tuple_index_types[i] == kObjectTypeTensorType) {
93       tensor_position_mask[i] = 1;
94     }
95   }
96   if (tuple_index_info_type != kSetitemByTuple && tuple_index_info_type != kSetitemByTupleWithTensor) {
97     int64_t new_fancy_position = -1;
98     if (tensor_position_mask == 0) {
99       return 0;
100     }
101     for (size_t i = 0; i < kMaxTensorIndexDimNums; i++) {
102       if (tensor_position_mask[i] == 0) {
103         continue;
104       }
105       bool first_tensor_found = new_fancy_position != -1;
106       if (first_tensor_found && tensor_position_mask[i - 1] == 0) {
107         return 0;
108       }
109       if (!first_tensor_found) {
110         new_fancy_position = static_cast<int64_t>(i);
111       }
112     }
113     return LongToSize(new_fancy_position);
114   }
115   return fancy_position;
116 }
117 
GetSliceShape(const std::vector<int64_t> & tuple_index_types,const std::vector<ShapeVector> & tensor_shapes,bool * has_zero_tensor,size_t * slice_nums,std::vector<ShapeVector> * tensor_indices_shapes)118 static std::vector<ShapeVector> GetSliceShape(const std::vector<int64_t> &tuple_index_types,
119                                               const std::vector<ShapeVector> &tensor_shapes, bool *has_zero_tensor,
120                                               size_t *slice_nums, std::vector<ShapeVector> *tensor_indices_shapes) {
121   std::vector<ShapeVector> slice_shapes;
122   auto new_tuple_index_types = tuple_index_types;
123   for (size_t i = 0; i < tuple_index_types.size(); i++) {
124     if (new_tuple_index_types[i] == kMetaTypeEllipsis) {
125       (void)new_tuple_index_types.erase(new_tuple_index_types.begin() + i);
126       (void)new_tuple_index_types.emplace_back(kMetaTypeEllipsis);
127       break;
128     }
129   }
130   for (size_t i = 0; i < tensor_shapes.size(); i++) {
131     if (new_tuple_index_types[i] == kObjectTypeTensorType) {
132       if (!tensor_shapes[i].empty() && tensor_shapes[i][0] == 0) {
133         *has_zero_tensor = true;
134       }
135       (void)tensor_indices_shapes->emplace_back(tensor_shapes[i]);
136     } else {
137       (void)slice_shapes.emplace_back(tensor_shapes[i]);
138       *slice_nums = *slice_nums + 1;
139     }
140   }
141   return slice_shapes;
142 }
143 
ComputeSliceShape(const ShapeVector & slice_shape,size_t broadcast_shape_len,size_t slice_cnt,int64_t fancy_position)144 static ShapeVector ComputeSliceShape(const ShapeVector &slice_shape, size_t broadcast_shape_len, size_t slice_cnt,
145                                      int64_t fancy_position) {
146   ShapeVector shape(slice_shape.size(), 1);
147   if (slice_cnt < shape.size()) {
148     shape[slice_cnt] = slice_shape[slice_cnt];
149   }
150   ShapeVector temp_shape(broadcast_shape_len, 1);
151   (void)shape.insert(shape.begin() + fancy_position, temp_shape.begin(), temp_shape.end());
152   return shape;
153 }
154 
ConstGetTupleIndexInfo(const ShapeVector & data_shape,const std::vector<ShapeVector> & tensor_shapes,const std::vector<int64_t> & tuple_index_types,ShapeVector * broadcast_shape,ShapeVector * final_shape,ShapeVector * index_tensor_new_shape,size_t * fancy_position,const string & tuple_index_info_type)155 std::vector<ShapeVector> GetTupleIndexInfo::ConstGetTupleIndexInfo(
156   const ShapeVector &data_shape, const std::vector<ShapeVector> &tensor_shapes,
157   const std::vector<int64_t> &tuple_index_types, ShapeVector *broadcast_shape, ShapeVector *final_shape,
158   ShapeVector *index_tensor_new_shape, size_t *fancy_position, const string &tuple_index_info_type) {
159   // Get tuple index info: broadcast_shape
160   size_t not_ellipse_occupy_dims = 0;
161   for (size_t i = 0; i < tuple_index_types.size(); i++) {
162     if (tuple_index_types[i] != kTypeUnknown && tuple_index_types[i] != kMetaTypeEllipsis) {
163       not_ellipse_occupy_dims += 1;
164     }
165   }
166   std::vector<ShapeVector> tensor_indices_shapes;
167   std::vector<ShapeVector> slice_shapes;
168   size_t slice_nums = 0;
169   bool has_zero_tensor = false;
170   slice_shapes = GetSliceShape(tuple_index_types, tensor_shapes, &has_zero_tensor, &slice_nums, &tensor_indices_shapes);
171   MS_LOG(DEBUG) << "slice_shapes: " << slice_shapes;
172   *broadcast_shape = BroadCastShape(tensor_indices_shapes);
173   constexpr size_t min_broadcast_shape_size = 2;
174   if (tuple_index_info_type == kSetitemByTupleWithTensor && broadcast_shape->size() < min_broadcast_shape_size) {
175     (void)broadcast_shape->insert(broadcast_shape->begin(), 1);
176   }
177   MS_LOG(DEBUG) << "broadcast_shape:" << *broadcast_shape;
178   // Get tuple index info: fancy_position
179   size_t ellipse_occupy_dims = tensor_shapes.size() - not_ellipse_occupy_dims;
180   *fancy_position = GetFancyPosition(tuple_index_types, *fancy_position, ellipse_occupy_dims, tuple_index_info_type);
181   MS_LOG(DEBUG) << "fancy_position:" << *fancy_position;
182   if (tuple_index_info_type == kPreSetitemByTuple) {
183     return {};
184   }
185   // Get tuple index info: final_shape
186   size_t pre_size_len = 0;
187   for (auto type : tuple_index_types) {
188     if (type == kObjectTypeSlice) {
189       pre_size_len += 1;
190     } else if (type == kMetaTypeEllipsis) {
191       break;
192     }
193   }
194   ShapeVector slice_len;
195   size_t not_ellipse_slice_cnt = slice_shapes.size() - ellipse_occupy_dims;
196   std::transform(slice_shapes.begin(), slice_shapes.begin() + not_ellipse_slice_cnt, std::back_inserter(slice_len),
197                  [](const ShapeVector &slice_shape) {
198                    if (slice_shape.empty()) {
199                      MS_LOG(EXCEPTION) << "Slice tensor can not be empty!";
200                    }
201                    return slice_shape[0];
202                  });
203   ShapeVector ellipse_slice;
204   std::transform(slice_shapes.begin() + not_ellipse_slice_cnt, slice_shapes.end(), std::back_inserter(ellipse_slice),
205                  [](const ShapeVector &slice_shape) {
206                    if (slice_shape.empty()) {
207                      MS_LOG(EXCEPTION) << "Slice tensor can not be empty!";
208                    }
209                    return slice_shape[0];
210                  });
211   (void)slice_len.insert(slice_len.begin() + pre_size_len, ellipse_slice.begin(), ellipse_slice.end());
212   *fancy_position = std::min(*fancy_position, slice_nums);
213   *final_shape = ShapeVector(slice_len.begin(), slice_len.begin() + slice_nums);
214   (void)final_shape->insert(final_shape->begin() + *fancy_position, broadcast_shape->begin(), broadcast_shape->end());
215   has_zero_tensor =
216     has_zero_tensor || std::all_of(data_shape.begin(), data_shape.end(), [](int64_t dim) { return dim == 0; });
217   if (has_zero_tensor && tensor_shapes.size() < data_shape.size()) {
218     (void)final_shape->insert(final_shape->end(), data_shape.begin() + tensor_shapes.size(), data_shape.end());
219   }
220   MS_LOG(DEBUG) << "final_shape:" << *final_shape;
221   // Get tuple index info: index_tensor_new_shape
222   *index_tensor_new_shape = ShapeVector(slice_nums, 1);
223   *fancy_position = std::min(*fancy_position, index_tensor_new_shape->size());
224   (void)index_tensor_new_shape->insert(index_tensor_new_shape->begin() + *fancy_position, broadcast_shape->begin(),
225                                        broadcast_shape->end());
226   MS_LOG(DEBUG) << "index_tensor_new_shape:" << *index_tensor_new_shape;
227   // Get tuple index info: new_slice_shapes
228   std::vector<ShapeVector> new_slice_shapes;
229   for (size_t i = 0; i < slice_nums; i++) {
230     (void)new_slice_shapes.emplace_back(ComputeSliceShape(slice_len, broadcast_shape->size(), i, *fancy_position));
231   }
232   std::vector<ShapeVector> ellipse_slice_shape_vector(new_slice_shapes.begin() + pre_size_len,
233                                                       new_slice_shapes.begin() + pre_size_len + ellipse_occupy_dims);
234   (void)new_slice_shapes.erase(new_slice_shapes.begin() + pre_size_len,
235                                new_slice_shapes.begin() + pre_size_len + ellipse_occupy_dims);
236   (void)new_slice_shapes.insert(new_slice_shapes.end(), ellipse_slice_shape_vector.begin(),
237                                 ellipse_slice_shape_vector.end());
238   MS_LOG(DEBUG) << "new_slice_shapes:" << new_slice_shapes;
239   return new_slice_shapes;
240 }
241 
VectorToAbstract(std::vector<int64_t> nums)242 static AbstractBasePtr VectorToAbstract(std::vector<int64_t> nums) {
243   abstract::AbstractBasePtrList elems;
244   std::transform(nums.begin(), nums.end(), std::back_inserter(elems),
245                  [](int64_t num) { return std::make_shared<abstract::AbstractScalar>(num); });
246   return std::make_shared<abstract::AbstractTuple>(elems);
247 }
248 
249 MIND_API_OPERATOR_IMPL(GetTupleIndexInfo, BaseOperator);
250 
GetTupleIndexInfoInferInner(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)251 AbstractBasePtr GetTupleIndexInfoInferInner(const PrimitivePtr &primitive,
252                                             const std::vector<AbstractBasePtr> &input_args) {
253   MS_EXCEPTION_IF_NULL(primitive);
254   auto data_shape = input_args[kIndex0]->GetShape()->GetShapeVector();
255   const AbstractBasePtr &fancy_position_abs = input_args[kIndex1];
256   auto tuple_index_types = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAttrTupleIndexTypes));
257   string tuple_index_info_type;
258   if (primitive->HasAttr(kAttrTupleIndexInfoType)) {
259     tuple_index_info_type = GetValue<string>(primitive->GetAttr(kAttrTupleIndexInfoType));
260   }
261   const size_t output_size = 12;
262   if (fancy_position_abs->GetValue()->isa<ValueAny>() ||
263       std::any_of(input_args.begin() + kIndex0, input_args.end(),
264                   [](const AbstractBasePtr &shape_abs) { return shape_abs->GetShape()->IsDynamic(); })) {
265     auto scalar_abs_any = std::make_shared<abstract::AbstractScalar>(kValueAny, kInt64);
266     auto tuple_abs = std::make_shared<abstract::AbstractTuple>(std::vector<abstract::AbstractBasePtr>{scalar_abs_any});
267     AbstractBasePtrList output_abs_list(output_size - 1, tuple_abs->BroadenToDynamicLenSequence());
268     output_abs_list.insert(output_abs_list.begin(), scalar_abs_any);
269     return std::make_shared<abstract::AbstractTuple>(output_abs_list);
270   }
271   if (data_shape.size() < 1 || data_shape.size() > kMaxTensorIndexDimNums) {
272     MS_EXCEPTION(ValueError) << "The input data's dim must in the range of [1, 8], but got " << data_shape.size();
273   }
274   std::vector<ShapeVector> tensor_indices_shapes;
275   ShapeVector slice_shapes;
276   size_t valid_tensor_nums = 0;
277   int64_t expand_dims = GetValue<int64_t>(primitive->GetAttr(kAttrExpandDimsCnt));
278   for (size_t i = 0; i < tuple_index_types.size(); i++) {
279     if (tuple_index_types[i] == kMetaTypeEllipsis) {
280       valid_tensor_nums = data_shape.size() + LongToSize(expand_dims);
281       break;
282     } else if (tuple_index_types[i] != kTypeUnknown) {
283       valid_tensor_nums += 1;
284     }
285   }
286   for (size_t i = 0; i < valid_tensor_nums; i++) {
287     auto input_shape = input_args[i + kIndex2]->GetShape()->GetShapeVector();
288     (void)tensor_indices_shapes.emplace_back(input_shape);
289   }
290   MS_LOG(DEBUG) << "valid_tensor_nums:" << valid_tensor_nums;
291   ShapeVector broadcast_shape;
292   ShapeVector final_shape;
293   ShapeVector index_tensor_new_shape;
294   auto fancy_position_opt = GetScalarValue<int64_t>(fancy_position_abs->GetValue());
295   if (!fancy_position_opt.has_value()) {
296     MS_EXCEPTION(ValueError) << "The value of fancy_position should not be none.";
297   }
298   auto fancy_position = fancy_position_opt.value();
299   auto new_slice_shapes = GetTupleIndexInfo::ConstGetTupleIndexInfo(
300     data_shape, tensor_indices_shapes, tuple_index_types, &broadcast_shape, &final_shape, &index_tensor_new_shape,
301     reinterpret_cast<size_t *>(&fancy_position), tuple_index_info_type);
302 
303   AbstractBasePtrList abs_list{std::make_shared<abstract::AbstractScalar>(fancy_position),
304                                VectorToAbstract(broadcast_shape), VectorToAbstract(index_tensor_new_shape),
305                                VectorToAbstract(final_shape)};
306   for (auto new_slice_shape : new_slice_shapes) {
307     (void)abs_list.emplace_back(VectorToAbstract(new_slice_shape));
308   }
309 
310   for (size_t i = 0; i < kMaxTensorIndexDimNums - new_slice_shapes.size(); i++) {
311     ShapeVector shape(1);
312     (void)abs_list.emplace_back(VectorToAbstract(shape));
313   }
314   auto output_abs = std::make_shared<abstract::AbstractTuple>(abs_list);
315   return output_abs;
316 }
317 
318 class MIND_API GetTupleIndexInfoInfer : public abstract::OpInferBase {
319  public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const320   BaseShapePtr InferShape(const PrimitivePtr &primitive,
321                           const std::vector<AbstractBasePtr> &input_args) const override {
322     return GetTupleIndexInfoInferInner(primitive, input_args)->GetShape();
323   }
324 
InferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args) const325   TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
326     return GetTupleIndexInfoInferInner(prim, input_args)->GetType();
327   }
328 
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const329   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
330                                     const std::vector<AbstractBasePtr> &input_args) const override {
331     return GetTupleIndexInfoInferInner(primitive, input_args);
332   }
333 };
334 
335 REGISTER_PRIMITIVE_OP_INFER_IMPL(GetTupleIndexInfo, prim::kPrimGetTupleIndexInfo, GetTupleIndexInfoInfer, false);
336 }  // namespace ops
337 }  // namespace mindspore
338