1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_PARALLEL_UTILS_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_PARALLEL_UTILS_H_
19
20 #include <vector>
21 #include <string>
22 #include <utility>
23 #include <set>
24 #include <map>
25 #include <memory>
26 #include "base/base.h"
27 #include "frontend/parallel/device_manager.h"
28 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
29 #include "frontend/parallel/graph_util/node_info.h"
30 #include "frontend/parallel/ops_info/ops_utils.h"
31 #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
32
33 namespace mindspore {
34 namespace parallel {
35
36 bool IsDynamicShapeInput(const CNodePtr &node, const AnfNodePtr &input);
37 // maybe the input value is dynamic for these ops
38 static const std::set<std::string> CANDIDATE_DYNAMIC_VALUE_OPS = {RESHAPE, STRIDED_SLICE, PAD_V3,
39 TILE, FILLV2, UNIFORM_REAL};
40 // split tensor only for first input
41 static const std::set<std::string> SPLIT_TENSOR_ONLY_FOR_FIRST_INPUT_OPS = {PAD_V3};
42 // the input is tuple or list
43 static const std::set<std::string> INPUT_IS_TUPLE_OR_LIST_OPS = {
44 CONCAT, STACK, ADDN, INCRE_FLASH_ATTENTION, MESHGRID, FUSED_INFER_ATTENTION_SCORE, GROUPED_MATMUL, STACK_EXT};
45 // support new shapebase operator
46 static const std::set<std::string> SUPPORT_NEW_SHAPEBASE_OPS = {VIRTUAL_DATA_SET, FUSED_INFER_ATTENTION_SCORE,
47 GROUPED_MATMUL};
48 // op list for allreduce pull down
49 static const std::set<std::string> ALLREDUCE_PULL_DOWN_WHITE_LIST = {
50 TUPLE_GETITEM_OP, RESHAPE, TRANSPOSE, MIRROR_OPERATOR, ADD, MUL, DIV, GATHERV2};
51
52 const int64_t TWO_INPUT_SIZE = 2;
53
54 constexpr char KAttrAsLossDivisor[] = "as_loss_divisor";
55 constexpr char KAttrDevMatrixShape[] = "dev_matrix_shape";
56 constexpr char KAttrInputsTensorMap[] = "inputs_tensor_map";
57 constexpr char KAttrOutputsTensorMap[] = "outputs_tensor_map";
58 constexpr int64_t DYNAMIC_DIM_VAL = -1;
59
60 extern size_t TOTAL_OPS;
61 extern std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
62 struct CommInfo {
63 int64_t device_num = 1;
64 int64_t global_rank = 0;
65 std::string world_group;
66 std::string communication_backend;
67 };
68 const std::set<std::string> COMMUNICATION_OPS = {
69 ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER, BROADCAST, NEIGHBOREXCHANGE,
70 NEIGHBOREXCHANGEV2, SYNC_BATCH_NORM, COLLECTIVE_SCATTER, COLLECTIVE_GATHER, BATCHISENDIRECV, ALL_TO_ALLV};
71 // common method
72 CommInfo GetCommInfo();
73 ShapeVector ToFullShape(const ShapeVector &input_shape, size_t index);
74 void ExtendInputArgsAbstractShape(const AbstractBasePtr &args_abstract_item, size_t index);
75 bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
76 bool IsSomePrimitiveList(const CNodePtr &cnode, const std::set<string> &check_list);
77 bool IsParallelCareNode(const CNodePtr &cnode);
78 bool IsAutoParallelCareNode(const CNodePtr &cnode);
79 Shapes GetNodeShape(const AnfNodePtr &node);
80 bool HasSupportedValueSequence(const CNodePtr &node);
81 // Extract shape from anfnode
82 std::vector<Shapes> ExtractShape(const CNodePtr &node);
83 std::vector<NewShapes> ExtractNewShape(const CNodePtr &node);
84 std::vector<Shapes> ExtractRealDivisor(const CNodePtr &node);
85 // Generate and init parallel operator
86 OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,
87 const std::vector<Shapes> &shape_list);
88 OperatorInfoPtr CreateOperatorInfo(const CNodePtr &cnode);
89 OperatorInfoPtr CreateOperatorInfoForTupleShape(const CNodePtr &cnode);
90 std::string GetPrimName(const CNodePtr &node);
91 std::shared_ptr<Value> GetAttrsFromAnfNode(const std::shared_ptr<AnfNode> &node, const string &key);
92 std::string CreateInstanceName(const CNodePtr &node, size_t index);
93 TensorInfo GetInputsTensorInfo(const std::pair<AnfNodePtr, int64_t> ¶m_info);
94 AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr &manager);
95 bool IsControlFlowNode(const AnfNodePtr &node);
96 int64_t GetTupleGetItemIndex(const CNodePtr &cnode);
97 std::pair<AnfNodePtr, int64_t> GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index,
98 CNodePtr *call_node = nullptr, bool ignore_get_item = true);
99
100 std::vector<std::pair<AnfNodePtr, int>> GetOutputNodesWithFilter(const AnfNodePtr &node,
101 std::function<bool(const AnfNodePtr &)> filter);
102 AnfNodePtr GetInputNodeWithFilter(const AnfNodePtr &node,
103 std::function<std::pair<bool, size_t>(const CNodePtr &)> filter);
104 void RedistributionPreNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager,
105 std::vector<AnfNodePtr> *pre_nodes);
106 void RedistributionNextNode(
107 const AnfNodePtr &node, const FuncGraphManagerPtr &manager, const NodeUsersMap &node_users_map,
108 const std::vector<int> &get_item_index, int64_t make_tuple_index,
109 std::vector<std::pair<std::pair<AnfNodePtr, std::vector<int>>, std::vector<int>>> *next_nodes);
110 AnfNodePtr NewMicroMirrorPrimByMicroMirror(const FuncGraphPtr &func_graph, const CNodePtr µ_mirror,
111 const AnfNodePtr µ_mirror_new_input);
112 // for specific scenarios
113 RankList FindCommonMirrorGroup(const FuncGraphPtr &root);
114 bool IsTraining(const FuncGraphManagerPtr &manager);
115 bool HasBackward(const FuncGraphPtr &root);
116 void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input);
117 void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes);
118 AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const TypePtr &compute_node_type);
119 TypePtr FindChildCastWithFP32ToFP16(const std::pair<AnfNodePtr, int> &res, const NodeUsersMap &node_users_map);
120 void LabelGenMaskMicro(const FuncGraphPtr &root);
121 void AddNodeFusionInfo(const CNodePtr &node, const CNodePtr &comm_node, const std::string &backward_comm_name,
122 const std::string ¶m_name, int32_t fusion_id);
123 void AddNodeMirrorInfo(const CNodePtr &cnode, const std::string ¶m_name);
124 void SetCastForParamNotRecompute(const std::vector<AnfNodePtr> &all_nodes);
125 bool IsPynativeParallel();
126 bool IsAutoParallelCareGraph(const FuncGraphPtr &func_graph);
127 bool HasNestedMetaFg(const FuncGraphPtr &func_graph);
128 bool IsEmbedShardNode(const FuncGraphPtr &func_graph);
129 bool IsSplittableOperator(const std::string &op_name);
130 AnfNodePtr FindRealInputByFormalParameter(const CNodePtr &node, const AnfNodePtr &input,
131 const std::vector<AnfNodePtr> &all_nodes);
132 std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node, const std::vector<AnfNodePtr> &all_nodes);
133 OperatorInfoPtr GetDistributeOperator(const CNodePtr &node);
134 bool StrategyFound(const mindspore::HashMap<std::string, ValuePtr> &attrs);
135 bool AttrFound(const mindspore::HashMap<std::string, ValuePtr> &attrs, const std::string &target);
136 void ExceptionIfHasCommunicationOp(const std::vector<AnfNodePtr> &all_nodes);
137 std::string MirrorOpName();
138 // Extract strategy from attr
139 StrategyPtr ExtractStrategy(const ValuePtr &stra);
140 StrategyPtr ExtractNewStrategy(const ValuePtr &stra);
141 ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth);
142 std::vector<std::pair<AnfNodePtr, int>> FuncGraphNodeUsers(const std::pair<AnfNodePtr, int> &node_pair);
143 Status ParallelInit(size_t rank_id = 0, const size_t devices = 0);
144 std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph,
145 size_t max_depth);
146 std::pair<std::shared_ptr<AnfNode>, int> BFSParallelCareNode(const AnfNodePtr &node_ptr,
147 const NodeUsersMap &node_users_map, const int index,
148 const std::vector<AnfNodePtr> &all_nodes);
149 void FindPreNodeCrossFuncGraph(CNodePtr *cnode, int64_t out_index);
150 bool CrossInterNode(CNodePtr *prev_cnode, ValueNodePtr *prev_prim_anf_node, PrimitivePtr *prev_prim);
151 bool IsCarePrevCNode(const CNodePtr &prev_cnode, const PrimitivePtr &prev_prim);
152 void SetSharedParameterFlag(const FuncGraphPtr &root, const AnfNodePtr ¶meter);
153 StrategyPtr GenerateStandAloneStrategy(const Shapes &inputs_shape);
154 StrategyPtr GenerateStandAloneStrategyForNewShapes(const NewShapes &inputs_shape);
155 StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim);
156 bool IsInsertVirtualOutput(const FuncGraphPtr &root);
157 TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_pair, const int &make_tuple_index);
158 Shape mirror_group_list(const TensorLayoutPtr &layout);
159 // Transfer number to serial number string
160 std::string GetSerialNumberString(size_t number);
161 size_t GetDeviceCapacity();
162 bool IsIgnoreSplitTensor(const CNodePtr &node, int64_t index);
163 bool MergeConcatSlice(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);
164 void UpdateMicroBatchInterleavedStatus(const std::vector<AnfNodePtr> &all_nodes);
165 Status ExtractUserConfigLayout(const mindspore::HashMap<std::string, ValuePtr> &prim_attrs, const Shapes &inputs_shape,
166 const Shapes &outputs_shape,
167 std::vector<std::shared_ptr<TensorLayout>> *in_tensor_layouts,
168 std::vector<std::shared_ptr<TensorLayout>> *out_tensor_layouts);
IsMakeSequence(const AnfNodePtr & node)169 inline bool IsMakeSequence(const AnfNodePtr &node) {
170 return AnfNodeIsPrimitive(node, MAKE_TUPLE) || AnfNodeIsPrimitive(node, MAKE_LIST);
171 }
IsValueSequence(const AnfNodePtr & node)172 inline bool IsValueSequence(const AnfNodePtr &node) {
173 return IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node);
174 }
175 bool IsCellReuseForwardGraph(const FuncGraphPtr &graph);
176 FuncGraphPtr GetCellReuseBackwardGraph(const FuncGraphPtr &forward_graph);
177 bool IsCommunicationOp(const PrimitivePtr &prim);
178 void ConvertInterleaveAllGatherToConcat(const FuncGraphPtr &func_graph, const CNodePtr &virtual_converter_end,
179 const std::vector<std::vector<std::vector<int64_t>>> &ag_group_ranks_vectors);
180 void SplitNotParallelCareOpsInterleaved(const FuncGraphPtr &root);
181 void EraseVirtualConverter(const FuncGraphPtr &root);
182
SetReserved(const FuncGraphPtr & root)183 inline void SetReserved(const FuncGraphPtr &root) {
184 // Keep all func graph for parallel before save result.
185 root->set_reserved(true);
186 for (auto &fg : root->func_graphs_used_total()) {
187 MS_EXCEPTION_IF_NULL(fg);
188 fg->set_reserved(true);
189 }
190 }
191
192 abstract::AbstractBasePtr GenerateAbsByOpInfer(const PrimitivePtr &primitive, const AnfNodePtrList &input_list);
193 } // namespace parallel
194 } // namespace mindspore
195
196 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STEP_PARALLEL_UTILS_H_
197