• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 
25 #include "frontend/parallel/ops_info/ops_info_head_files.h"
26 #include "frontend/parallel/step_parallel.h"
27 
28 namespace mindspore {
29 namespace parallel {
30 #define REGISTER(className)                                                                                  \
31   OperatorInfoPtr objectCreator##className(std::string name, Shapes in, Shapes out, PrimitiveAttrs &attrs) { \
32     return std::make_shared<className>(name, in, out, attrs);                                                \
33   }                                                                                                          \
34   RegisterAction className##Register(#className, (CreatFn)objectCreator##className);
35 
36 typedef OperatorInfoPtr (*CreatFn)(const std::string &name, const Shapes &shape_in, const Shapes shape_out,
37                                    const PrimitiveAttrs &attrs);
38 
39 class DynCreator {
40  public:
41   ~DynCreator() = default;
42 
43   // create static singleton dyn_creator instance
Instance()44   static DynCreator &Instance() {
45     static DynCreator fac = DynCreator();
46     return fac;
47   }
48   // register
Register(std::string name,CreatFn func)49   void Register(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); }
50   // creator
Create(const std::string & name,const Shapes & shape_in,const Shapes & shape_out,const PrimitiveAttrs & attrs,size_t count)51   OperatorInfoPtr Create(const std::string &name, const Shapes &shape_in, const Shapes &shape_out,
52                          const PrimitiveAttrs &attrs, size_t count) {
53     std::string op_name = name + std::to_string(count);
54     auto iter = Function_map_.find(name);
55     if (iter == Function_map_.end()) {
56       MS_LOG(INFO) << name << " is not register yet";
57       return nullptr;
58     }
59     return iter->second(op_name, shape_in, shape_out, attrs);
60   }
61 
62  private:
63   DynCreator() = default;
64   std::map<std::string, CreatFn> Function_map_;
65 };
66 
67 class RegisterAction {
68  public:
RegisterAction(const std::string & name,CreatFn creatfn)69   RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) {
70     DynCreator::Instance().Register(name, creatfn);
71   }
72   ~RegisterAction() = default;
73 
74  private:
75   std::string name_;
76 };
77 
78 // operator register
79 REGISTER(MatMulInfo);
80 REGISTER(GeLUInfo);
81 REGISTER(FastGeLUInfo);
82 REGISTER(VirtualDatasetInfo);
83 REGISTER(BatchParallelInfo);
84 REGISTER(TanhInfo);
85 REGISTER(SoftmaxInfo);
86 REGISTER(LogSoftmaxInfo);
87 REGISTER(ActivationInfo);
88 REGISTER(SoftmaxCrossEntropyWithLogitsInfo);
89 REGISTER(SubInfo);
90 REGISTER(AddInfo);
91 REGISTER(BiasAddInfo);
92 REGISTER(MulInfo);
93 REGISTER(DivInfo);
94 REGISTER(ModInfo);
95 REGISTER(RealDivInfo);
96 REGISTER(PowInfo);
97 REGISTER(ExpInfo);
98 REGISTER(OneHotInfo);
99 REGISTER(EqualInfo);
100 REGISTER(NotEqualInfo);
101 REGISTER(LogInfo);
102 REGISTER(CosInfo);
103 REGISTER(ACosInfo);
104 REGISTER(LogicalNotInfo);
105 REGISTER(L2NormalizeInfo);
106 REGISTER(LayerNormInfo);
107 REGISTER(ReduceMaxInfo);
108 REGISTER(ArgMaxWithValueInfo);
109 REGISTER(ArgMinWithValueInfo);
110 REGISTER(ReduceMeanInfo);
111 REGISTER(ReduceSumInfo);
112 REGISTER(ReduceMinInfo);
113 REGISTER(TransposeInfo);
114 REGISTER(PReLUInfo);
115 REGISTER(DropoutDoMaskInfo);
116 REGISTER(ReshapeInfo);
117 REGISTER(FloorDivInfo);
118 REGISTER(MaximumInfo);
119 REGISTER(MinimumInfo);
120 REGISTER(CastInfo);
121 REGISTER(GreaterInfo);
122 REGISTER(GreaterEqualInfo);
123 REGISTER(LessEqualInfo);
124 REGISTER(LessInfo);
125 REGISTER(ApproximateEqualInfo);
126 REGISTER(SparseSoftmaxCrossEntropyWithLogitsInfo);
127 REGISTER(AssignSubInfo);
128 REGISTER(FloorModInfo);
129 REGISTER(AssignInfo);
130 REGISTER(AssignAddInfo);
131 REGISTER(Atan2Info);
132 REGISTER(DivNoNanInfo);
133 REGISTER(LogicalAndInfo);
134 REGISTER(LogicalOrInfo);
135 REGISTER(EluInfo);
136 REGISTER(ReLUInfo);
137 REGISTER(RepeatElementsInfo);
138 REGISTER(TensorDotInfo);
139 REGISTER(RangeInfo);
140 REGISTER(ReLU6Info);
141 REGISTER(ReLUV2Info);
142 REGISTER(SoftplusInfo);
143 REGISTER(SoftsignInfo);
144 REGISTER(GatherInfo);
145 REGISTER(SparseGatherV2Info);
146 REGISTER(SqrtInfo);
147 REGISTER(SigmoidInfo);
148 REGISTER(GetNextInfo);
149 REGISTER(NegInfo);
150 REGISTER(AbsInfo);
151 REGISTER(AcoshInfo);
152 REGISTER(AsinInfo);
153 REGISTER(AsinhInfo);
154 REGISTER(AtanInfo);
155 REGISTER(AtanhInfo);
156 REGISTER(CeilInfo);
157 REGISTER(CoshInfo);
158 REGISTER(Expm1Info);
159 REGISTER(Log1pInfo);
160 REGISTER(SinInfo);
161 REGISTER(SinhInfo);
162 REGISTER(TanInfo);
163 REGISTER(RsqrtInfo);
164 REGISTER(InvInfo);
165 REGISTER(ReciprocalInfo);
166 REGISTER(RoundInfo);
167 REGISTER(FloorInfo);
168 REGISTER(SignInfo);
169 REGISTER(ErfInfo);
170 REGISTER(ErfcInfo);
171 REGISTER(ZerosLikeInfo);
172 REGISTER(OnesLikeInfo);
173 REGISTER(BesselI0eInfo);
174 REGISTER(BesselI1eInfo);
175 REGISTER(BatchMatMulInfo);
176 REGISTER(ExpandDimsInfo);
177 REGISTER(SqueezeInfo);
178 REGISTER(SigmoidCrossEntropyWithLogitsInfo);
179 REGISTER(SquareInfo);
180 REGISTER(UniformCandidateSamplerInfo);
181 REGISTER(UnsortedSegmentSumInfo);
182 REGISTER(UnsortedSegmentMinInfo);
183 REGISTER(UnsortedSegmentMaxInfo);
184 REGISTER(GatherPInfo);
185 REGISTER(EmbeddingLookupInfo);
186 REGISTER(TileInfo);
187 REGISTER(BroadcastToInfo);
188 REGISTER(StridedSliceInfo);
189 REGISTER(SliceInfo);
190 REGISTER(DropoutInfo);
191 REGISTER(StackInfo);
192 REGISTER(ConcatInfo);
193 REGISTER(SplitInfo);
194 REGISTER(UniqueInfo);
195 REGISTER(SelectInfo);
196 REGISTER(GatherNdInfo);
197 REGISTER(TopKInfo);
198 REGISTER(ScatterUpdateInfo);
199 REGISTER(VirtualOutputInfo);
200 REGISTER(Conv2DInfo);
201 REGISTER(Conv2DBackpropInputInfo);
202 REGISTER(Conv2DTransposeInfo);
203 REGISTER(BatchNormInfo);
204 REGISTER(MaxPoolInfo);
205 REGISTER(AvgPoolInfo);
206 REGISTER(GatherDInfo);
207 REGISTER(ReduceAnyInfo);
208 REGISTER(MatmulDDSInfo);
209 REGISTER(DSDMatmulInfo);
210 REGISTER(ResizeBilinearInfo);
211 REGISTER(ResizeNearestNeighborInfo);
212 REGISTER(UniformRealInfo);
213 }  // namespace parallel
214 }  // namespace mindspore
215 
216 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_
217