1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ 19 20 #include <memory> 21 #include <string> 22 #include <unordered_map> 23 #include <vector> 24 25 #include "ir/value.h" 26 #include "frontend/parallel/auto_parallel/operator_costmodel.h" 27 #include "frontend/parallel/ops_info/operator_info.h" 28 #include "frontend/parallel/strategy.h" 29 30 namespace mindspore { 31 namespace parallel { 32 class ArithmeticBase : public OperatorInfo { 33 public: ArithmeticBase(const std::string & operator_name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs,OperatorCostPtr cost)34 ArithmeticBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, 35 const PrimitiveAttrs &attrs, OperatorCostPtr cost) 36 : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} 37 ~ArithmeticBase() override = default; 38 Status Init(const StrategyPtr &strategy) override; 39 Status InitForCostModel(const StrategyPtr &strategy) override; 40 std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; 41 Status SetCostUnderStrategy(const StrategyPtr &) override; 42 void ReComputeBatchSplitFlagList() override; 43 44 protected: GetAttrs()45 Status GetAttrs() override { return SUCCESS; } 46 Status CheckStrategy(const StrategyPtr &strategy) override; InferForwardCommunication()47 Status InferForwardCommunication() override { return SUCCESS; } 48 Status InferDevMatrixShape() override; 49 Status InferTensorMap() override; 50 Shapes InferExpendShape(); 51 }; 52 53 class SubInfo : public ArithmeticBase { 54 public: SubInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)55 SubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) 56 : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SubCost>()) {} 57 ~SubInfo() override = default; 58 }; 59 60 class AddInfo : public ArithmeticBase { 61 public: AddInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)62 AddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) 63 : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorAddCost>()) {} 64 ~AddInfo() override = default; 65 }; 66 67 class MulInfo : public ArithmeticBase { 68 public: MulInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)69 MulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) 70 : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<MulCost>()) {} 71 ~MulInfo() override = default; 72 }; 73 74 class DivInfo : public ArithmeticBase { 75 public: DivInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)76 DivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) 77 : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<DivCost>()) {} 78 ~DivInfo() override = default; 79 }; 80 81 class ModInfo : public ArithmeticBase { 82 public: ModInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)83 ModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) 84 : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ModCost>()) {} 85 ~ModInfo() override = default; 86 }; 87 88 class RealDivInfo : public ArithmeticBase { 89 public: RealDivInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)90 RealDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 91 const PrimitiveAttrs &attrs) 92 : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReadDivCost>()) {} 93 ~RealDivInfo() override = default; 94 }; 95 96 class FloorDivInfo : public ArithmeticBase { 97 public: FloorDivInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)98 FloorDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 99 const PrimitiveAttrs &attrs) 100 : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<FloorDivCost>()) {} 101 ~FloorDivInfo() override = default; 102 }; 103 104 class FloorModInfo : public ArithmeticBase { 105 public: FloorModInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)106 FloorModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 107 const PrimitiveAttrs &attrs) 108 : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<FloorModCost>()) {} 109 ~FloorModInfo() override = default; 110 }; 111 112 class PowInfo : public ArithmeticBase { 113 public: PowInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)114 PowInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) 115 : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<PowCost>()) {} 116 ~PowInfo() override = default; 117 }; 118 119 class AssignSubInfo : public ArithmeticBase { 120 public: AssignSubInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)121 AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 122 const PrimitiveAttrs &attrs) 123 : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<AssignSubCost>()) {} 124 ~AssignSubInfo() override = default; 125 }; 126 127 class AssignInfo : public ArithmeticBase { 128 public: AssignInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)129 AssignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 130 const PrimitiveAttrs &attrs) 131 : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<AssignCost>()) {} 132 ~AssignInfo() override = default; 133 }; 134 135 class AssignAddInfo : public ArithmeticBase { 136 public: AssignAddInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)137 AssignAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 138 const PrimitiveAttrs &attrs) 139 : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<AssignAddCost>()) {} 140 ~AssignAddInfo() override = default; 141 }; 142 143 // All dimensions can be split arbitrarily, but the split method of Logits should be the same as that of label. 144 class SigmoidCrossEntropyWithLogitsInfo : public ArithmeticBase { 145 public: SigmoidCrossEntropyWithLogitsInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)146 SigmoidCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 147 const PrimitiveAttrs &attrs) 148 : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, 149 std::make_shared<SigmoidCrossEntropyWithLogitsCost>()) {} 150 ~SigmoidCrossEntropyWithLogitsInfo() override = default; 151 }; 152 153 class Atan2Info : public ArithmeticBase { 154 public: Atan2Info(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)155 Atan2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 156 const PrimitiveAttrs &attrs) 157 : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<Atan2Cost>()) {} 158 ~Atan2Info() override = default; 159 }; 160 161 class DivNoNanInfo : public ArithmeticBase { 162 public: DivNoNanInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)163 DivNoNanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 164 const PrimitiveAttrs &attrs) 165 : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<DivNoNanCost>()) {} 166 ~DivNoNanInfo() override = default; 167 }; 168 169 class LogicalAndInfo : public ArithmeticBase { 170 public: LogicalAndInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)171 LogicalAndInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 172 const PrimitiveAttrs &attrs) 173 : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogicalAndCost>()) {} 174 ~LogicalAndInfo() override = default; 175 }; 176 177 class LogicalOrInfo : public ArithmeticBase { 178 public: LogicalOrInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)179 LogicalOrInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 180 const PrimitiveAttrs &attrs) 181 : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogicalOrCost>()) {} 182 ~LogicalOrInfo() override = default; 183 }; 184 } // namespace parallel 185 } // namespace mindspore 186 187 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ 188