• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_
18 #define PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_
19 
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "frontend/parallel/auto_parallel/rec_core/rec_graph.h"
26 #include "frontend/parallel/ops_info/operator_info.h"
27 
28 namespace mindspore {
29 namespace parallel {
30 void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
31                       const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
32                       const std::vector<std::vector<std::string>> &input_tensor_names,
33                       const std::shared_ptr<std::vector<size_t>> &index_list, bool is_training);
34 Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
35                         const size_t iter_graph, const size_t iter_ops);
36 Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s);
37 Strategys PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
38                               Dimensions basic_stra);
39 Strategys PrepareOneHot(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
40                         const size_t iter_graph, const size_t iter_ops);
41 Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
42                                      const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
43                                      const size_t iter_ops);
44 Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
45 Strategys PrepareGatherV2P(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
46 Dimensions PrepareGatherV2POutputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
47                                           const size_t incoming_op_index);
48 Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
49                              Dimensions s);
50 Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
51                                 const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
52                                 const size_t iter_ops);
53 Strategys CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
54 Dimensions ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s,
55                           size_t first_tensor_dim, size_t second_tensor_dim, bool broadcast_first_tensor);
56 Strategys CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
57 Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
58                                    const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
59                                    const size_t iter_ops);
60 Strategys MakeFullBatchStrategy(const std::shared_ptr<Graph> &graph,
61                                 const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
62                                 const size_t iter_ops);
63 void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op);
64 Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
65                           const size_t iter_graph, const size_t iter_ops);
66 void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph,
67                                          const std::vector<std::shared_ptr<OperatorInfo>> &ops,
68                                          const std::shared_ptr<std::vector<size_t>> &index_list);
69 size_t FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names,
70                                    const size_t iter_ops);
71 Dimensions CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> &graph,
72                                               const std::vector<std::shared_ptr<OperatorInfo>> &ops,
73                                               const size_t iter_ops, const size_t iter_graph,
74                                               const size_t incoming_op_index);
75 Dimensions PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
76                                                 const size_t incoming_op_index);
77 Dimensions GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int64_t iter_ops);
78 Dimensions ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
79                                            const size_t incoming_op_index, Dimensions s);
80 bool GetKeepDims(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops);
81 Dimensions GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops);
82 Dimensions ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
83                                           const size_t incoming_op_index, Dimensions s);
84 Dimensions GetDimListFromAttrs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops);
85 Dimensions ModifyStrategyIfArgIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
86                                        const size_t incoming_op_index, Dimensions s);
87 Dimensions CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
88                                              const size_t incoming_op_index);
89 Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
90                                          Dimensions basic_stra);
91 void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &graph,
92                                                const std::vector<std::shared_ptr<OperatorInfo>> &ops,
93                                                const std::vector<std::vector<std::string>> &input_tensor_names,
94                                                const std::shared_ptr<std::vector<size_t>> &index_list,
95                                                const std::shared_ptr<std::vector<size_t>> &no_stra_op_list);
96 Dimensions ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
97                                            Dimensions s);
98 Dimensions CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
99                                              const std::vector<std::vector<std::string>> &input_tensor_names,
100                                              const size_t iter_ops);
101 void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
102                                                 const std::vector<std::vector<std::string>> &input_tensor_names,
103                                                 const std::shared_ptr<std::vector<size_t>> &no_stra_op_list);
104 void GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> &graph,
105                                        const std::vector<std::shared_ptr<OperatorInfo>> &ops,
106                                        const std::vector<std::vector<std::string>> &input_tensor_names,
107                                        const std::shared_ptr<std::vector<size_t>> &index_list,
108                                        const std::shared_ptr<std::vector<size_t>> &no_stra_op_list);
109 }  // namespace parallel
110 }  // namespace mindspore
111 #endif  // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_
112