• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_
18 #define PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_
19 
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 #include <list>
25 #include <map>
26 
27 #include "frontend/parallel/auto_parallel/rec_core/rec_graph.h"
28 #include "frontend/parallel/ops_info/operator_info.h"
29 
30 namespace mindspore {
31 namespace parallel {
32 static std::map<std::string, Dimensions> param_strategy_;
33 class RecStrategyPropagator {
34  public:
35   typedef std::list<size_t> prop_list_t;
36 
37  private:
38   std::shared_ptr<Graph> graph_;
39   const std::vector<std::shared_ptr<OperatorInfo>> &ops_;
40   std::shared_ptr<std::vector<std::vector<size_t>>> eli_list_;
41   const std::vector<std::vector<std::string>> &input_tensor_names_;
42   std::shared_ptr<std::vector<size_t>> index_list_;
43   bool is_training_;
44   std::vector<std::vector<size_t>> shared_tensors_ops_;
45   FuncGraphPtr root_;
46 
47   prop_list_t forward_;
48   prop_list_t backward_;
49   std::shared_ptr<std::vector<size_t>> no_stra_op_list_;
50   std::vector<size_t> source_ops_;
51 
52   void FixInvalidStra();
53   void CheckConnectedComponents();
54 
55   void AjustToNoTraining();
56 
57   void ApplyStrategy(size_t i_op, const Strategies &strategy);
58 
59   size_t GenerateEliminatedOperatorStrategyForward(size_t min_devices = 1);
60   size_t GenerateEliminatedOperatorStrategyBackward(size_t min_devices = 1);
61   size_t GenerateRemainingOperatorStrategy();
62   size_t ModifyParamSharingOpsStrategy();
63   size_t AssignStandaloneAndBatchParallelOpStrategy();
64 
65   std::map<std::string, std::vector<std::pair<size_t, size_t>>> GetParamUsers();
66   void SetParamStrategy();
67   size_t ApplyParamStrategy();
68 
69  public:
70   RecStrategyPropagator(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
71                         const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
72                         const std::vector<std::vector<std::string>> &input_tensor_names,
73                         const std::shared_ptr<std::vector<size_t>> &index_list, bool is_training,
74                         const std::vector<std::vector<size_t>> &shared_tensors_ops, const FuncGraphPtr &root);
75 
76   size_t GetMaxDimNum(size_t i_op);
77   Dimensions GetDefaultStrategy(size_t i_op);
78 
79   size_t CopyMainOperatorsStrategy();
80   size_t PropagateFromInputs();
81   size_t PropagateFromOutputs();
82 
83   void GenerateNoStraList();
84   void ExtraShardMatmulOnBatchDim();
85 
86   void GenerateStrategyV1();
87   void GenerateStrategyV3();
88 };
89 
90 Dimensions GetInputStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
91                             const std::shared_ptr<std::vector<size_t>> &index_list, size_t i_op,
92                             size_t incoming_op_index);
93 
94 void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
95                       const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
96                       const std::vector<std::vector<std::string>> &input_tensor_names,
97                       const std::shared_ptr<std::vector<size_t>> &index_list, bool is_training,
98                       const std::vector<std::vector<size_t>> &shared_tensors_ops, const FuncGraphPtr &root);
99 Dimensions PrepareMatMulStrategy(Graph::NodeType *node, bool transpose_a, bool transpose_b, size_t iter_op_inputs);
100 Strategies PrepareMatMul(Graph::NodeType *node, const std::shared_ptr<OperatorInfo> &op);
101 Dimensions PrepareBatchMatMulStrategy(Graph::NodeType *node, const bool transpose_a, const bool transpose_b,
102                                       const size_t iter_op_inputs, const size_t dim_num);
103 Strategies PrepareBatchMatMul(Graph::NodeType *node, const std::shared_ptr<OperatorInfo> &op);
104 Strategies PreparePropagateBatchMatMul(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra);
105 Strategies PrepareBiasAdd(const std::shared_ptr<Dimensions> &strategy);
106 Strategies PrepareStridedSlice(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra, bool dyn_shape_tmp_fix);
107 Strategies PrepareSoftMax(const std::shared_ptr<OperatorInfo> &op, const Dimensions &basic_stra);
108 Strategies PrepareLayerNorm(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra);
109 Strategies PrepareOneHot(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy);
110 Strategies PrepareGather(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy, bool dyn_shape_tmp_fix);
111 Dimensions PrepareGatherV2OutputStrategy(const std::shared_ptr<OperatorInfo> &op);
112 Strategies PrepareL2Normalize(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy);
113 Strategies PrepareAxisRelatedStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
114                                       const size_t iter_ops);
115 Strategies MakeRecSearchStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
116                                  const size_t iter_ops);
117 Strategies MakeDataParallelStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
118                                     const size_t iter_ops);
119 Strategies MakeFullBatchStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
120                                  const size_t iter_ops);
121 void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op);
122 Strategies PrepareStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
123                            const size_t iter_ops, const bool dyn_shape_tmp_fix);
124 bool HasStrategy(std::shared_ptr<OperatorInfo> op);
125 size_t FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names, size_t iter_ops);
126 std::pair<size_t, size_t> FindIndexOfOperatorOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
127                                                       const std::vector<std::vector<std::string>> &input_tensor_names,
128                                                       size_t iter_ops);
129 Dimensions CopyIncomingOperatorOutputStrategy(Graph::NodeType *node,
130                                               const std::vector<std::shared_ptr<OperatorInfo>> &ops,
131                                               const size_t iter_ops, const size_t incoming_op_index);
132 Dimensions PrepareReshapeOutputStrategy(const std::shared_ptr<OperatorInfo> &op);
133 Dimensions PrepareTransposeOutputStrategy(const std::shared_ptr<OperatorInfo> &op);
134 Dimensions PrepareExpandDimsOutputStrategy(const std::shared_ptr<OperatorInfo> &op);
135 Dimensions PrepareIncomingArithmeticOpeartorInputStrategy(const std::shared_ptr<OperatorInfo> &op);
136 Dimensions PrepareIncomingOperatorInputStrategy(const std::shared_ptr<OperatorInfo> &op);
137 Dimensions GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int64_t iter_ops);
138 Dimensions ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
139                                            const size_t incoming_op_index, Dimensions strategy);
140 Dimensions ModifyStrategyIfReduceIncoming(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy);
141 Dimensions GetDimListFromAttrs(const std::shared_ptr<OperatorInfo> &op);
142 Dimensions ModifyStrategyIfArgIncoming(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy);
143 Dimensions ModifyStrategyIfFlattenIncoming(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy);
144 Dimensions CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
145                                              const size_t iter_ops, const size_t incoming_op_index);
146 Strategies GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
147                                           Dimensions basic_stra, bool dyn_shape_tmp_fix);
148 Strategies CheckBroadcast(const std::shared_ptr<OperatorInfo> &op, const Dimensions &strategy);
149 Dimensions ApplyBroadcast(const std::shared_ptr<OperatorInfo> &op, const Dimensions &strategy,
150                           bool broadcast_first_tensor);
151 Strategies CheckDivisible(const std::shared_ptr<OperatorInfo> &op, const Dimensions &strategy);
152 Dimensions ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
153                                            Dimensions strategy);
154 Dimensions PrepareTransposeInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t i_ops,
155                                          size_t outgoing_op_index, size_t iter_op_inputs);
156 Dimensions CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops,
157                                              size_t outgoing_op_index, size_t iter_op_inputs, bool dyn_shape_tmp_fix);
158 }  // namespace parallel
159 }  // namespace mindspore
160 #endif  // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_
161