1 /** 2 * Copyright 2021-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_EXPANDER_FACTORY_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_EXPANDER_FACTORY_H_ 18 19 #include <functional> 20 #include <string> 21 #include <memory> 22 23 #include "utils/hash_map.h" 24 #include "backend/common/graph_kernel/expanders/utils.h" 25 #include "include/backend/visible.h" 26 27 namespace mindspore::graphkernel::expanders { 28 class BACKEND_EXPORT OpDescFactory { 29 public: Instance()30 static OpDescFactory &Instance() { 31 static OpDescFactory instance{}; 32 return instance; 33 } HasOp(const std::string & op)34 bool HasOp(const std::string &op) const { return creators.find(op) != creators.end(); } GetOp(const std::string & op)35 std::shared_ptr<OpDesc> GetOp(const std::string &op) const { 36 if (auto iter = creators.find(op); iter != creators.end()) { 37 auto op_desc = iter->second(); 38 op_desc->name_ = op; 39 return op_desc; 40 } 41 return nullptr; 42 } 43 OpDescFactory() = default; 44 ~OpDescFactory() = default; 45 46 using RegFunc = std::function<std::shared_ptr<OpDesc>()>; Register(const std::string & op,const RegFunc & func)47 void Register(const std::string &op, const RegFunc &func) { creators[op] = func; } 48 49 private: 50 inline static mindspore::HashMap<std::string, RegFunc> creators; 51 }; 52 53 class OpDescRegister { 54 public: OpDescRegister(const std::string & name,const OpDescFactory::RegFunc & func)55 OpDescRegister(const std::string &name, const OpDescFactory::RegFunc &func) : func_(func) { 56 OpDescFactory::Instance().Register(name, func); 57 } 58 ~OpDescRegister() = default; 59 60 private: 61 // for pclint-plus 62 OpDescFactory::RegFunc func_; 63 }; 64 65 #define JOIN(x, y) x##y 66 #define UNIQUE_NAME(prefix, cnt) JOIN(prefix, cnt) 67 #define EXPANDER_OP_DESC_REGISTER(name, cls) \ 68 const OpDescRegister UNIQUE_NAME(g_expander_opdesc_, __COUNTER__)( \ 69 name, []() noexcept -> std::shared_ptr<OpDesc> { return std::make_shared<cls>(); }) 70 } // namespace mindspore::graphkernel::expanders 71 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_EXPANDER_FACTORY_H_ 72