1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_OPS_PRIMITIVE_C_H_ 18 #define MINDSPORE_CORE_OPS_PRIMITIVE_C_H_ 19 #include <string> 20 #include <vector> 21 #include <map> 22 #include <memory> 23 #include "ir/primitive.h" 24 #include "abstract/primitive_infer_map.h" 25 #include "ir/value.h" 26 namespace mindspore { 27 namespace ops { 28 /// \brief PrimitiveC defines the base class for end side operators. 29 class MS_CORE_API PrimitiveC : public Primitive { 30 public: 31 /// \brief Constructor for PrimitiveC. 32 /// 33 /// \param[in] name The name of the end side operator. PrimitiveC(const std::string & name)34 explicit PrimitiveC(const std::string &name) : Primitive(name) {} 35 MS_DECLARE_PARENT(PrimitiveC, Primitive); 36 37 /// \brief Destructor of PrimitiveC. 38 ~PrimitiveC() = default; 39 40 /// \brief Derive the abstract of the PrimitiveC object. 41 /// 42 /// \param[in] abstract_list The abstract of the inputs of the PrimitiveC object. 43 /// \return The abstract of the PrimitiveC object. 44 AbstractBasePtr Infer(const AbstractBasePtrList &abstract_list); 45 46 protected: 47 void InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name); 48 }; 49 50 using OpPrimCDefineFunc = std::function<std::shared_ptr<PrimitiveC>()>; 51 /// \brief OpPrimCRegister defines the singleton to save the end side operators. 52 class MS_CORE_API OpPrimCRegister { 53 public: 54 /// \brief Destructor of OpPrimCRegister. ~OpPrimCRegister()55 ~OpPrimCRegister() {} 56 57 /// \brief Get the OpPrimCRegister singleton. 58 /// 59 /// \return The OpPrimCRegister singleton. 60 static OpPrimCRegister &GetInstance(); 61 62 /// \brief Get PrimCMap of the OpPrimCRegister singleton. 63 /// 64 /// \return The PrimCMap of the OpPrimCRegister singleton. 65 std::map<std::string, OpPrimCDefineFunc> GetPrimCMap(); 66 67 /// \brief Add an element into the PrimCMap of the OpPrimCRegister singleton. 68 /// 69 /// param[in] kname The name of the input end side operator. 70 /// param[in] fn The input end side operator. 71 void SetPrimCMap(const std::string &kname, const OpPrimCDefineFunc &fn); 72 73 private: OpPrimCRegister()74 OpPrimCRegister() {} 75 std::map<std::string, OpPrimCDefineFunc> op_primc_fns_; 76 }; 77 78 /// \brief OpPrimCRegisterHelper defines the helper class for the OpPrimCRegister singleton. 79 class MS_CORE_API OpPrimCRegisterHelper { 80 public: 81 /// \brief Constructor for OpPrimCRegisterHelper. 82 /// 83 /// param[in] kname The name of the input end side operator. 84 /// param[in] fn The input end side operator. OpPrimCRegisterHelper(const std::string & kname,const OpPrimCDefineFunc & fn)85 OpPrimCRegisterHelper(const std::string &kname, const OpPrimCDefineFunc &fn) { 86 OpPrimCRegister::GetInstance().SetPrimCMap(kname, fn); 87 } 88 89 /// Destructor of OpPrimCRegisterHelper. 90 ~OpPrimCRegisterHelper() = default; 91 92 private: 93 int id_{0}; 94 }; 95 96 #define REGISTER_PRIMITIVE_C(kname, primc) \ 97 std::shared_ptr<PrimitiveC> GetDefaultPrimC##primc() { \ 98 auto out = std::make_shared<primc>(); \ 99 return out; \ 100 } \ 101 OpPrimCRegisterHelper primc_gen_##kname(kname, GetDefaultPrimC##primc); 102 } // namespace ops 103 } // namespace mindspore 104 #endif // MINDSPORE_CORE_OPS_PRIMITIVE_C_H_ 105