1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONV2D_INFO_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONV2D_INFO_H_ 19 20 #include <string> 21 #include <memory> 22 #include <unordered_map> 23 #include <vector> 24 25 #include "ir/value.h" 26 #include "frontend/parallel/graph_util/generate_graph.h" 27 #include "frontend/parallel/auto_parallel/operator_costmodel.h" 28 #include "frontend/parallel/ops_info/operator_info.h" 29 #include "frontend/parallel/strategy.h" 30 31 namespace mindspore { 32 namespace parallel { 33 class Conv2DInfo : public OperatorInfo { 34 public: Conv2DInfo(const std::string & operator_name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)35 Conv2DInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, 36 const PrimitiveAttrs &attrs) 37 : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()) {} 38 ~Conv2DInfo() override = default; 39 40 Status Init(const StrategyPtr &strategy) override; 41 Status InitForCostModel(const StrategyPtr &strategy) override; 42 std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; 43 Status SetCostUnderStrategy(const StrategyPtr &) override; 44 void ReComputeBatchSplitFlagList() override; 45 46 protected: 47 Status GetAttrsBase(); 48 Status GetAttrs() override; 49 Status CheckStrategyBase(const StrategyPtr &strategy); 50 Status CheckHWStrategyBase(int64_t h_strategy, int64_t w_strategy) const; 51 Status CheckStrategy(const StrategyPtr &strategy) override; 52 Status InferForwardCommunication() override; 53 Status InferDevMatrixShape() override; 54 Status InferTensorMap() override; 55 Status InferRankBias(); 56 void InferOverlapSize(); 57 void InferNewOperatorAttrs(); 58 void InferSendRecvFlag(); 59 void InferOverlapShapes(); 60 void InferStridedSliceAttrs(); 61 std::string ReplaceNodeName() const; 62 AnfNodePtr GenerateConv2DNode(const AnfNodePtr &new_input, const CNodePtr &cnode); 63 ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; 64 OperatorAttrs CreateNeighborExchangeAttrs(const CNodePtr &cnode); 65 OperatorAttrs CreateConv2DAttrs(); 66 void ComputeReplaceGraph(const CNodePtr &cnode); 67 68 int64_t out_channel_ = 1; 69 std::vector<int64_t> kernel_size_; // two integers 70 int64_t mode_ = 1; 71 int64_t pad_mode_ = 0; // "pad": 0; "same": 1; "valid": 2; 72 std::vector<int64_t> pad_list_; // four integers 73 std::vector<int64_t> stride_; // four integers 74 std::vector<int64_t> dilation_; // four integers 75 int64_t group_ = 1; 76 std::string format_; 77 bool out_channel_shard_ = false; 78 int64_t new_out_channel_ = 1; 79 std::vector<int64_t> new_pad_list_; 80 81 bool need_exchange_overlap_ = false; 82 int64_t rank_bias_ = 0; 83 int64_t left_rank_bias_ = -1; 84 int64_t right_rank_bias_ = -1; 85 int64_t left_rank_id_ = -1; 86 int64_t right_rank_id_ = -1; 87 int64_t overlap_left_size_ = 0; 88 int64_t overlap_right_size_ = 0; 89 int64_t left_rank_overlap_left_size_ = 0; 90 int64_t left_rank_overlap_right_size_ = 0; 91 int64_t right_rank_overlap_left_size_ = 0; 92 int64_t right_rank_overlap_right_size_ = 0; 93 int64_t w_dimension_shard_num_ = 1; 94 Shape input_slice_shape_; 95 96 bool left_need_send_ = false; 97 bool left_need_recv_ = false; 98 bool right_need_send_ = false; 99 bool right_need_recv_ = false; 100 Shape left_strided_slice_begin_; 101 Shape left_strided_slice_end_; 102 Shape left_strided_slice_strides_; 103 Shape right_strided_slice_begin_; 104 Shape right_strided_slice_end_; 105 Shape right_strided_slice_strides_; 106 107 std::vector<int64_t> send_rank_ids_; 108 std::vector<int64_t> recv_rank_ids_; 109 Shapes send_shapes_; 110 Shapes recv_shapes_; 111 112 GenerateGraph gen_g_ = GenerateGraph(attrs_); 113 114 virtual Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy); 115 virtual void InferNewPadList(); 116 virtual int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias); 117 virtual int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias); 118 119 private: 120 Status CheckHWStrategySameMode(int64_t h_strategy, int64_t w_strategy); 121 Status CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strategy); 122 }; 123 124 class Conv2DBackpropInputInfo : public Conv2DInfo { 125 public: Conv2DBackpropInputInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)126 Conv2DBackpropInputInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 127 const PrimitiveAttrs &attrs) 128 : Conv2DInfo(name, inputs_shape, outputs_shape, attrs) {} 129 ~Conv2DBackpropInputInfo() override = default; 130 void UpdateOutShape(); 131 void ReplaceNodeInputOrAttrs() override; 132 133 protected: 134 Status GetAttrs() override; 135 Status GetOutShape(); 136 Status CheckStrategy(const StrategyPtr &strategy) override; 137 Status InferDevMatrixShape() override; 138 Status InferTensorMap() override; 139 Status InferMirrorOps() override; // can not use OperatorInfo::InferMirrorOps(), since the 'out_shape' is not tensor 140 141 Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) override; 142 void InferNewPadList() override; 143 int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) override; 144 int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias) override; 145 146 private: 147 Shape out_shape_; 148 Shape out_slice_shape_; 149 }; 150 151 class Conv2DTransposeInfo : public Conv2DBackpropInputInfo { 152 public: Conv2DTransposeInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)153 Conv2DTransposeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, 154 const PrimitiveAttrs &attrs) 155 : Conv2DBackpropInputInfo(name, inputs_shape, outputs_shape, attrs) {} 156 ~Conv2DTransposeInfo() override = default; 157 }; 158 159 constexpr size_t IN_CHANNEL_INDEX = 1; 160 using Conv2DBackpropInputInfoPtr = std::shared_ptr<Conv2DBackpropInputInfo>; 161 } // namespace parallel 162 } // namespace mindspore 163 164 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONV2D_INFO_H_ 165