1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ 19 20 #include <ir/value.h> 21 #include <memory> 22 #include <string> 23 #include <unordered_map> 24 #include <vector> 25 26 #include "frontend/parallel/auto_parallel/operator_costmodel.h" 27 #include "frontend/parallel/ops_info/operator_info.h" 28 #include "frontend/parallel/strategy.h" 29 30 namespace mindspore { 31 namespace parallel { 32 class ActivationBase : public OperatorInfo { 33 public: ActivationBase(const std::string & operator_name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs,OperatorCostPtr cost)34 ActivationBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, 35 const PrimitiveAttrs &attrs, OperatorCostPtr cost) 36 : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} 37 ~ActivationBase() override = default; 38 39 Status Init(const StrategyPtr &strategy) override; 40 Status InitForCostModel(const StrategyPtr &strategy) override; 41 42 protected: 43 Status InferMirrorOps() override; 44 Status InferForwardCommunication() override; 45 Status InferTensorMap() override; 46 Status InferDevMatrixShape() override; 47 }; 48 49 class Activation : public ActivationBase { 50 public: Activation(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs,OperatorCostPtr cost)51 Activation(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 52 const PrimitiveAttrs &attrs, OperatorCostPtr cost) 53 : ActivationBase(name, inputs_shape, outputs_shape, attrs, cost) {} 54 ~Activation() override = default; 55 std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; 56 Status SetCostUnderStrategy(const StrategyPtr &strategy) override; 57 58 protected: 59 Status CheckStrategy(const StrategyPtr &strategy) override; 60 }; 61 62 class ActivationInfo : public Activation { 63 public: ActivationInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)64 ActivationInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 65 const PrimitiveAttrs &attrs) 66 : Activation(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationInfoCost>()) {} 67 ~ActivationInfo() override = default; 68 69 protected: 70 Status GetAttrs() override; // activation_type: relu, relu6, sigmoid 71 }; 72 73 class ActivationOther : public Activation { 74 public: ActivationOther(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs,OperatorCostPtr cost)75 ActivationOther(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 76 const PrimitiveAttrs &attrs, OperatorCostPtr cost) 77 : Activation(name, inputs_shape, outputs_shape, attrs, cost) {} 78 ~ActivationOther() override = default; 79 80 protected: 81 Status GetAttrs() override; 82 }; 83 84 class GeLUInfo : public ActivationOther { 85 public: GeLUInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)86 GeLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 87 const PrimitiveAttrs &attrs) 88 : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<GeLUCost>()) {} 89 ~GeLUInfo() override = default; 90 }; 91 92 class FastGeLUInfo : public ActivationOther { 93 public: FastGeLUInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)94 FastGeLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 95 const PrimitiveAttrs &attrs) 96 : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<FastGeLUCost>()) {} 97 ~FastGeLUInfo() override = default; 98 }; 99 100 class TanhInfo : public ActivationOther { 101 public: TanhInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)102 TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 103 const PrimitiveAttrs &attrs) 104 : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<TanhCost>()) {} 105 ~TanhInfo() override = default; 106 }; 107 108 class Softmax : public ActivationBase { 109 public: Softmax(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)110 explicit Softmax(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 111 const PrimitiveAttrs &attrs) 112 : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>()) {} 113 ~Softmax() override = default; 114 std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; 115 Status SetCostUnderStrategy(const StrategyPtr &strategy) override; 116 117 protected: 118 Status CheckStrategy(const StrategyPtr &strategy) override; 119 Status GetAttrs() override; 120 121 private: 122 std::vector<int64_t> axis_; 123 }; 124 125 class SoftmaxInfo : public Softmax { 126 public: SoftmaxInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)127 SoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 128 const PrimitiveAttrs &attrs) 129 : Softmax(name, inputs_shape, outputs_shape, attrs) {} 130 ~SoftmaxInfo() override = default; 131 }; 132 133 class LogSoftmaxInfo : public Softmax { 134 public: LogSoftmaxInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)135 LogSoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 136 const PrimitiveAttrs &attrs) 137 : Softmax(name, inputs_shape, outputs_shape, attrs) {} 138 ~LogSoftmaxInfo() override = default; 139 }; 140 141 class EluInfo : public ActivationOther { 142 public: EluInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)143 EluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) 144 : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<EluCost>()) {} 145 ~EluInfo() override = default; 146 }; 147 148 class ReLUInfo : public ActivationOther { 149 public: ReLUInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)150 ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 151 const PrimitiveAttrs &attrs) 152 : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLUCost>()) {} 153 ~ReLUInfo() override = default; 154 }; 155 156 class RepeatElementsInfo : public ActivationOther { 157 public: RepeatElementsInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)158 RepeatElementsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 159 const PrimitiveAttrs &attrs) 160 : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<RepeatElementsCost>()) {} 161 ~RepeatElementsInfo() override = default; 162 }; 163 164 class ReLU6Info : public ActivationOther { 165 public: ReLU6Info(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)166 ReLU6Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 167 const PrimitiveAttrs &attrs) 168 : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLU6Cost>()) {} 169 ~ReLU6Info() override = default; 170 }; 171 172 class SoftsignInfo : public ActivationOther { 173 public: SoftsignInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)174 SoftsignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 175 const PrimitiveAttrs &attrs) 176 : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftsignCost>()) {} 177 ~SoftsignInfo() override = default; 178 }; 179 180 class SoftplusInfo : public ActivationOther { 181 public: SoftplusInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)182 SoftplusInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 183 const PrimitiveAttrs &attrs) 184 : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftplusCost>()) {} 185 ~SoftplusInfo() override = default; 186 }; 187 188 class CastInfo : public ActivationOther { 189 public: CastInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)190 CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 191 const PrimitiveAttrs &attrs) 192 : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<CastCost>()) {} 193 ~CastInfo() override = default; 194 195 protected: 196 Status InferMirrorOps() override; 197 }; 198 199 class SqrtInfo : public ActivationOther { 200 public: SqrtInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)201 SqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 202 const PrimitiveAttrs &attrs) 203 : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SqrtCost>()) {} 204 ~SqrtInfo() override = default; 205 }; 206 207 class NegInfo : public ActivationOther { 208 public: NegInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)209 NegInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) 210 : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<NegCost>()) {} 211 ~NegInfo() override = default; 212 }; 213 214 class ExpandDimsInfo : public ActivationOther { 215 public: ExpandDimsInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)216 ExpandDimsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 217 const PrimitiveAttrs &attrs) 218 : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ExpandDimsCost>()) {} 219 ~ExpandDimsInfo() override = default; 220 221 protected: 222 Status GetAttrs() override; 223 Status InferTensorMap() override; 224 Status InferMirrorOps() override; 225 Status InferTensorStrategy(); 226 227 private: 228 int64_t positive_axis_ = -1; 229 Strategys inputs_strategy_; 230 Strategys outputs_strategy_; 231 }; 232 233 class SqueezeInfo : public ActivationOther { 234 public: SqueezeInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)235 SqueezeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 236 const PrimitiveAttrs &attrs) 237 : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SqueezeCost>()) {} 238 ~SqueezeInfo() override = default; 239 240 protected: 241 Status InferAxis(const ValueTuplePtr &value_tuple); 242 Status GetAttrs() override; 243 Status InferReplaceOps(); 244 Status InferTensorMap() override; 245 Status Init(const StrategyPtr &strategy) override; 246 247 private: 248 ValueTuplePtr axis_; 249 }; 250 251 class SquareInfo : public ActivationOther { 252 public: SquareInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)253 SquareInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 254 const PrimitiveAttrs &attrs) 255 : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SquareCost>()) {} 256 ~SquareInfo() override = default; 257 }; 258 259 class SigmoidInfo : public ActivationOther { 260 public: SigmoidInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)261 SigmoidInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 262 const PrimitiveAttrs &attrs) 263 : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SigmoidCost>()) {} 264 ~SigmoidInfo() override = default; 265 }; 266 267 class DropoutInfo : public ActivationOther { 268 public: DropoutInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)269 DropoutInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 270 const PrimitiveAttrs &attrs) 271 : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>()) {} 272 ~DropoutInfo() override = default; 273 std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; 274 Status Init(const StrategyPtr &strategy) override; 275 276 protected: 277 Status GetAttrs() override; 278 Status InferTensorMap() override; 279 Status InferReplaceOps(); 280 Status InferAsLossDivisor() override; 281 282 private: 283 int64_t seed0_ = 0; 284 int64_t seed1_ = 0; get_seed()285 int64_t get_seed() { 286 static int64_t SEED_NUM; 287 return ++SEED_NUM; 288 } 289 }; 290 } // namespace parallel 291 } // namespace mindspore 292 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ 293