• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_PARALLEL_UTILS_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_PARALLEL_UTILS_H_
19 
20 #include <vector>
21 #include <string>
22 #include <utility>
23 #include <set>
24 #include <map>
25 #include <memory>
26 #include "base/base.h"
27 #include "frontend/parallel/device_manager.h"
28 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
29 #include "frontend/parallel/graph_util/node_info.h"
30 #include "frontend/parallel/ops_info/ops_utils.h"
31 #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
32 
33 namespace mindspore {
34 namespace parallel {
35 
36 bool IsDynamicShapeInput(const CNodePtr &node, const AnfNodePtr &input);
37 // maybe the input value is dynamic for these ops
38 static const std::set<std::string> CANDIDATE_DYNAMIC_VALUE_OPS = {RESHAPE, STRIDED_SLICE, PAD_V3,
39                                                                   TILE,    FILLV2,        UNIFORM_REAL};
40 // split tensor only for first input
41 static const std::set<std::string> SPLIT_TENSOR_ONLY_FOR_FIRST_INPUT_OPS = {PAD_V3};
42 // the input is tuple or list
43 static const std::set<std::string> INPUT_IS_TUPLE_OR_LIST_OPS = {
44   CONCAT, STACK, ADDN, INCRE_FLASH_ATTENTION, MESHGRID, FUSED_INFER_ATTENTION_SCORE, GROUPED_MATMUL, STACK_EXT};
45 // support new shapebase operator
46 static const std::set<std::string> SUPPORT_NEW_SHAPEBASE_OPS = {VIRTUAL_DATA_SET, FUSED_INFER_ATTENTION_SCORE,
47                                                                 GROUPED_MATMUL};
48 // op list for allreduce pull down
49 static const std::set<std::string> ALLREDUCE_PULL_DOWN_WHITE_LIST = {
50   TUPLE_GETITEM_OP, RESHAPE, TRANSPOSE, MIRROR_OPERATOR, ADD, MUL, DIV, GATHERV2};
51 
52 const int64_t TWO_INPUT_SIZE = 2;
53 
54 constexpr char KAttrAsLossDivisor[] = "as_loss_divisor";
55 constexpr char KAttrDevMatrixShape[] = "dev_matrix_shape";
56 constexpr char KAttrInputsTensorMap[] = "inputs_tensor_map";
57 constexpr char KAttrOutputsTensorMap[] = "outputs_tensor_map";
58 constexpr int64_t DYNAMIC_DIM_VAL = -1;
59 
60 extern size_t TOTAL_OPS;
61 extern std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
62 struct CommInfo {
63   int64_t device_num = 1;
64   int64_t global_rank = 0;
65   std::string world_group;
66   std::string communication_backend;
67 };
68 const std::set<std::string> COMMUNICATION_OPS = {
69   ALL_REDUCE,         ALL_GATHER,      ALL_TO_ALL,         REDUCE_SCATTER,    BROADCAST,       NEIGHBOREXCHANGE,
70   NEIGHBOREXCHANGEV2, SYNC_BATCH_NORM, COLLECTIVE_SCATTER, COLLECTIVE_GATHER, BATCHISENDIRECV, ALL_TO_ALLV};
71 // common method
72 CommInfo GetCommInfo();
73 ShapeVector ToFullShape(const ShapeVector &input_shape, size_t index);
74 void ExtendInputArgsAbstractShape(const AbstractBasePtr &args_abstract_item, size_t index);
75 bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
76 bool IsSomePrimitiveList(const CNodePtr &cnode, const std::set<string> &check_list);
77 bool IsParallelCareNode(const CNodePtr &cnode);
78 bool IsAutoParallelCareNode(const CNodePtr &cnode);
79 Shapes GetNodeShape(const AnfNodePtr &node);
80 bool HasSupportedValueSequence(const CNodePtr &node);
81 // Extract shape from anfnode
82 std::vector<Shapes> ExtractShape(const CNodePtr &node);
83 std::vector<NewShapes> ExtractNewShape(const CNodePtr &node);
84 std::vector<Shapes> ExtractRealDivisor(const CNodePtr &node);
85 // Generate and init parallel operator
86 OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,
87                                  const std::vector<Shapes> &shape_list);
88 OperatorInfoPtr CreateOperatorInfo(const CNodePtr &cnode);
89 OperatorInfoPtr CreateOperatorInfoForTupleShape(const CNodePtr &cnode);
90 std::string GetPrimName(const CNodePtr &node);
91 std::shared_ptr<Value> GetAttrsFromAnfNode(const std::shared_ptr<AnfNode> &node, const string &key);
92 std::string CreateInstanceName(const CNodePtr &node, size_t index);
93 TensorInfo GetInputsTensorInfo(const std::pair<AnfNodePtr, int64_t> &param_info);
94 AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr &manager);
95 bool IsControlFlowNode(const AnfNodePtr &node);
96 int64_t GetTupleGetItemIndex(const CNodePtr &cnode);
97 std::pair<AnfNodePtr, int64_t> GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index,
98                                                  CNodePtr *call_node = nullptr, bool ignore_get_item = true);
99 
100 std::vector<std::pair<AnfNodePtr, int>> GetOutputNodesWithFilter(const AnfNodePtr &node,
101                                                                  std::function<bool(const AnfNodePtr &)> filter);
102 AnfNodePtr GetInputNodeWithFilter(const AnfNodePtr &node,
103                                   std::function<std::pair<bool, size_t>(const CNodePtr &)> filter);
104 void RedistributionPreNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager,
105                            std::vector<AnfNodePtr> *pre_nodes);
106 void RedistributionNextNode(
107   const AnfNodePtr &node, const FuncGraphManagerPtr &manager, const NodeUsersMap &node_users_map,
108   const std::vector<int> &get_item_index, int64_t make_tuple_index,
109   std::vector<std::pair<std::pair<AnfNodePtr, std::vector<int>>, std::vector<int>>> *next_nodes);
110 AnfNodePtr NewMicroMirrorPrimByMicroMirror(const FuncGraphPtr &func_graph, const CNodePtr &micro_mirror,
111                                            const AnfNodePtr &micro_mirror_new_input);
112 // for specific scenarios
113 RankList FindCommonMirrorGroup(const FuncGraphPtr &root);
114 bool IsTraining(const FuncGraphManagerPtr &manager);
115 bool HasBackward(const FuncGraphPtr &root);
116 void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input);
117 void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes);
118 AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const TypePtr &compute_node_type);
119 TypePtr FindChildCastWithFP32ToFP16(const std::pair<AnfNodePtr, int> &res, const NodeUsersMap &node_users_map);
120 void LabelGenMaskMicro(const FuncGraphPtr &root);
121 void AddNodeFusionInfo(const CNodePtr &node, const CNodePtr &comm_node, const std::string &backward_comm_name,
122                        const std::string &param_name, int32_t fusion_id);
123 void AddNodeMirrorInfo(const CNodePtr &cnode, const std::string &param_name);
124 void SetCastForParamNotRecompute(const std::vector<AnfNodePtr> &all_nodes);
125 bool IsPynativeParallel();
126 bool IsAutoParallelCareGraph(const FuncGraphPtr &func_graph);
127 bool HasNestedMetaFg(const FuncGraphPtr &func_graph);
128 bool IsEmbedShardNode(const FuncGraphPtr &func_graph);
129 bool IsSplittableOperator(const std::string &op_name);
130 AnfNodePtr FindRealInputByFormalParameter(const CNodePtr &node, const AnfNodePtr &input,
131                                           const std::vector<AnfNodePtr> &all_nodes);
132 std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node, const std::vector<AnfNodePtr> &all_nodes);
133 OperatorInfoPtr GetDistributeOperator(const CNodePtr &node);
134 bool StrategyFound(const mindspore::HashMap<std::string, ValuePtr> &attrs);
135 bool AttrFound(const mindspore::HashMap<std::string, ValuePtr> &attrs, const std::string &target);
136 void ExceptionIfHasCommunicationOp(const std::vector<AnfNodePtr> &all_nodes);
137 std::string MirrorOpName();
138 // Extract strategy from attr
139 StrategyPtr ExtractStrategy(const ValuePtr &stra);
140 StrategyPtr ExtractNewStrategy(const ValuePtr &stra);
141 ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth);
142 std::vector<std::pair<AnfNodePtr, int>> FuncGraphNodeUsers(const std::pair<AnfNodePtr, int> &node_pair);
143 Status ParallelInit(size_t rank_id = 0, const size_t devices = 0);
144 std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph,
145                                     size_t max_depth);
146 std::pair<std::shared_ptr<AnfNode>, int> BFSParallelCareNode(const AnfNodePtr &node_ptr,
147                                                              const NodeUsersMap &node_users_map, const int index,
148                                                              const std::vector<AnfNodePtr> &all_nodes);
149 void FindPreNodeCrossFuncGraph(CNodePtr *cnode, int64_t out_index);
150 bool CrossInterNode(CNodePtr *prev_cnode, ValueNodePtr *prev_prim_anf_node, PrimitivePtr *prev_prim);
151 bool IsCarePrevCNode(const CNodePtr &prev_cnode, const PrimitivePtr &prev_prim);
152 void SetSharedParameterFlag(const FuncGraphPtr &root, const AnfNodePtr &parameter);
153 StrategyPtr GenerateStandAloneStrategy(const Shapes &inputs_shape);
154 StrategyPtr GenerateStandAloneStrategyForNewShapes(const NewShapes &inputs_shape);
155 StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim);
156 bool IsInsertVirtualOutput(const FuncGraphPtr &root);
157 TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_pair, const int &make_tuple_index);
158 Shape mirror_group_list(const TensorLayoutPtr &layout);
159 // Transfer number to serial number string
160 std::string GetSerialNumberString(size_t number);
161 size_t GetDeviceCapacity();
162 bool IsIgnoreSplitTensor(const CNodePtr &node, int64_t index);
163 bool MergeConcatSlice(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);
164 void UpdateMicroBatchInterleavedStatus(const std::vector<AnfNodePtr> &all_nodes);
165 Status ExtractUserConfigLayout(const mindspore::HashMap<std::string, ValuePtr> &prim_attrs, const Shapes &inputs_shape,
166                                const Shapes &outputs_shape,
167                                std::vector<std::shared_ptr<TensorLayout>> *in_tensor_layouts,
168                                std::vector<std::shared_ptr<TensorLayout>> *out_tensor_layouts);
IsMakeSequence(const AnfNodePtr & node)169 inline bool IsMakeSequence(const AnfNodePtr &node) {
170   return AnfNodeIsPrimitive(node, MAKE_TUPLE) || AnfNodeIsPrimitive(node, MAKE_LIST);
171 }
IsValueSequence(const AnfNodePtr & node)172 inline bool IsValueSequence(const AnfNodePtr &node) {
173   return IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node);
174 }
175 bool IsCellReuseForwardGraph(const FuncGraphPtr &graph);
176 FuncGraphPtr GetCellReuseBackwardGraph(const FuncGraphPtr &forward_graph);
177 bool IsCommunicationOp(const PrimitivePtr &prim);
178 void ConvertInterleaveAllGatherToConcat(const FuncGraphPtr &func_graph, const CNodePtr &virtual_converter_end,
179                                         const std::vector<std::vector<std::vector<int64_t>>> &ag_group_ranks_vectors);
180 void SplitNotParallelCareOpsInterleaved(const FuncGraphPtr &root);
181 void EraseVirtualConverter(const FuncGraphPtr &root);
182 
SetReserved(const FuncGraphPtr & root)183 inline void SetReserved(const FuncGraphPtr &root) {
184   // Keep all func graph for parallel before save result.
185   root->set_reserved(true);
186   for (auto &fg : root->func_graphs_used_total()) {
187     MS_EXCEPTION_IF_NULL(fg);
188     fg->set_reserved(true);
189   }
190 }
191 
192 abstract::AbstractBasePtr GenerateAbsByOpInfer(const PrimitivePtr &primitive, const AnfNodePtrList &input_list);
193 }  // namespace parallel
194 }  // namespace mindspore
195 
196 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_PARALLEL_UTILS_H_
197