• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 
25 #include "frontend/parallel/step_parallel.h"
26 
27 namespace mindspore {
28 namespace parallel {
29 #define REGISTER2(opName, className)                                                                      \
30   OperatorInfoPtr objectCreator##opName(std::string name, Shapes in, Shapes out, PrimitiveAttrs &attrs) { \
31     return std::make_shared<className>(name, in, out, attrs);                                             \
32   }                                                                                                       \
33   RegisterAction opName##Register(#opName, reinterpret_cast<CreatFn>(objectCreator##opName));
34 
35 typedef OperatorInfoPtr (*CreatFn)(const std::string &name, const Shapes &shape_in, const Shapes shape_out,
36                                    const PrimitiveAttrs &attrs);
37 
38 #define REGISTER(className) REGISTER2(className, className)
39 
40 class DynCreator {
41  public:
42   ~DynCreator() = default;
43 
44   // create static singleton dyn_creator instance
Instance()45   static DynCreator &Instance() {
46     static DynCreator fac = DynCreator();
47     return fac;
48   }
49   // register
Register(std::string name,CreatFn func)50   void Register(std::string name, CreatFn func) { (void)function_map_.insert(std::make_pair(name, func)); }
51   // creator
Create(const std::string & name,const Shapes & shape_in,const Shapes & shape_out,const PrimitiveAttrs & attrs,size_t count)52   OperatorInfoPtr Create(const std::string &name, const Shapes &shape_in, const Shapes &shape_out,
53                          const PrimitiveAttrs &attrs, size_t count) {
54     std::string op_name = name + std::to_string(count);
55     const auto iter = function_map_.find(name);
56     if (iter == function_map_.end()) {
57       MS_LOG(INFO) << name << " is not register yet";
58       return nullptr;
59     }
60     return iter->second(op_name, shape_in, shape_out, attrs);
61   }
62 
63  private:
64   DynCreator() = default;
65   std::map<std::string, CreatFn> function_map_;
66 };
67 
68 class RegisterAction {
69  public:
RegisterAction(const std::string & name,CreatFn creatfn)70   RegisterAction(const std::string &name, CreatFn creatfn) noexcept : name_(name) {
71     DynCreator::Instance().Register(name, creatfn);
72   }
73   ~RegisterAction() = default;
74 
75  private:
76   std::string name_;
77 };
78 }  // namespace parallel
79 }  // namespace mindspore
80 
81 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_
82