1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2021 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 #ifndef MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ 19 #define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ 20 21 #include <vector> 22 #include <set> 23 #include <string> 24 #include <memory> 25 #include <optional> 26 #include "utils/hash_map.h" 27 #include "ir/primitive.h" 28 #include "ops/primitive_c.h" 29 #include "abstract/abstract_value.h" 30 #include "ir/anf.h" 31 #include "abstract/ops/op_infer.h" 32 33 namespace mindspore { 34 namespace abstract { 35 using InferAbstractImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, 36 const AbstractBasePtrList &); 37 using InferValueImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &); 38 39 class MS_CORE_API StandardPrimitiveImplReg { 40 public: 41 StandardPrimitiveImplReg() = default; 42 StandardPrimitiveImplReg(const InferAbstractImpl &infer_abstract, const InferValueImpl &infer_value, 43 bool in_white_list); StandardPrimitiveImplReg(const OpInferBasePtr & op_infer,bool is_impl_infer_value)44 StandardPrimitiveImplReg(const OpInferBasePtr &op_infer, bool is_impl_infer_value) 45 : op_infer_(op_infer), is_impl_infer_value_(is_impl_infer_value) {} 46 ~StandardPrimitiveImplReg() = default; 47 Get()48 const OpInferBasePtr Get() const { return op_infer_; } 49 50 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive, 51 const std::vector<AbstractBasePtr> &input_args) const; 52 BaseShapePtr InferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args) const; 53 TypePtr InferType(const PrimitivePtr &prim, const AbstractBasePtrList &args) const; 54 ValuePtr InferValue(const PrimitivePtr &prim, const AbstractBasePtrList &args) const; 55 IsImplInferShapeAndType()56 bool IsImplInferShapeAndType() const { return is_impl_infer_shape_and_type_ && op_infer_ != nullptr; } IsImplInferValue()57 bool IsImplInferValue() const { return is_impl_infer_value_ && op_infer_ != nullptr; } IsInWhiteList()58 bool IsInWhiteList() const { return in_white_list_; } 59 60 private: 61 OpInferBasePtr op_infer_{nullptr}; // Infer shape, type and value. 62 bool is_impl_infer_shape_and_type_{true}; 63 bool is_impl_infer_value_{false}; 64 // in_white_list_ is true means this primitive can be executed by vm backend 65 // else will be optimized by frontend 66 bool in_white_list_{true}; 67 }; 68 69 void IsImplInferShapeAndType(const OpInferBasePtr &op_infer); 70 void IsImplInferValue(const OpInferBasePtr &op_infer); 71 72 using PrimitiveEvalImplMap = 73 mindspore::HashMap<PrimitivePtr, StandardPrimitiveImplReg, PrimitiveHasher, PrimitiveEqual>; 74 75 using PrimShapeDependMap = mindspore::HashMap<std::string, std::set<int64_t>>; 76 77 MS_CORE_API const PrimitiveEvalImplMap &GetPrimitiveInferMap(); 78 MS_CORE_API PrimitiveEvalImplMap *GetPrimitiveInferMapPtr(); 79 80 MS_CORE_API const PrimitiveEvalImplMap &GetDeprecatedPrimitiveInferMap(); 81 MS_CORE_API PrimitiveEvalImplMap *GetDeprecatedPrimitiveInferMapPtr(); 82 83 // get prim infer from infer map or deprecated infer map 84 MS_CORE_API std::optional<StandardPrimitiveImplReg> GetPrimitiveInferImpl(const PrimitivePtr &primitive); 85 86 MS_CORE_API std::set<int64_t> GetValueDependArgIndices(const CNodePtr &cnode, bool is_proto = false); 87 88 class RegisterStandardPrimitiveEvalHelper { 89 public: 90 RegisterStandardPrimitiveEvalHelper(PrimitiveEvalImplMap *eval_map, const PrimitivePtr &primitive, 91 const InferAbstractImpl &infer_shape_and_type_impl, 92 const InferValueImpl &infer_value_impl, const bool is_white_list = true) { 93 const StandardPrimitiveImplReg impl_reg{infer_shape_and_type_impl, infer_value_impl, is_white_list}; 94 eval_map->emplace(primitive, impl_reg); 95 } 96 97 RegisterStandardPrimitiveEvalHelper(PrimitiveEvalImplMap *eval_map, const PrimitivePtr &primitive, 98 const OpInferBasePtr &op_infer, bool is_impl_infer_value = false) { 99 const StandardPrimitiveImplReg impl_reg{op_infer, is_impl_infer_value}; 100 eval_map->emplace(primitive, impl_reg); 101 } 102 ~RegisterStandardPrimitiveEvalHelper() = default; 103 }; 104 105 #define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, infer_shape_and_type_impl, infer_value_impl, is_white_list) \ 106 static auto helper_eval_##name = abstract::RegisterStandardPrimitiveEvalHelper( \ 107 abstract::GetPrimitiveInferMapPtr(), primitive, infer_shape_and_type_impl, infer_value_impl, is_white_list); \ 108 std::shared_ptr<ops::PrimitiveC> GetDefaultPrimC##name() { \ 109 name out; \ 110 return std::dynamic_pointer_cast<ops::PrimitiveC>(out.impl()); \ 111 } \ 112 ops::OpPrimCRegisterHelper primc_gen_##name(#name, GetDefaultPrimC##name); 113 114 #define REGISTER_PRIMITIVE_OP_INFER_IMPL(name, primitive, OP_INFER_ClASS, is_impl_infer_value) \ 115 const auto helper_op_infer_##name = abstract::RegisterStandardPrimitiveEvalHelper( \ 116 abstract::GetPrimitiveInferMapPtr(), primitive, std::make_shared<OP_INFER_ClASS>(), is_impl_infer_value); \ 117 std::shared_ptr<ops::PrimitiveC> GetDefaultPrimC##name() { \ 118 name out; \ 119 return std::dynamic_pointer_cast<ops::PrimitiveC>(out.impl()); \ 120 } \ 121 ops::OpPrimCRegisterHelper primc_gen_##name(#name, GetDefaultPrimC##name) 122 123 MS_CORE_API std::optional<BaseShapePtr> InferShapeByFuncImpl(const PrimitivePtr &primitive, 124 const AbstractBasePtrList &input_args, 125 bool compile_phase = false); 126 MS_CORE_API std::optional<TypePtr> InferTypeByFuncImpl(const PrimitivePtr &primitive, 127 const AbstractBasePtrList &input_args, 128 bool compile_phase = false); 129 MS_CORE_API std::optional<AbstractBasePtr> InferAbstractByFuncImpl(const PrimitivePtr &primitive, 130 const AbstractBasePtrList &input_args); 131 MS_CORE_API std::optional<ValuePtr> InferValueByFuncImpl(const PrimitivePtr &primitive, 132 const AbstractBasePtrList &input_args); 133 134 MS_CORE_API std::optional<AbstractBasePtr> TryInferAbstract(const PrimitivePtr &primitive, 135 const AbstractBasePtrList &input_args); 136 } // namespace abstract 137 } // namespace mindspore 138 #endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ 139