1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <utility> 24 25 #include "frontend/parallel/ops_info/ops_info_head_files.h" 26 #include "frontend/parallel/step_parallel.h" 27 28 namespace mindspore { 29 namespace parallel { 30 #define REGISTER(className) \ 31 OperatorInfoPtr objectCreator##className(std::string name, Shapes in, Shapes out, PrimitiveAttrs &attrs) { \ 32 return std::make_shared<className>(name, in, out, attrs); \ 33 } \ 34 RegisterAction className##Register(#className, (CreatFn)objectCreator##className); 35 36 typedef OperatorInfoPtr (*CreatFn)(const std::string &name, const Shapes &shape_in, const Shapes shape_out, 37 const PrimitiveAttrs &attrs); 38 39 class DynCreator { 40 public: 41 ~DynCreator() = default; 42 43 // create static singleton dyn_creator instance Instance()44 static DynCreator &Instance() { 45 static DynCreator fac = DynCreator(); 46 return fac; 47 } 48 // register Register(std::string name,CreatFn func)49 void Register(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); } 50 // creator Create(const std::string & name,const Shapes & shape_in,const Shapes & shape_out,const PrimitiveAttrs & attrs,size_t count)51 OperatorInfoPtr Create(const std::string &name, const Shapes &shape_in, const Shapes &shape_out, 52 const PrimitiveAttrs &attrs, size_t count) { 53 std::string op_name = name + std::to_string(count); 54 auto iter = Function_map_.find(name); 55 if (iter == Function_map_.end()) { 56 MS_LOG(INFO) << name << " is not register yet"; 57 return nullptr; 58 } 59 return iter->second(op_name, shape_in, shape_out, attrs); 60 } 61 62 private: 63 DynCreator() = default; 64 std::map<std::string, CreatFn> Function_map_; 65 }; 66 67 class RegisterAction { 68 public: RegisterAction(const std::string & name,CreatFn creatfn)69 RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) { 70 DynCreator::Instance().Register(name, creatfn); 71 } 72 ~RegisterAction() = default; 73 74 private: 75 std::string name_; 76 }; 77 78 // operator register 79 REGISTER(MatMulInfo); 80 REGISTER(GeLUInfo); 81 REGISTER(FastGeLUInfo); 82 REGISTER(VirtualDatasetInfo); 83 REGISTER(BatchParallelInfo); 84 REGISTER(TanhInfo); 85 REGISTER(SoftmaxInfo); 86 REGISTER(LogSoftmaxInfo); 87 REGISTER(ActivationInfo); 88 REGISTER(SoftmaxCrossEntropyWithLogitsInfo); 89 REGISTER(SubInfo); 90 REGISTER(AddInfo); 91 REGISTER(BiasAddInfo); 92 REGISTER(MulInfo); 93 REGISTER(DivInfo); 94 REGISTER(ModInfo); 95 REGISTER(RealDivInfo); 96 REGISTER(PowInfo); 97 REGISTER(ExpInfo); 98 REGISTER(OneHotInfo); 99 REGISTER(EqualInfo); 100 REGISTER(NotEqualInfo); 101 REGISTER(LogInfo); 102 REGISTER(CosInfo); 103 REGISTER(ACosInfo); 104 REGISTER(LogicalNotInfo); 105 REGISTER(L2NormalizeInfo); 106 REGISTER(LayerNormInfo); 107 REGISTER(ReduceMaxInfo); 108 REGISTER(ArgMaxWithValueInfo); 109 REGISTER(ArgMinWithValueInfo); 110 REGISTER(ReduceMeanInfo); 111 REGISTER(ReduceSumInfo); 112 REGISTER(ReduceMinInfo); 113 REGISTER(TransposeInfo); 114 REGISTER(PReLUInfo); 115 REGISTER(DropoutDoMaskInfo); 116 REGISTER(ReshapeInfo); 117 REGISTER(FloorDivInfo); 118 REGISTER(MaximumInfo); 119 REGISTER(MinimumInfo); 120 REGISTER(CastInfo); 121 REGISTER(GreaterInfo); 122 REGISTER(GreaterEqualInfo); 123 REGISTER(LessEqualInfo); 124 REGISTER(LessInfo); 125 REGISTER(ApproximateEqualInfo); 126 REGISTER(SparseSoftmaxCrossEntropyWithLogitsInfo); 127 REGISTER(AssignSubInfo); 128 REGISTER(FloorModInfo); 129 REGISTER(AssignInfo); 130 REGISTER(AssignAddInfo); 131 REGISTER(Atan2Info); 132 REGISTER(DivNoNanInfo); 133 REGISTER(LogicalAndInfo); 134 REGISTER(LogicalOrInfo); 135 REGISTER(EluInfo); 136 REGISTER(ReLUInfo); 137 REGISTER(RepeatElementsInfo); 138 REGISTER(TensorDotInfo); 139 REGISTER(RangeInfo); 140 REGISTER(ReLU6Info); 141 REGISTER(ReLUV2Info); 142 REGISTER(SoftplusInfo); 143 REGISTER(SoftsignInfo); 144 REGISTER(GatherInfo); 145 REGISTER(SparseGatherV2Info); 146 REGISTER(SqrtInfo); 147 REGISTER(SigmoidInfo); 148 REGISTER(GetNextInfo); 149 REGISTER(NegInfo); 150 REGISTER(AbsInfo); 151 REGISTER(AcoshInfo); 152 REGISTER(AsinInfo); 153 REGISTER(AsinhInfo); 154 REGISTER(AtanInfo); 155 REGISTER(AtanhInfo); 156 REGISTER(CeilInfo); 157 REGISTER(CoshInfo); 158 REGISTER(Expm1Info); 159 REGISTER(Log1pInfo); 160 REGISTER(SinInfo); 161 REGISTER(SinhInfo); 162 REGISTER(TanInfo); 163 REGISTER(RsqrtInfo); 164 REGISTER(InvInfo); 165 REGISTER(ReciprocalInfo); 166 REGISTER(RoundInfo); 167 REGISTER(FloorInfo); 168 REGISTER(SignInfo); 169 REGISTER(ErfInfo); 170 REGISTER(ErfcInfo); 171 REGISTER(ZerosLikeInfo); 172 REGISTER(OnesLikeInfo); 173 REGISTER(BesselI0eInfo); 174 REGISTER(BesselI1eInfo); 175 REGISTER(BatchMatMulInfo); 176 REGISTER(ExpandDimsInfo); 177 REGISTER(SqueezeInfo); 178 REGISTER(SigmoidCrossEntropyWithLogitsInfo); 179 REGISTER(SquareInfo); 180 REGISTER(UniformCandidateSamplerInfo); 181 REGISTER(UnsortedSegmentSumInfo); 182 REGISTER(UnsortedSegmentMinInfo); 183 REGISTER(UnsortedSegmentMaxInfo); 184 REGISTER(GatherPInfo); 185 REGISTER(EmbeddingLookupInfo); 186 REGISTER(TileInfo); 187 REGISTER(BroadcastToInfo); 188 REGISTER(StridedSliceInfo); 189 REGISTER(SliceInfo); 190 REGISTER(DropoutInfo); 191 REGISTER(StackInfo); 192 REGISTER(ConcatInfo); 193 REGISTER(SplitInfo); 194 REGISTER(UniqueInfo); 195 REGISTER(SelectInfo); 196 REGISTER(GatherNdInfo); 197 REGISTER(TopKInfo); 198 REGISTER(ScatterUpdateInfo); 199 REGISTER(VirtualOutputInfo); 200 REGISTER(Conv2DInfo); 201 REGISTER(Conv2DBackpropInputInfo); 202 REGISTER(Conv2DTransposeInfo); 203 REGISTER(BatchNormInfo); 204 REGISTER(MaxPoolInfo); 205 REGISTER(AvgPoolInfo); 206 REGISTER(GatherDInfo); 207 REGISTER(ReduceAnyInfo); 208 REGISTER(MatmulDDSInfo); 209 REGISTER(DSDMatmulInfo); 210 REGISTER(ResizeBilinearInfo); 211 REGISTER(ResizeNearestNeighborInfo); 212 REGISTER(UniformRealInfo); 213 } // namespace parallel 214 } // namespace mindspore 215 216 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_ 217