1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ 18 #define PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ 19 20 #include <memory> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 #include <list> 25 #include <map> 26 27 #include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" 28 #include "frontend/parallel/ops_info/operator_info.h" 29 30 namespace mindspore { 31 namespace parallel { 32 static std::map<std::string, Dimensions> param_strategy_; 33 class RecStrategyPropagator { 34 public: 35 typedef std::list<size_t> prop_list_t; 36 37 private: 38 std::shared_ptr<Graph> graph_; 39 const std::vector<std::shared_ptr<OperatorInfo>> &ops_; 40 std::shared_ptr<std::vector<std::vector<size_t>>> eli_list_; 41 const std::vector<std::vector<std::string>> &input_tensor_names_; 42 std::shared_ptr<std::vector<size_t>> index_list_; 43 bool is_training_; 44 std::vector<std::vector<size_t>> shared_tensors_ops_; 45 FuncGraphPtr root_; 46 47 prop_list_t forward_; 48 prop_list_t backward_; 49 std::shared_ptr<std::vector<size_t>> no_stra_op_list_; 50 std::vector<size_t> source_ops_; 51 52 void FixInvalidStra(); 53 void CheckConnectedComponents(); 54 55 void AjustToNoTraining(); 56 57 void ApplyStrategy(size_t i_op, const Strategies &strategy); 58 59 size_t GenerateEliminatedOperatorStrategyForward(size_t min_devices = 1); 60 size_t GenerateEliminatedOperatorStrategyBackward(size_t min_devices = 1); 61 size_t GenerateRemainingOperatorStrategy(); 62 size_t ModifyParamSharingOpsStrategy(); 63 size_t AssignStandaloneAndBatchParallelOpStrategy(); 64 65 std::map<std::string, std::vector<std::pair<size_t, size_t>>> GetParamUsers(); 66 void SetParamStrategy(); 67 size_t ApplyParamStrategy(); 68 69 public: 70 RecStrategyPropagator(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, 71 const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list, 72 const std::vector<std::vector<std::string>> &input_tensor_names, 73 const std::shared_ptr<std::vector<size_t>> &index_list, bool is_training, 74 const std::vector<std::vector<size_t>> &shared_tensors_ops, const FuncGraphPtr &root); 75 76 size_t GetMaxDimNum(size_t i_op); 77 Dimensions GetDefaultStrategy(size_t i_op); 78 79 size_t CopyMainOperatorsStrategy(); 80 size_t PropagateFromInputs(); 81 size_t PropagateFromOutputs(); 82 83 void GenerateNoStraList(); 84 void ExtraShardMatmulOnBatchDim(); 85 86 void GenerateStrategyV1(); 87 void GenerateStrategyV3(); 88 }; 89 90 Dimensions GetInputStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, 91 const std::shared_ptr<std::vector<size_t>> &index_list, size_t i_op, 92 size_t incoming_op_index); 93 94 void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, 95 const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list, 96 const std::vector<std::vector<std::string>> &input_tensor_names, 97 const std::shared_ptr<std::vector<size_t>> &index_list, bool is_training, 98 const std::vector<std::vector<size_t>> &shared_tensors_ops, const FuncGraphPtr &root); 99 Dimensions PrepareMatMulStrategy(Graph::NodeType *node, bool transpose_a, bool transpose_b, size_t iter_op_inputs); 100 Strategies PrepareMatMul(Graph::NodeType *node, const std::shared_ptr<OperatorInfo> &op); 101 Dimensions PrepareBatchMatMulStrategy(Graph::NodeType *node, const bool transpose_a, const bool transpose_b, 102 const size_t iter_op_inputs, const size_t dim_num); 103 Strategies PrepareBatchMatMul(Graph::NodeType *node, const std::shared_ptr<OperatorInfo> &op); 104 Strategies PreparePropagateBatchMatMul(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra); 105 Strategies PrepareBiasAdd(const std::shared_ptr<Dimensions> &strategy); 106 Strategies PrepareStridedSlice(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra, bool dyn_shape_tmp_fix); 107 Strategies PrepareSoftMax(const std::shared_ptr<OperatorInfo> &op, const Dimensions &basic_stra); 108 Strategies PrepareLayerNorm(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra); 109 Strategies PrepareOneHot(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy); 110 Strategies PrepareGather(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy, bool dyn_shape_tmp_fix); 111 Dimensions PrepareGatherV2OutputStrategy(const std::shared_ptr<OperatorInfo> &op); 112 Strategies PrepareL2Normalize(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy); 113 Strategies PrepareAxisRelatedStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops, 114 const size_t iter_ops); 115 Strategies MakeRecSearchStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops, 116 const size_t iter_ops); 117 Strategies MakeDataParallelStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops, 118 const size_t iter_ops); 119 Strategies MakeFullBatchStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops, 120 const size_t iter_ops); 121 void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op); 122 Strategies PrepareStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops, 123 const size_t iter_ops, const bool dyn_shape_tmp_fix); 124 bool HasStrategy(std::shared_ptr<OperatorInfo> op); 125 size_t FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names, size_t iter_ops); 126 std::pair<size_t, size_t> FindIndexOfOperatorOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops, 127 const std::vector<std::vector<std::string>> &input_tensor_names, 128 size_t iter_ops); 129 Dimensions CopyIncomingOperatorOutputStrategy(Graph::NodeType *node, 130 const std::vector<std::shared_ptr<OperatorInfo>> &ops, 131 const size_t iter_ops, const size_t incoming_op_index); 132 Dimensions PrepareReshapeOutputStrategy(const std::shared_ptr<OperatorInfo> &op); 133 Dimensions PrepareTransposeOutputStrategy(const std::shared_ptr<OperatorInfo> &op); 134 Dimensions PrepareExpandDimsOutputStrategy(const std::shared_ptr<OperatorInfo> &op); 135 Dimensions PrepareIncomingArithmeticOpeartorInputStrategy(const std::shared_ptr<OperatorInfo> &op); 136 Dimensions PrepareIncomingOperatorInputStrategy(const std::shared_ptr<OperatorInfo> &op); 137 Dimensions GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int64_t iter_ops); 138 Dimensions ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops, 139 const size_t incoming_op_index, Dimensions strategy); 140 Dimensions ModifyStrategyIfReduceIncoming(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy); 141 Dimensions GetDimListFromAttrs(const std::shared_ptr<OperatorInfo> &op); 142 Dimensions ModifyStrategyIfArgIncoming(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy); 143 Dimensions ModifyStrategyIfFlattenIncoming(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy); 144 Dimensions CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, 145 const size_t iter_ops, const size_t incoming_op_index); 146 Strategies GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, 147 Dimensions basic_stra, bool dyn_shape_tmp_fix); 148 Strategies CheckBroadcast(const std::shared_ptr<OperatorInfo> &op, const Dimensions &strategy); 149 Dimensions ApplyBroadcast(const std::shared_ptr<OperatorInfo> &op, const Dimensions &strategy, 150 bool broadcast_first_tensor); 151 Strategies CheckDivisible(const std::shared_ptr<OperatorInfo> &op, const Dimensions &strategy); 152 Dimensions ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, 153 Dimensions strategy); 154 Dimensions PrepareTransposeInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t i_ops, 155 size_t outgoing_op_index, size_t iter_op_inputs); 156 Dimensions CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops, 157 size_t outgoing_op_index, size_t iter_op_inputs, bool dyn_shape_tmp_fix); 158 } // namespace parallel 159 } // namespace mindspore 160 #endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ 161