• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PLUGIN_FACTORY_MS_FACTORY_H_
18 #define MINDSPORE_CCSRC_PLUGIN_FACTORY_MS_FACTORY_H_
19 
20 #include <algorithm>
21 #include <functional>
22 #include <map>
23 #include <memory>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 #include "kernel/kernel_factory.h"
28 #include "utils/log_adapter.h"
29 
30 namespace mindspore {
31 namespace kernel {
32 template <class C>
33 class Factory : public FactoryBase {
34   using CreatorFunc = std::function<std::shared_ptr<C>()>;
35 
36  public:
37   Factory(const Factory &) = delete;
38   void operator=(const Factory &) = delete;
39 
Instance()40   static Factory<C> &Instance() {
41     std::string key = typeid(C).name();
42     FactoryBase *instance = FactoryBase::GetInstance(key);
43     if (instance == nullptr) {
44       FactoryBase::CreateFactory(key, std::make_unique<Factory<C>>());
45       instance = FactoryBase::GetInstance(key);
46     }
47     MS_EXCEPTION_IF_NULL(instance);
48     return *static_cast<Factory<C> *>(instance);
49   }
50 
Register(const std::string & name,CreatorFunc && creator)51   void Register(const std::string &name, CreatorFunc &&creator) {
52     if (IsRegistered(name)) {
53       MS_LOG(EXCEPTION) << "Kernel " << name << " is already registered!";
54     }
55     (void)kernel_mod_creators_.emplace(name, creator);
56   }
57 
UnRegister(const std::string & name)58   void UnRegister(const std::string &name) {
59     auto iter = kernel_mod_creators_.find(name);
60     if (iter != kernel_mod_creators_.end()) {
61       kernel_mod_creators_.erase(iter);
62     }
63   }
64 
Create(const std::string & name)65   std::shared_ptr<C> Create(const std::string &name) const {
66     typename std::map<std::string, CreatorFunc>::const_iterator iter = kernel_mod_creators_.find(name);
67     if (iter != kernel_mod_creators_.cend()) {
68       return (iter->second)();
69     }
70     return nullptr;
71   }
72 
IsRegistered(const std::string & name)73   bool IsRegistered(const std::string &name) const {
74     if (kernel_mod_creators_.find(name) != kernel_mod_creators_.end()) {
75       return true;
76     }
77     return false;
78   }
79 
80   Factory() = default;
81   ~Factory() = default;
82 
83  private:
84   std::map<std::string, CreatorFunc> kernel_mod_creators_;
85 };
86 
87 template <class C>
88 class KernelRegistrar {
89  public:
KernelRegistrar(const std::string & name,std::function<std::shared_ptr<C> ()> creator)90   explicit KernelRegistrar(const std::string &name, std::function<std::shared_ptr<C>()> creator) noexcept {
91     Factory<C>::Instance().Register(name, std::move(creator));
92   }
93   ~KernelRegistrar() = default;
94 };
95 
96 // Helper macro for factory registration.
97 #define MS_KERNEL_FACTORY_REG(BASE_CLASS, NAME, DERIVE_CLASS)                                                          \
98   static_assert(std::is_base_of<BASE_CLASS, DERIVE_CLASS>::value, #DERIVE_CLASS " must be derived from " #BASE_CLASS); \
99   static const KernelRegistrar<BASE_CLASS> g_##NAME##_##BASE_CLASS##_reg(                                              \
100     #NAME, []() { return std::make_shared<DERIVE_CLASS>(); })
101 
102 #define MS_KERNEL_FACTORY_REG_BY_CREATOR(BASE_CLASS, NAME, CREATOR) \
103   static const KernelRegistrar<BASE_CLASS> g_##NAME##_##BASE_CLASS##_reg(#NAME, CREATOR)
104 
105 #define MS_KERNEL_FACTORY_REG_WITH_NAME_PARAM(BASE_CLASS, NAME, DERIVE_CLASS)                                          \
106   static_assert(std::is_base_of<BASE_CLASS, DERIVE_CLASS>::value, #DERIVE_CLASS " must be derived from " #BASE_CLASS); \
107   static const KernelRegistrar<BASE_CLASS> g_##NAME##_##BASE_CLASS##_reg(                                              \
108     #NAME, []() { return std::make_shared<DERIVE_CLASS>(#NAME); })
109 }  // namespace kernel
110 }  // namespace mindspore
111 #endif  // MINDSPORE_CCSRC_PLUGIN_FACTORY_MS_FACTORY_H_
112