• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 #ifndef MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
19 #define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
20 
21 #include <vector>
22 #include <set>
23 #include <string>
24 #include <memory>
25 #include <optional>
26 #include "utils/hash_map.h"
27 #include "ir/primitive.h"
28 #include "ops/primitive_c.h"
29 #include "abstract/abstract_value.h"
30 #include "ir/anf.h"
31 #include "abstract/ops/op_infer.h"
32 
33 namespace mindspore {
34 namespace abstract {
35 using InferAbstractImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &,
36                                               const AbstractBasePtrList &);
37 using InferValueImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &);
38 
39 class MS_CORE_API StandardPrimitiveImplReg {
40  public:
41   StandardPrimitiveImplReg() = default;
42   StandardPrimitiveImplReg(const InferAbstractImpl &infer_abstract, const InferValueImpl &infer_value,
43                            bool in_white_list);
StandardPrimitiveImplReg(const OpInferBasePtr & op_infer,bool is_impl_infer_value)44   StandardPrimitiveImplReg(const OpInferBasePtr &op_infer, bool is_impl_infer_value)
45       : op_infer_(op_infer), is_impl_infer_value_(is_impl_infer_value) {}
46   ~StandardPrimitiveImplReg() = default;
47 
Get()48   const OpInferBasePtr Get() const { return op_infer_; }
49 
50   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
51                                     const std::vector<AbstractBasePtr> &input_args) const;
52   BaseShapePtr InferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args) const;
53   TypePtr InferType(const PrimitivePtr &prim, const AbstractBasePtrList &args) const;
54   ValuePtr InferValue(const PrimitivePtr &prim, const AbstractBasePtrList &args) const;
55 
IsImplInferShapeAndType()56   bool IsImplInferShapeAndType() const { return is_impl_infer_shape_and_type_ && op_infer_ != nullptr; }
IsImplInferValue()57   bool IsImplInferValue() const { return is_impl_infer_value_ && op_infer_ != nullptr; }
IsInWhiteList()58   bool IsInWhiteList() const { return in_white_list_; }
59 
60  private:
61   OpInferBasePtr op_infer_{nullptr};  // Infer shape, type and value.
62   bool is_impl_infer_shape_and_type_{true};
63   bool is_impl_infer_value_{false};
64   // in_white_list_ is true means this primitive can be executed by vm backend
65   // else will be optimized by frontend
66   bool in_white_list_{true};
67 };
68 
69 void IsImplInferShapeAndType(const OpInferBasePtr &op_infer);
70 void IsImplInferValue(const OpInferBasePtr &op_infer);
71 
72 using PrimitiveEvalImplMap =
73   mindspore::HashMap<PrimitivePtr, StandardPrimitiveImplReg, PrimitiveHasher, PrimitiveEqual>;
74 
75 using PrimShapeDependMap = mindspore::HashMap<std::string, std::set<int64_t>>;
76 
77 MS_CORE_API const PrimitiveEvalImplMap &GetPrimitiveInferMap();
78 MS_CORE_API PrimitiveEvalImplMap *GetPrimitiveInferMapPtr();
79 
80 MS_CORE_API const PrimitiveEvalImplMap &GetDeprecatedPrimitiveInferMap();
81 MS_CORE_API PrimitiveEvalImplMap *GetDeprecatedPrimitiveInferMapPtr();
82 
83 // get prim infer from infer map or deprecated infer map
84 MS_CORE_API std::optional<StandardPrimitiveImplReg> GetPrimitiveInferImpl(const PrimitivePtr &primitive);
85 
86 MS_CORE_API std::set<int64_t> GetValueDependArgIndices(const CNodePtr &cnode, bool is_proto = false);
87 
88 class RegisterStandardPrimitiveEvalHelper {
89  public:
90   RegisterStandardPrimitiveEvalHelper(PrimitiveEvalImplMap *eval_map, const PrimitivePtr &primitive,
91                                       const InferAbstractImpl &infer_shape_and_type_impl,
92                                       const InferValueImpl &infer_value_impl, const bool is_white_list = true) {
93     const StandardPrimitiveImplReg impl_reg{infer_shape_and_type_impl, infer_value_impl, is_white_list};
94     eval_map->emplace(primitive, impl_reg);
95   }
96 
97   RegisterStandardPrimitiveEvalHelper(PrimitiveEvalImplMap *eval_map, const PrimitivePtr &primitive,
98                                       const OpInferBasePtr &op_infer, bool is_impl_infer_value = false) {
99     const StandardPrimitiveImplReg impl_reg{op_infer, is_impl_infer_value};
100     eval_map->emplace(primitive, impl_reg);
101   }
102   ~RegisterStandardPrimitiveEvalHelper() = default;
103 };
104 
105 #define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, infer_shape_and_type_impl, infer_value_impl, is_white_list) \
106   static auto helper_eval_##name = abstract::RegisterStandardPrimitiveEvalHelper(                                 \
107     abstract::GetPrimitiveInferMapPtr(), primitive, infer_shape_and_type_impl, infer_value_impl, is_white_list);  \
108   std::shared_ptr<ops::PrimitiveC> GetDefaultPrimC##name() {                                                      \
109     name out;                                                                                                     \
110     return std::dynamic_pointer_cast<ops::PrimitiveC>(out.impl());                                                \
111   }                                                                                                               \
112   ops::OpPrimCRegisterHelper primc_gen_##name(#name, GetDefaultPrimC##name);
113 
114 #define REGISTER_PRIMITIVE_OP_INFER_IMPL(name, primitive, OP_INFER_ClASS, is_impl_infer_value)                \
115   const auto helper_op_infer_##name = abstract::RegisterStandardPrimitiveEvalHelper(                          \
116     abstract::GetPrimitiveInferMapPtr(), primitive, std::make_shared<OP_INFER_ClASS>(), is_impl_infer_value); \
117   std::shared_ptr<ops::PrimitiveC> GetDefaultPrimC##name() {                                                  \
118     name out;                                                                                                 \
119     return std::dynamic_pointer_cast<ops::PrimitiveC>(out.impl());                                            \
120   }                                                                                                           \
121   ops::OpPrimCRegisterHelper primc_gen_##name(#name, GetDefaultPrimC##name)
122 
123 MS_CORE_API std::optional<BaseShapePtr> InferShapeByFuncImpl(const PrimitivePtr &primitive,
124                                                              const AbstractBasePtrList &input_args,
125                                                              bool compile_phase = false);
126 MS_CORE_API std::optional<TypePtr> InferTypeByFuncImpl(const PrimitivePtr &primitive,
127                                                        const AbstractBasePtrList &input_args,
128                                                        bool compile_phase = false);
129 MS_CORE_API std::optional<AbstractBasePtr> InferAbstractByFuncImpl(const PrimitivePtr &primitive,
130                                                                    const AbstractBasePtrList &input_args);
131 MS_CORE_API std::optional<ValuePtr> InferValueByFuncImpl(const PrimitivePtr &primitive,
132                                                          const AbstractBasePtrList &input_args);
133 
134 MS_CORE_API std::optional<AbstractBasePtr> TryInferAbstract(const PrimitivePtr &primitive,
135                                                             const AbstractBasePtrList &input_args);
136 }  // namespace abstract
137 }  // namespace mindspore
138 #endif  // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
139