• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_OPS_FRONTEND_FUNC_IMPL_H
18 #define MINDSPORE_CORE_OPS_FRONTEND_FUNC_IMPL_H
19 
20 #include <vector>
21 #include <unordered_map>
22 #include <memory>
23 #include <string>
24 #include "ir/cell.h"
25 #include "ir/primitive.h"
26 #include "abstract/abstract_value.h"
27 #include "ir/anf.h"
28 #include "mindapi/base/macros.h"
29 
30 namespace mindspore::ops {
31 class OpFrontendFuncImpl {
32  public:
33   OpFrontendFuncImpl() = default;
34   virtual ~OpFrontendFuncImpl() = default;
35 
36   /// \brief Infer the output value for target operator. Only override when needed.
37   ///
38   /// \param[in] primitive Operator's primitive.
39   /// \param[in] input_args Operator's inputs.
40   ///
41   /// \return Inferred Value based on given inputs.
InferValue(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)42   virtual ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
43     return nullptr;
44   }
45 
46   /// \brief Infer the related Abstract for target operator.
47   ///
48   /// \param[in] primitive Operator's primitive.
49   /// \param[in] input_args Operator's inputs.
50   ///
51   /// \return AbstractBasePtr with inferred shape and inferred type.
InferAbstract(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)52   virtual AbstractBasePtr InferAbstract(const PrimitivePtr &primitive,
53                                         const std::vector<AbstractBasePtr> &input_args) const {
54     return nullptr;
55   }
56 };
57 
58 using OpFrontendFuncImplPtr = std::shared_ptr<OpFrontendFuncImpl>;
59 
60 class FrontendFuncImplHolder {
61  public:
FrontendFuncImplHolder(const OpFrontendFuncImplPtr & func_impl)62   explicit FrontendFuncImplHolder(const OpFrontendFuncImplPtr &func_impl) : func_impl_(func_impl) {}
63   ~FrontendFuncImplHolder() = default;
get_func_impl()64   OpFrontendFuncImplPtr get_func_impl() { return func_impl_; }
65 
66  private:
67   OpFrontendFuncImplPtr func_impl_{nullptr};
68 };
69 
70 using OpsFrontendFuncImplMap = std::unordered_map<std::string, FrontendFuncImplHolder>;
71 
72 MS_CORE_API OpFrontendFuncImplPtr GetOpFrontendFuncImplPtr(const std::string &name);
73 
74 class MS_CORE_API RegFrontendFuncImplHelper {
75  public:
76   RegFrontendFuncImplHelper(const std::string &name, const OpFrontendFuncImplPtr &func_impl);
77   ~RegFrontendFuncImplHelper() = default;
78 };
79 
80 #define REGISTER_PRIMITIVE_FUNCTION_FRONTEND_FUNC_IMPL(name, func_impl_class) \
81   static auto helper_##func_impl_class = RegFrontendFuncImplHelper(name, std::make_shared<func_impl_class>());
82 
83 using InferValueFunc = std::function<ValuePtr(const std::string &, const AbstractBasePtrList &)>;
84 class MS_CORE_API InferValueCallback {
85  public:
86   InferValueCallback(const InferValueCallback &) = delete;
87   InferValueCallback &operator=(const InferValueCallback &) = delete;
88 
89   static InferValueCallback &GetInstance();
90 
91   void RegImpl(const std::string &impl_type, const InferValueFunc &py_func);
92   ValuePtr CallPyInferValue(const std::string &op_name, const AbstractBasePtrList &input_args);
93   ValuePtr CallKernelInferValue(const std::string &op_name, const AbstractBasePtrList &input_args);
94 
95  private:
96   InferValueCallback() = default;
~InferValueCallback()97   ~InferValueCallback() {}
98 
99  private:
100   InferValueFunc python_impl_{nullptr};
101   InferValueFunc kernel_impl_{nullptr};
102 };
103 
104 class MS_CORE_API InferValueImplRegister {
105  public:
106   InferValueImplRegister(const std::string &impl_type, const InferValueFunc &fn);
107   ~InferValueImplRegister() = default;
108 };
109 
110 #define INFER_VALUE_IMPL_REGISTER(impl_type, func) \
111   static auto reg_##impl_type##_##func = mindspore::ops::InferValueImplRegister(#impl_type, func)
112 }  //  namespace mindspore::ops
113 #endif  //  MINDSPORE_CORE_OPS_FRONTEND_FUNC_IMPL_H
114