1 /** 2 * Copyright 2019-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <utility> 24 25 #include "frontend/parallel/step_parallel.h" 26 27 namespace mindspore { 28 namespace parallel { 29 #define REGISTER2(opName, className) \ 30 OperatorInfoPtr objectCreator##opName(std::string name, Shapes in, Shapes out, PrimitiveAttrs &attrs) { \ 31 return std::make_shared<className>(name, in, out, attrs); \ 32 } \ 33 RegisterAction opName##Register(#opName, reinterpret_cast<CreatFn>(objectCreator##opName)); 34 35 typedef OperatorInfoPtr (*CreatFn)(const std::string &name, const Shapes &shape_in, const Shapes shape_out, 36 const PrimitiveAttrs &attrs); 37 38 #define REGISTER(className) REGISTER2(className, className) 39 40 class DynCreator { 41 public: 42 ~DynCreator() = default; 43 44 // create static singleton dyn_creator instance Instance()45 static DynCreator &Instance() { 46 static DynCreator fac = DynCreator(); 47 return fac; 48 } 49 // register Register(std::string name,CreatFn func)50 void Register(std::string name, CreatFn func) { (void)function_map_.insert(std::make_pair(name, func)); } 51 // creator Create(const std::string & name,const Shapes & shape_in,const Shapes & shape_out,const PrimitiveAttrs & attrs,size_t count)52 OperatorInfoPtr Create(const std::string &name, const Shapes &shape_in, const Shapes &shape_out, 53 const PrimitiveAttrs &attrs, size_t count) { 54 std::string op_name = name + std::to_string(count); 55 const auto iter = function_map_.find(name); 56 if (iter == function_map_.end()) { 57 MS_LOG(INFO) << name << " is not register yet"; 58 return nullptr; 59 } 60 return iter->second(op_name, shape_in, shape_out, attrs); 61 } 62 63 private: 64 DynCreator() = default; 65 std::map<std::string, CreatFn> function_map_; 66 }; 67 68 class RegisterAction { 69 public: RegisterAction(const std::string & name,CreatFn creatfn)70 RegisterAction(const std::string &name, CreatFn creatfn) noexcept : name_(name) { 71 DynCreator::Instance().Register(name, creatfn); 72 } 73 ~RegisterAction() = default; 74 75 private: 76 std::string name_; 77 }; 78 } // namespace parallel 79 } // namespace mindspore 80 81 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_ 82