1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_OPS_FRONTEND_FUNC_IMPL_H 18 #define MINDSPORE_CORE_OPS_FRONTEND_FUNC_IMPL_H 19 20 #include <vector> 21 #include <unordered_map> 22 #include <memory> 23 #include <string> 24 #include "ir/cell.h" 25 #include "ir/primitive.h" 26 #include "abstract/abstract_value.h" 27 #include "ir/anf.h" 28 #include "mindapi/base/macros.h" 29 30 namespace mindspore::ops { 31 class OpFrontendFuncImpl { 32 public: 33 OpFrontendFuncImpl() = default; 34 virtual ~OpFrontendFuncImpl() = default; 35 36 /// \brief Infer the output value for target operator. Only override when needed. 37 /// 38 /// \param[in] primitive Operator's primitive. 39 /// \param[in] input_args Operator's inputs. 40 /// 41 /// \return Inferred Value based on given inputs. InferValue(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)42 virtual ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const { 43 return nullptr; 44 } 45 46 /// \brief Infer the related Abstract for target operator. 47 /// 48 /// \param[in] primitive Operator's primitive. 49 /// \param[in] input_args Operator's inputs. 50 /// 51 /// \return AbstractBasePtr with inferred shape and inferred type. InferAbstract(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)52 virtual AbstractBasePtr InferAbstract(const PrimitivePtr &primitive, 53 const std::vector<AbstractBasePtr> &input_args) const { 54 return nullptr; 55 } 56 }; 57 58 using OpFrontendFuncImplPtr = std::shared_ptr<OpFrontendFuncImpl>; 59 60 class FrontendFuncImplHolder { 61 public: FrontendFuncImplHolder(const OpFrontendFuncImplPtr & func_impl)62 explicit FrontendFuncImplHolder(const OpFrontendFuncImplPtr &func_impl) : func_impl_(func_impl) {} 63 ~FrontendFuncImplHolder() = default; get_func_impl()64 OpFrontendFuncImplPtr get_func_impl() { return func_impl_; } 65 66 private: 67 OpFrontendFuncImplPtr func_impl_{nullptr}; 68 }; 69 70 using OpsFrontendFuncImplMap = std::unordered_map<std::string, FrontendFuncImplHolder>; 71 72 MS_CORE_API OpFrontendFuncImplPtr GetOpFrontendFuncImplPtr(const std::string &name); 73 74 class MS_CORE_API RegFrontendFuncImplHelper { 75 public: 76 RegFrontendFuncImplHelper(const std::string &name, const OpFrontendFuncImplPtr &func_impl); 77 ~RegFrontendFuncImplHelper() = default; 78 }; 79 80 #define REGISTER_PRIMITIVE_FUNCTION_FRONTEND_FUNC_IMPL(name, func_impl_class) \ 81 static auto helper_##func_impl_class = RegFrontendFuncImplHelper(name, std::make_shared<func_impl_class>()); 82 83 using InferValueFunc = std::function<ValuePtr(const std::string &, const AbstractBasePtrList &)>; 84 class MS_CORE_API InferValueCallback { 85 public: 86 InferValueCallback(const InferValueCallback &) = delete; 87 InferValueCallback &operator=(const InferValueCallback &) = delete; 88 89 static InferValueCallback &GetInstance(); 90 91 void RegImpl(const std::string &impl_type, const InferValueFunc &py_func); 92 ValuePtr CallPyInferValue(const std::string &op_name, const AbstractBasePtrList &input_args); 93 ValuePtr CallKernelInferValue(const std::string &op_name, const AbstractBasePtrList &input_args); 94 95 private: 96 InferValueCallback() = default; ~InferValueCallback()97 ~InferValueCallback() {} 98 99 private: 100 InferValueFunc python_impl_{nullptr}; 101 InferValueFunc kernel_impl_{nullptr}; 102 }; 103 104 class MS_CORE_API InferValueImplRegister { 105 public: 106 InferValueImplRegister(const std::string &impl_type, const InferValueFunc &fn); 107 ~InferValueImplRegister() = default; 108 }; 109 110 #define INFER_VALUE_IMPL_REGISTER(impl_type, func) \ 111 static auto reg_##impl_type##_##func = mindspore::ops::InferValueImplRegister(#impl_type, func) 112 } // namespace mindspore::ops 113 #endif // MINDSPORE_CORE_OPS_FRONTEND_FUNC_IMPL_H 114