1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_EXPANDER_FACTORY_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_EXPANDER_FACTORY_H_ 18 19 #include <unordered_map> 20 #include <functional> 21 #include <string> 22 #include <memory> 23 24 #include "backend/optimizer/graph_kernel/expanders/utils.h" 25 26 namespace mindspore { 27 namespace opt { 28 namespace expanders { 29 class OpExpanderFactory { 30 public: Instance()31 static OpExpanderFactory &Instance() { 32 static OpExpanderFactory instance; 33 return instance; 34 } GetExpander(const std::string & op)35 std::shared_ptr<OpExpander> GetExpander(const std::string &op) { 36 if (auto iter = creators.find(op); iter != creators.end()) { 37 auto expander_ptr = iter->second(); 38 expander_ptr->op_ = op; 39 return expander_ptr; 40 } 41 return nullptr; 42 } 43 ~OpExpanderFactory() = default; 44 45 using RegFunc = std::function<std::shared_ptr<OpExpander>()>; Register(const std::string & op,const RegFunc & func)46 void Register(const std::string &op, const RegFunc &func) { creators[op] = func; } 47 48 private: 49 std::unordered_map<std::string, RegFunc> creators; 50 }; 51 52 class OpExpanderRegister { 53 public: OpExpanderRegister(const std::string & name,const OpExpanderFactory::RegFunc & func)54 OpExpanderRegister(const std::string &name, const OpExpanderFactory::RegFunc &func) : func_(func) { 55 OpExpanderFactory::Instance().Register(name, func); 56 } 57 ~OpExpanderRegister() = default; 58 59 private: 60 // for pclint-plus 61 OpExpanderFactory::RegFunc func_; 62 }; 63 64 #define OP_EXPANDER_REGISTER(name, cls) \ 65 static const OpExpanderRegister g_##cls##_expander_reg( \ 66 name, []() -> std::shared_ptr<OpExpander> { return std::make_shared<cls>(); }) 67 } // namespace expanders 68 } // namespace opt 69 } // namespace mindspore 70 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_EXPANDER_FACTORY_H_ 71