1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PLUGIN_FACTORY_MS_FACTORY_H_ 18 #define MINDSPORE_CCSRC_PLUGIN_FACTORY_MS_FACTORY_H_ 19 20 #include <algorithm> 21 #include <functional> 22 #include <map> 23 #include <memory> 24 #include <string> 25 #include <utility> 26 #include <vector> 27 #include "kernel/kernel_factory.h" 28 #include "utils/log_adapter.h" 29 30 namespace mindspore { 31 namespace kernel { 32 template <class C> 33 class Factory : public FactoryBase { 34 using CreatorFunc = std::function<std::shared_ptr<C>()>; 35 36 public: 37 Factory(const Factory &) = delete; 38 void operator=(const Factory &) = delete; 39 Instance()40 static Factory<C> &Instance() { 41 std::string key = typeid(C).name(); 42 FactoryBase *instance = FactoryBase::GetInstance(key); 43 if (instance == nullptr) { 44 FactoryBase::CreateFactory(key, std::make_unique<Factory<C>>()); 45 instance = FactoryBase::GetInstance(key); 46 } 47 MS_EXCEPTION_IF_NULL(instance); 48 return *static_cast<Factory<C> *>(instance); 49 } 50 Register(const std::string & name,CreatorFunc && creator)51 void Register(const std::string &name, CreatorFunc &&creator) { 52 if (IsRegistered(name)) { 53 MS_LOG(EXCEPTION) << "Kernel " << name << " is already registered!"; 54 } 55 (void)kernel_mod_creators_.emplace(name, creator); 56 } 57 UnRegister(const std::string & name)58 void UnRegister(const std::string &name) { 59 auto iter = kernel_mod_creators_.find(name); 60 if (iter != kernel_mod_creators_.end()) { 61 kernel_mod_creators_.erase(iter); 62 } 63 } 64 Create(const std::string & name)65 std::shared_ptr<C> Create(const std::string &name) const { 66 typename std::map<std::string, CreatorFunc>::const_iterator iter = kernel_mod_creators_.find(name); 67 if (iter != kernel_mod_creators_.cend()) { 68 return (iter->second)(); 69 } 70 return nullptr; 71 } 72 IsRegistered(const std::string & name)73 bool IsRegistered(const std::string &name) const { 74 if (kernel_mod_creators_.find(name) != kernel_mod_creators_.end()) { 75 return true; 76 } 77 return false; 78 } 79 80 Factory() = default; 81 ~Factory() = default; 82 83 private: 84 std::map<std::string, CreatorFunc> kernel_mod_creators_; 85 }; 86 87 template <class C> 88 class KernelRegistrar { 89 public: KernelRegistrar(const std::string & name,std::function<std::shared_ptr<C> ()> creator)90 explicit KernelRegistrar(const std::string &name, std::function<std::shared_ptr<C>()> creator) noexcept { 91 Factory<C>::Instance().Register(name, std::move(creator)); 92 } 93 ~KernelRegistrar() = default; 94 }; 95 96 // Helper macro for factory registration. 97 #define MS_KERNEL_FACTORY_REG(BASE_CLASS, NAME, DERIVE_CLASS) \ 98 static_assert(std::is_base_of<BASE_CLASS, DERIVE_CLASS>::value, #DERIVE_CLASS " must be derived from " #BASE_CLASS); \ 99 static const KernelRegistrar<BASE_CLASS> g_##NAME##_##BASE_CLASS##_reg( \ 100 #NAME, []() { return std::make_shared<DERIVE_CLASS>(); }) 101 102 #define MS_KERNEL_FACTORY_REG_BY_CREATOR(BASE_CLASS, NAME, CREATOR) \ 103 static const KernelRegistrar<BASE_CLASS> g_##NAME##_##BASE_CLASS##_reg(#NAME, CREATOR) 104 105 #define MS_KERNEL_FACTORY_REG_WITH_NAME_PARAM(BASE_CLASS, NAME, DERIVE_CLASS) \ 106 static_assert(std::is_base_of<BASE_CLASS, DERIVE_CLASS>::value, #DERIVE_CLASS " must be derived from " #BASE_CLASS); \ 107 static const KernelRegistrar<BASE_CLASS> g_##NAME##_##BASE_CLASS##_reg( \ 108 #NAME, []() { return std::make_shared<DERIVE_CLASS>(#NAME); }) 109 } // namespace kernel 110 } // namespace mindspore 111 #endif // MINDSPORE_CCSRC_PLUGIN_FACTORY_MS_FACTORY_H_ 112