• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2022 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "abstract/ops/primitive_infer_map.h"
20 #include <string>
21 #include <vector>
22 #include <set>
23 #include <algorithm>
24 #include <cstdint>
25 #include <iterator>
26 
27 #include "abstract/utils.h"
28 #include "ops/sparse_ops.h"
29 #include "ops/random_ops.h"
30 #include "ops/conv_pool_ops.h"
31 #include "ops/other_ops.h"
32 #include "ops/nn_ops.h"
33 #include "ops/math_ops.h"
34 #include "ops/image_ops.h"
35 #include "ops/array_ops.h"
36 #include "ops/framework_ops.h"
37 #include "ops/ops_frontend_func_impl.h"
38 #include "ops/op_def.h"
39 #include "ops/shape_calc.h"
40 #include "ops/op_utils.h"
41 #include "include/common/utils/utils.h"
42 #include "utils/ms_context.h"
43 
44 namespace mindspore {
45 namespace abstract {
GetDependValueSize(const ValuePtr & value)46 int64_t GetDependValueSize(const ValuePtr &value) {
47   if (value->isa<Int64Imm>()) {
48     return GetValue<int64_t>(value);
49   }
50   if (!value->isa<ValueTuple>()) {
51     MS_LOG(EXCEPTION) << "the element of attr[dyn_input_size] should be all int64 of ValueTuple but got"
52                       << value->ToString() << ", type :" << value->type_name();
53   }
54   int64_t size = 0;
55   auto value_tuple = value->cast_ptr<ValueTuple>();
56   MS_EXCEPTION_IF_NULL(value_tuple);
57   for (size_t i = 0; i < value_tuple->size(); ++i) {
58     size += GetDependValueSize((*value_tuple)[i]);
59   }
60   return size;
61 }
62 
CheckScalarValid(const AbstractBasePtr & input_abstract)63 bool CheckScalarValid(const AbstractBasePtr &input_abstract) {
64   // Now, only scalar with int/float/uint will be used as the output of operator, so only add them to list.
65   if (input_abstract->isa<abstract::AbstractScalar>()) {
66     auto scalar_id = NormalizeTypeId(input_abstract->BuildType()->type_id());
67     return (scalar_id == kNumberTypeBool || scalar_id == kNumberTypeInt || scalar_id == kNumberTypeFloat ||
68             scalar_id == kNumberTypeUInt);
69   }
70   return false;
71 }
72 
CheckNeedAddToDependList(const AbstractBasePtr & input_abstract)73 bool CheckNeedAddToDependList(const AbstractBasePtr &input_abstract) {
74   auto is_tensor = input_abstract->isa<abstract::AbstractTensor>();
75   bool is_integer = false;
76   bool is_tuple_scalar_or_tensor = false;
77   is_integer = CheckScalarValid(input_abstract);
78   if (input_abstract->isa<abstract::AbstractTuple>()) {
79     auto tuple_abs = input_abstract->cast_ptr<abstract::AbstractTuple>();
80     auto elements = tuple_abs->elements();
81     is_tuple_scalar_or_tensor = std::all_of(elements.begin(), elements.end(), [](const AbstractBasePtr &element) {
82       return (CheckScalarValid(element)) || element->isa<abstract::AbstractTensor>();
83     });
84   }
85   return is_tensor || is_integer || is_tuple_scalar_or_tensor;
86 }
87 
RectifyDependListFromDynamicInputAttr(const CNodePtr & cnode,const PrimitivePtr & primitive,const std::set<int64_t> & ori_depend_list)88 std::set<int64_t> RectifyDependListFromDynamicInputAttr(const CNodePtr &cnode, const PrimitivePtr &primitive,
89                                                         const std::set<int64_t> &ori_depend_list) {
90   std::set<int64_t> rec_depend_list = {};
91   constexpr auto all_tensor_inputs = -1;
92   if (ori_depend_list.size() == 1 && *(ori_depend_list.cbegin()) == all_tensor_inputs) {
93     for (size_t i = 1; i < cnode->size(); ++i) {
94       const auto &input = cnode->inputs()[i];
95       const auto &input_abstract = input->abstract();
96       if (input_abstract != nullptr) {
97         auto need_add_to_depend_list = CheckNeedAddToDependList(input_abstract);
98         if (need_add_to_depend_list) {
99           (void)rec_depend_list.emplace(SizeToLong(i - 1));
100         }
101       }
102     }
103     return rec_depend_list;
104   }
105 
106   auto attr = primitive->GetAttr(kAttrDynInputSizes);
107   if (attr == nullptr) {
108     return ori_depend_list;
109   }
110 
111   // mapping from input prototype index to corresponding start index of real input
112   std::vector<int64_t> dyn_input_sizes = GetValue<std::vector<int64_t>>(attr);
113   std::vector<int64_t> proto2real;
114   int64_t count = 0;
115   std::for_each(dyn_input_sizes.begin(), dyn_input_sizes.end(), [&count, &proto2real](int64_t dyn_size) {
116     proto2real.push_back(count);
117     count += dyn_size < 0 ? 1 : dyn_size;
118   });
119 
120   std::for_each(ori_depend_list.begin(), ori_depend_list.end(),
121                 [&proto2real, &dyn_input_sizes, &primitive, &rec_depend_list](int64_t proto_idx) {
122                   if (proto_idx >= static_cast<int64_t>(dyn_input_sizes.size())) {
123                     MS_LOG(EXCEPTION) << "The value depend index " << proto_idx << " of primitive " << primitive->name()
124                                       << " is out of range [0, " << dyn_input_sizes.size() << ").";
125                   }
126                   // value depend input is a normal input
127                   if (dyn_input_sizes[proto_idx] < 0) {
128                     rec_depend_list.insert(proto2real[proto_idx]);
129                   }
130                   // value depend input is is a dynamic input
131                   for (int64_t i = 0; i < dyn_input_sizes[proto_idx]; ++i) {
132                     rec_depend_list.insert(proto2real[proto_idx] + i);
133                   }
134                 });
135 
136   return rec_depend_list;
137 }
138 
GetValueDependArgIndices(const CNodePtr & cnode,bool is_proto)139 std::set<int64_t> GetValueDependArgIndices(const CNodePtr &cnode, bool is_proto) {
140   MS_EXCEPTION_IF_NULL(cnode);
141   if (cnode->inputs().empty()) {
142     MS_LOG(EXCEPTION) << "Invalid inputs";
143   }
144   auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
145   if (primitive == nullptr) {
146     return {};
147   }
148   auto prim_name = primitive->name();
149   std::set<int64_t> ori = {};
150   auto op_infer_opt = GetPrimitiveInferImpl(primitive);
151   if (!op_infer_opt.has_value()) {
152     // some operator will be mapped to new operator on Ascend like GatherV2, however they use same Infer information
153     if (primitive->HasAttr(kAttrMeOpName)) {
154       auto ori_prim_name = GetValue<std::string>(primitive->GetAttr(kAttrMeOpName));
155       op_infer_opt = GetPrimitiveInferImpl(std::make_shared<Primitive>(ori_prim_name));
156     }
157   }
158 
159   if (op_infer_opt.has_value()) {
160     auto op_infer = op_infer_opt.value().Get();
161     if (op_infer != nullptr && ori.empty()) {
162       ori = op_infer->GetValueDependArgIndices();
163     }
164     if (prim_name == ops::kNameShapeCalc) {
165       auto only_depend_shape = GetValue<std::vector<bool>>(primitive->GetAttr(kAttrOnlyDependShape));
166       for (size_t i = 0; i < only_depend_shape.size(); i++) {
167         if (!only_depend_shape[i]) {
168           ori.insert(i);
169         }
170       }
171     }
172   } else if (ori.empty()) {
173     MS_LOG(DEBUG) << "Not find infer function GetValueDependArgIndices, prim name: " << prim_name;
174     // if not found in infer, consider all the non-tensor inputs as value depend args.
175     ori = ops::GetInputDependValueList(primitive);
176     if (prim_name == ops::kNameAvgPoolGrad && primitive->HasAttr(kAttrValueDepend)) {
177       auto value_depend_vector = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAttrValueDepend));
178       ori.clear();
179       ori.insert(value_depend_vector.begin(), value_depend_vector.end());
180     }
181   }
182   if (ori.empty()) {
183     return ori;
184   }
185   size_t input_num = cnode->size() - 1;
186   std::set<int64_t> res = {};
187 
188   (void)std::copy_if(ori.begin(), ori.end(), std::inserter(res, res.begin()),
189                      [&](int64_t idx) { return idx < SizeToLong(input_num); });
190   if (is_proto) {
191     return res;
192   }
193   return RectifyDependListFromDynamicInputAttr(cnode, primitive, res);
194 }
195 
GetPrimitiveInferMapPtr()196 PrimitiveEvalImplMap *GetPrimitiveInferMapPtr() {
197   static PrimitiveEvalImplMap prim_eval_implement_map{
198     // core/ops infer
199     // Do not add anything in this initializer anymore since it will be removed soon, core/ops prim should register its
200     // infer in its cc file.
201   };
202   return &prim_eval_implement_map;
203 }
GetPrimitiveInferMap()204 const PrimitiveEvalImplMap &GetPrimitiveInferMap() { return *GetPrimitiveInferMapPtr(); }
205 
GetPrimitiveInferImpl(const PrimitivePtr & primitive)206 std::optional<StandardPrimitiveImplReg> GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
207   auto iter = GetPrimitiveInferMap().find(primitive);
208   if (iter != GetPrimitiveInferMap().end()) {
209     return iter->second;
210   }
211 
212   iter = GetDeprecatedPrimitiveInferMap().find(primitive);
213   if (iter != GetDeprecatedPrimitiveInferMap().end()) {
214     return iter->second;
215   }
216   return std::optional<StandardPrimitiveImplReg>();
217 }
218 
219 class OpInferCommon : public OpInferBase {
220  public:
221   OpInferCommon() = delete;
OpInferCommon(const InferAbstractImpl & infer_impl,const InferValueImpl & infer_value_impl)222   OpInferCommon(const InferAbstractImpl &infer_impl, const InferValueImpl &infer_value_impl)
223       : infer_impl_(infer_impl), infer_value_impl_(infer_value_impl) {}
224   ~OpInferCommon() = default;
225 
226   BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override;
227   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override;
228   ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override;
229   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
230                                     const std::vector<AbstractBasePtr> &input_args) const override;
231 
232  private:
233   InferAbstractImpl infer_impl_{nullptr};
234   InferValueImpl infer_value_impl_{nullptr};
235 };
236 
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const237 BaseShapePtr OpInferCommon::InferShape(const PrimitivePtr &primitive,
238                                        const std::vector<AbstractBasePtr> &input_args) const {
239   if (!infer_impl_) {
240     return nullptr;
241   }
242 
243   auto inferred_res = infer_impl_(nullptr, primitive, input_args);
244   if (inferred_res == nullptr) {
245     return nullptr;
246   }
247 
248   return inferred_res->GetShape();
249 }
250 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const251 TypePtr OpInferCommon::InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
252   if (!infer_impl_) {
253     return nullptr;
254   }
255 
256   auto inferred_res = infer_impl_(nullptr, primitive, input_args);
257   if (inferred_res == nullptr) {
258     return nullptr;
259   }
260 
261   return inferred_res->BuildType();
262 }
263 
InferValue(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const264 ValuePtr OpInferCommon::InferValue(const PrimitivePtr &primitive,
265                                    const std::vector<AbstractBasePtr> &input_args) const {
266   if (!infer_value_impl_) {
267     return nullptr;
268   }
269   return infer_value_impl_(primitive, input_args);
270 }
271 
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const272 AbstractBasePtr OpInferCommon::InferShapeAndType(const abstract::AnalysisEnginePtr &engine,
273                                                  const PrimitivePtr &primitive,
274                                                  const std::vector<AbstractBasePtr> &input_args) const {
275   if (!infer_impl_) {
276     return nullptr;
277   }
278 
279   return infer_impl_(engine, primitive, input_args);
280 }
281 
StandardPrimitiveImplReg(const InferAbstractImpl & infer_abstract,const InferValueImpl & infer_value,bool in_white_list)282 StandardPrimitiveImplReg::StandardPrimitiveImplReg(const InferAbstractImpl &infer_abstract,
283                                                    const InferValueImpl &infer_value, bool in_white_list) {
284   op_infer_ = std::make_shared<OpInferCommon>(infer_abstract, infer_value);
285   is_impl_infer_shape_and_type_ = infer_abstract != nullptr;
286   is_impl_infer_value_ = infer_value != nullptr;
287   in_white_list_ = in_white_list;
288 }
289 
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const290 AbstractBasePtr StandardPrimitiveImplReg::InferShapeAndType(const abstract::AnalysisEnginePtr &engine,
291                                                             const PrimitivePtr &primitive,
292                                                             const std::vector<AbstractBasePtr> &input_args) const {
293   if (op_infer_ == nullptr) {
294     return nullptr;
295   }
296 
297   return op_infer_->InferShapeAndType(engine, primitive, input_args);
298 }
299 
InferShape(const PrimitivePtr & prim,const AbstractBasePtrList & args) const300 BaseShapePtr StandardPrimitiveImplReg::InferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args) const {
301   if (op_infer_ == nullptr) {
302     return nullptr;
303   }
304 
305   return op_infer_->InferShape(prim, args);
306 }
307 
InferType(const PrimitivePtr & prim,const AbstractBasePtrList & args) const308 TypePtr StandardPrimitiveImplReg::InferType(const PrimitivePtr &prim, const AbstractBasePtrList &args) const {
309   if (op_infer_ == nullptr) {
310     return nullptr;
311   }
312 
313   return op_infer_->InferType(prim, args);
314 }
315 
InferValue(const PrimitivePtr & prim,const AbstractBasePtrList & args) const316 ValuePtr StandardPrimitiveImplReg::InferValue(const PrimitivePtr &prim, const AbstractBasePtrList &args) const {
317   if (op_infer_ == nullptr) {
318     return nullptr;
319   }
320 
321   return op_infer_->InferValue(prim, args);
322 }
323 
InferShapeByFuncImpl(const PrimitivePtr & primitive,const AbstractBasePtrList & input_args,bool compile_phase)324 std::optional<BaseShapePtr> InferShapeByFuncImpl(const PrimitivePtr &primitive, const AbstractBasePtrList &input_args,
325                                                  bool compile_phase) {
326   MS_EXCEPTION_IF_NULL(primitive);
327   auto op_name = primitive->name();
328   if (compile_phase) {
329     auto frontend_func_impl = ops::GetOpFrontendFuncImplPtr(op_name);
330     if (frontend_func_impl != nullptr) {
331       auto infer_result = frontend_func_impl->InferAbstract(primitive, input_args);
332       if (infer_result != nullptr) {
333         return infer_result->GetShape();
334       }
335     }
336   }
337 
338   auto op_def = ops::GetOpDef(op_name);
339   if (op_def == nullptr) {
340     return std::nullopt;
341   }
342   (void)op_def->func_impl_.CheckValidation(primitive, input_args);
343   return op_def->func_impl_.InferShape(primitive, input_args);
344 }
345 
InferTypeByFuncImpl(const PrimitivePtr & primitive,const AbstractBasePtrList & input_args,bool compile_phase)346 std::optional<TypePtr> InferTypeByFuncImpl(const PrimitivePtr &primitive, const AbstractBasePtrList &input_args,
347                                            bool compile_phase) {
348   MS_EXCEPTION_IF_NULL(primitive);
349   auto op_name = primitive->name();
350   if (compile_phase) {
351     auto frontend_func_impl = ops::GetOpFrontendFuncImplPtr(op_name);
352     if (frontend_func_impl != nullptr) {
353       auto infer_result = frontend_func_impl->InferAbstract(primitive, input_args);
354       if (infer_result != nullptr) {
355         return infer_result->GetType();
356       }
357     }
358   }
359 
360   auto op_def = ops::GetOpDef(op_name);
361   if (op_def == nullptr) {
362     return std::nullopt;
363   }
364   (void)op_def->func_impl_.CheckValidation(primitive, input_args);
365   return op_def->func_impl_.InferType(primitive, input_args);
366 }
367 
InferAbstractByFuncImpl(const PrimitivePtr & primitive,const AbstractBasePtrList & input_args)368 std::optional<AbstractBasePtr> InferAbstractByFuncImpl(const PrimitivePtr &primitive,
369                                                        const AbstractBasePtrList &input_args) {
370   MS_EXCEPTION_IF_NULL(primitive);
371   auto op_name = primitive->name();
372   auto frontend_func_impl = ops::GetOpFrontendFuncImplPtr(op_name);
373   if (frontend_func_impl != nullptr) {
374     auto infer_result = frontend_func_impl->InferAbstract(primitive, input_args);
375     if (infer_result != nullptr) {
376       return infer_result;
377     }
378   }
379 
380   auto op_def = ops::GetOpDef(op_name);
381   if (op_def == nullptr) {
382     return std::nullopt;
383   }
384   (void)op_def->func_impl_.CheckValidation(primitive, input_args);
385   auto shape = op_def->func_impl_.InferShape(primitive, input_args);
386   auto type = op_def->func_impl_.InferType(primitive, input_args);
387   return MakeAbstract(shape, type);
388 }
389 
InferValueByFuncImpl(const PrimitivePtr & primitive,const AbstractBasePtrList & input_args)390 std::optional<ValuePtr> InferValueByFuncImpl(const PrimitivePtr &primitive, const AbstractBasePtrList &input_args) {
391   MS_EXCEPTION_IF_NULL(primitive);
392   auto op_name = primitive->name();
393   auto frontend_func_impl = ops::GetOpFrontendFuncImplPtr(op_name);
394   if (frontend_func_impl == nullptr) {
395     return std::nullopt;
396   }
397   return frontend_func_impl->InferValue(primitive, input_args);
398 }
399 
TryInferAbstract(const PrimitivePtr & primitive,const AbstractBasePtrList & input_args)400 std::optional<AbstractBasePtr> TryInferAbstract(const PrimitivePtr &primitive, const AbstractBasePtrList &input_args) {
401   MS_EXCEPTION_IF_NULL(primitive);
402   auto abstract_optional = abstract::InferAbstractByFuncImpl(primitive, input_args);
403   if (abstract_optional.has_value()) {
404     return abstract_optional.value();
405   }
406 
407   auto found = abstract::GetPrimitiveInferImpl(primitive);
408   if (!found.has_value() || !found.value().IsImplInferShapeAndType()) {
409     MS_LOG(DEBUG) << "The infer function of [" << primitive->name() << "] is not defined.";
410     return std::nullopt;
411   }
412   return found.value().InferShapeAndType(nullptr, primitive, input_args);
413 }
414 }  // namespace abstract
415 }  // namespace mindspore
416