• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONV2D_INFO_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONV2D_INFO_H_
19 
20 #include <string>
21 #include <memory>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "ir/value.h"
26 #include "frontend/parallel/graph_util/generate_graph.h"
27 #include "frontend/parallel/auto_parallel/operator_costmodel.h"
28 #include "frontend/parallel/ops_info/operator_info.h"
29 #include "frontend/parallel/strategy.h"
30 
31 namespace mindspore {
32 namespace parallel {
33 class Conv2DInfo : public OperatorInfo {
34  public:
Conv2DInfo(const std::string & operator_name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)35   Conv2DInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
36              const PrimitiveAttrs &attrs)
37       : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()) {}
38   ~Conv2DInfo() override = default;
39 
40   Status Init(const StrategyPtr &strategy) override;
41   Status InitForCostModel(const StrategyPtr &strategy) override;
42   std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
43   Status SetCostUnderStrategy(const StrategyPtr &) override;
44   void ReComputeBatchSplitFlagList() override;
45 
46  protected:
47   Status GetAttrsBase();
48   Status GetAttrs() override;
49   Status CheckStrategyBase(const StrategyPtr &strategy);
50   Status CheckHWStrategyBase(int64_t h_strategy, int64_t w_strategy) const;
51   Status CheckStrategy(const StrategyPtr &strategy) override;
52   Status InferForwardCommunication() override;
53   Status InferDevMatrixShape() override;
54   Status InferTensorMap() override;
55   Status InferRankBias();
56   void InferOverlapSize();
57   void InferNewOperatorAttrs();
58   void InferSendRecvFlag();
59   void InferOverlapShapes();
60   void InferStridedSliceAttrs();
61   std::string ReplaceNodeName() const;
62   AnfNodePtr GenerateConv2DNode(const AnfNodePtr &new_input, const CNodePtr &cnode);
63   ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
64   OperatorAttrs CreateNeighborExchangeAttrs(const CNodePtr &cnode);
65   OperatorAttrs CreateConv2DAttrs();
66   void ComputeReplaceGraph(const CNodePtr &cnode);
67 
68   int64_t out_channel_ = 1;
69   std::vector<int64_t> kernel_size_;  // two integers
70   int64_t mode_ = 1;
71   int64_t pad_mode_ = 0;           // "pad": 0; "same": 1; "valid": 2;
72   std::vector<int64_t> pad_list_;  // four integers
73   std::vector<int64_t> stride_;    // four integers
74   std::vector<int64_t> dilation_;  // four integers
75   int64_t group_ = 1;
76   std::string format_;
77   bool out_channel_shard_ = false;
78   int64_t new_out_channel_ = 1;
79   std::vector<int64_t> new_pad_list_;
80 
81   bool need_exchange_overlap_ = false;
82   int64_t rank_bias_ = 0;
83   int64_t left_rank_bias_ = -1;
84   int64_t right_rank_bias_ = -1;
85   int64_t left_rank_id_ = -1;
86   int64_t right_rank_id_ = -1;
87   int64_t overlap_left_size_ = 0;
88   int64_t overlap_right_size_ = 0;
89   int64_t left_rank_overlap_left_size_ = 0;
90   int64_t left_rank_overlap_right_size_ = 0;
91   int64_t right_rank_overlap_left_size_ = 0;
92   int64_t right_rank_overlap_right_size_ = 0;
93   int64_t w_dimension_shard_num_ = 1;
94   Shape input_slice_shape_;
95 
96   bool left_need_send_ = false;
97   bool left_need_recv_ = false;
98   bool right_need_send_ = false;
99   bool right_need_recv_ = false;
100   Shape left_strided_slice_begin_;
101   Shape left_strided_slice_end_;
102   Shape left_strided_slice_strides_;
103   Shape right_strided_slice_begin_;
104   Shape right_strided_slice_end_;
105   Shape right_strided_slice_strides_;
106 
107   std::vector<int64_t> send_rank_ids_;
108   std::vector<int64_t> recv_rank_ids_;
109   Shapes send_shapes_;
110   Shapes recv_shapes_;
111 
112   GenerateGraph gen_g_ = GenerateGraph(attrs_);
113 
114   virtual Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
115   virtual void InferNewPadList();
116   virtual int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias);
117   virtual int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias);
118 
119  private:
120   Status CheckHWStrategySameMode(int64_t h_strategy, int64_t w_strategy);
121   Status CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strategy);
122 };
123 
124 class Conv2DBackpropInputInfo : public Conv2DInfo {
125  public:
Conv2DBackpropInputInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)126   Conv2DBackpropInputInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
127                           const PrimitiveAttrs &attrs)
128       : Conv2DInfo(name, inputs_shape, outputs_shape, attrs) {}
129   ~Conv2DBackpropInputInfo() override = default;
130   void UpdateOutShape();
131   void ReplaceNodeInputOrAttrs() override;
132 
133  protected:
134   Status GetAttrs() override;
135   Status GetOutShape();
136   Status CheckStrategy(const StrategyPtr &strategy) override;
137   Status InferDevMatrixShape() override;
138   Status InferTensorMap() override;
139   Status InferMirrorOps() override;  // can not use OperatorInfo::InferMirrorOps(), since the 'out_shape' is not tensor
140 
141   Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) override;
142   void InferNewPadList() override;
143   int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) override;
144   int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias) override;
145 
146  private:
147   Shape out_shape_;
148   Shape out_slice_shape_;
149 };
150 
151 class Conv2DTransposeInfo : public Conv2DBackpropInputInfo {
152  public:
Conv2DTransposeInfo(const std::string & name,const Shapes & inputs_shape,const Shapes & outputs_shape,const PrimitiveAttrs & attrs)153   Conv2DTransposeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
154                       const PrimitiveAttrs &attrs)
155       : Conv2DBackpropInputInfo(name, inputs_shape, outputs_shape, attrs) {}
156   ~Conv2DTransposeInfo() override = default;
157 };
158 
159 constexpr size_t IN_CHANNEL_INDEX = 1;
160 using Conv2DBackpropInputInfoPtr = std::shared_ptr<Conv2DBackpropInputInfo>;
161 }  // namespace parallel
162 }  // namespace mindspore
163 
164 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONV2D_INFO_H_
165