1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/pass/matmul_add_comm_reduction.h"
18 #include <memory>
19 #include <list>
20 #include <vector>
21 #include <string>
22 #include <utility>
23 #include "include/common/utils/utils.h"
24 #include "frontend/optimizer/optimizer.h"
25 #include "frontend/parallel/step_parallel.h"
26 #include "frontend/parallel/step_parallel_utils.h"
27 #include "mindspore/core/ops/math_ops.h"
28 #include "mindspore/core/ops/other_ops.h"
29
30 namespace mindspore {
31 namespace parallel {
32 namespace {
33 constexpr size_t kCommReductionValidCommOpsNum = 2;
34 constexpr auto MATMUL_ADD_COMM_BEGIN = "matmul_add_comm_begin";
35 constexpr auto MATMUL_ADD_COMM_END = "matmul_add_comm_end";
36 constexpr auto MATMUL_ADD_COMM_MUL = "matmul_add_comm_mul";
37 constexpr const char MATMUL_ADD_COMM_REDUCTION[] = "matmul_add_comm_reduction";
38
IsSubRankList(const RankList & child_list,const RankList & parent_list)39 bool IsSubRankList(const RankList &child_list, const RankList &parent_list) {
40 for (auto &child : child_list) {
41 if (std::find(parent_list.begin(), parent_list.end(), child) == parent_list.end()) {
42 return false;
43 }
44 }
45 return true;
46 }
47
IsPrimitiveAttrValid(const PrimitivePtr & prim,const std::string & attr_name)48 bool IsPrimitiveAttrValid(const PrimitivePtr &prim, const std::string &attr_name) {
49 MS_EXCEPTION_IF_NULL(prim);
50 return !prim->HasAttr(attr_name) || !GetValue<bool>(prim->GetAttr(attr_name));
51 }
52
IsAddNodeValid(const AnfNodePtr & add_node,const AnfNodePtr & comm_node)53 bool IsAddNodeValid(const AnfNodePtr &add_node, const AnfNodePtr &comm_node) {
54 OperatorInfoPtr add_distribute_operator = add_node->user_data<OperatorInfo>();
55 if (add_distribute_operator == nullptr) {
56 return false;
57 }
58 TensorInfo node_add_tensor_in = add_distribute_operator->inputs_tensor_info()[LongToSize(1)];
59 TensorLayout node_add_tensor_layout = node_add_tensor_in.tensor_layout();
60 const auto node_add_rank_list = node_add_tensor_layout.InferRepeatedGroup();
61
62 auto comm_prim = GetCNodePrimitive(comm_node);
63 if (!comm_prim->HasAttr(GROUP)) {
64 return false;
65 }
66 auto comm_group = GetValue<std::string>(comm_prim->GetAttr(GROUP));
67 MS_EXCEPTION_IF_NULL(g_device_manager);
68 auto comm_rank_list = g_device_manager->FindRankListByHashName(comm_group);
69 return IsSubRankList(comm_rank_list, node_add_rank_list);
70 }
71
IsPrimitiveLinear(const AnfNodePtr & anode)72 bool IsPrimitiveLinear(const AnfNodePtr &anode) {
73 MS_EXCEPTION_IF_NULL(anode);
74 if (IsPrimitiveCNode(anode, prim::kPrimReduceAll) || IsPrimitiveCNode(anode, prim::kPrimReduceAny) ||
75 IsPrimitiveCNode(anode, prim::kPrimReduceMean) || IsPrimitiveCNode(anode, prim::kPrimReduceMax) ||
76 IsPrimitiveCNode(anode, prim::kPrimReduceMin) || IsPrimitiveCNode(anode, prim::kPrimReduceProd) ||
77 IsPrimitiveCNode(anode, prim::kPrimReduceSum) || IsPrimitiveCNode(anode, prim::kPrimSquareSumV1)) {
78 return false;
79 }
80 return true;
81 }
82
FindPullDownNode(const AnfNodePtr & anode)83 AnfNodePtr FindPullDownNode(const AnfNodePtr &anode) {
84 auto pre_node = GetInputNodeWithFilter(anode, [&](const AnfNodePtr &cur_anode) {
85 auto cur_cnode = cur_anode->cast<CNodePtr>();
86 auto prim = GetCNodePrimitive(cur_cnode);
87 if (prim == nullptr) {
88 return std::make_pair(false, LongToSize(0));
89 }
90 auto cur_node_input_list = cur_cnode->inputs();
91 for (size_t i = 1; i < cur_node_input_list.size(); ++i) {
92 auto cur_input_node = cur_node_input_list[i];
93 // find first non Tensor CNode
94 if (IsValueNode<tensor::Tensor>(cur_input_node)) {
95 continue;
96 }
97 auto input_prim = GetCNodePrimitive(cur_input_node);
98 if (input_prim == nullptr) {
99 return std::make_pair(false, i);
100 }
101 // cur prim must in ALLREDUCE_PULL_DOWN_WHITE_LIST and input_prim is not marked or marked false
102 bool filter = (ALLREDUCE_PULL_DOWN_WHITE_LIST.find(prim->name()) != ALLREDUCE_PULL_DOWN_WHITE_LIST.end() ||
103 prim->name() == MATMUL || prim->name() == BATCH_MATMUL) &&
104 IsPrimitiveAttrValid(input_prim, MATMUL_ADD_COMM_BEGIN);
105 return std::make_pair(filter, i);
106 }
107 return std::make_pair(false, LongToSize(1));
108 });
109 return pre_node;
110 }
111
FindAllValidAddNode(const FuncGraphPtr & graph,HashMap<AnfNodePtr,std::vector<AnfNodePtr>> * pull_down_node_map)112 void FindAllValidAddNode(const FuncGraphPtr &graph, HashMap<AnfNodePtr, std::vector<AnfNodePtr>> *pull_down_node_map) {
113 std::list<CNodePtr> graph_orders = graph->GetOrderedCnodes();
114 std::vector<CNodePtr> origin_nodes_topological(graph_orders.cbegin(), graph_orders.cend());
115 for (const auto &node : origin_nodes_topological) {
116 // add node
117 auto prim = GetCNodePrimitive(node);
118 if (prim == nullptr || prim->name() != ADD || IsPrimitiveAttrValid(prim, MATMUL_ADD_COMM_END)) {
119 continue;
120 }
121 auto input_nodes = node->inputs();
122 for (size_t i = 1; i < input_nodes.size(); ++i) {
123 auto input_node = input_nodes[i];
124 if (!IsPrimitiveLinear(input_node)) {
125 continue;
126 }
127 auto comm_node = FindPullDownNode(input_node);
128 if (comm_node == nullptr) {
129 MS_LOG(INFO) << "For matmul add comm reduction, can not find valid comm node, node is "
130 << input_node->DebugString();
131 continue;
132 }
133 if ((!IsPrimitiveCNode(comm_node, prim::kPrimAllReduce) &&
134 !IsPrimitiveCNode(comm_node, prim::kPrimReduceScatter))) {
135 MS_LOG(INFO) << "For matmul comm reduction, comm node is not allreduce or reduce scatter, node is "
136 << comm_node->DebugString();
137 continue;
138 }
139
140 auto comm_cnode = comm_node->cast<CNodePtr>();
141 MS_EXCEPTION_IF_NULL(comm_node);
142 auto pre_prim = GetCNodePrimitive(comm_cnode->input(1));
143 if (pre_prim == nullptr || IsPrimitiveAttrValid(pre_prim, MATMUL_ADD_COMM_BEGIN)) {
144 MS_LOG(INFO) << "For matmul comm reduction, cannot find matmul/batch matmul node, "
145 << "skip cur node: " << input_node->DebugString();
146 continue;
147 }
148 (*pull_down_node_map)[node].push_back(comm_node);
149 MS_LOG(INFO) << "For matmul comm reduction, find one side with matmul-allreduce structure, add node is: "
150 << node->DebugString() << " comm node is: " << comm_node->DebugString();
151 }
152 }
153 }
154
FindBiasAdd(const AnfNodePtr & comm_node,const AnfNodePtr & add_node_input)155 AnfNodePtr FindBiasAdd(const AnfNodePtr &comm_node, const AnfNodePtr &add_node_input) {
156 MS_EXCEPTION_IF_NULL(comm_node);
157 auto add_node = GetInputNodeWithFilter(add_node_input, [&](const AnfNodePtr &anode) {
158 auto prim = GetCNodePrimitive(anode);
159 if (prim == nullptr) {
160 return std::make_pair(false, 0);
161 }
162 // find add node, current ops must lie in ALLREDUCE_PULL_DOWN_WHITE_LIST, cannot be add node or equal to comm node
163 bool filter = (ALLREDUCE_PULL_DOWN_WHITE_LIST.find(prim->name()) != ALLREDUCE_PULL_DOWN_WHITE_LIST.end() ||
164 prim->name() == MATMUL || prim->name() == BATCH_MATMUL) &&
165 prim->name() != ADD && anode != comm_node;
166 return std::make_pair(filter, 1);
167 });
168 return add_node;
169 }
170
HandleNodeBiasAdd(const AnfNodePtr & comm_node,const AnfNodePtr & add_node_input)171 void HandleNodeBiasAdd(const AnfNodePtr &comm_node, const AnfNodePtr &add_node_input) {
172 MS_EXCEPTION_IF_NULL(comm_node);
173 MS_EXCEPTION_IF_NULL(add_node_input);
174 auto comm_prim = GetCNodePrimitive(comm_node);
175 MS_EXCEPTION_IF_NULL(comm_prim);
176 if (!comm_prim->HasAttr(GROUP)) {
177 MS_LOG(INFO) << "For matmul comm reduction, cur prim has not attr " << GROUP
178 << ", skip it, node is: " << comm_node->DebugString();
179 return;
180 }
181 auto comm_group = GetValue<std::string>(comm_prim->GetAttr(GROUP));
182 MS_EXCEPTION_IF_NULL(g_device_manager);
183 auto comm_rank_list = g_device_manager->FindRankListByHashName(comm_group);
184 double rank_size = 1.0 / comm_rank_list.size();
185
186 auto add_node = FindBiasAdd(comm_node, add_node_input);
187 if (add_node == nullptr || !IsPrimitiveCNode(add_node, prim::kPrimAdd)) {
188 MS_LOG(INFO) << "For matmul comm reduction, cannot find bias add node, find node is: " << add_node->DebugString()
189 << " start node is " << add_node_input->DebugString();
190 return;
191 }
192 if (!IsAddNodeValid(add_node, comm_node)) {
193 MS_LOG(INFO) << "For matmul comm reduction, strategy of add node mismatched, skip it, add node is: "
194 << add_node->DebugString();
195 return;
196 }
197 auto add_cnode = add_node->cast<CNodePtr>();
198 MS_EXCEPTION_IF_NULL(add_cnode);
199 // find load node for bias parameter
200 auto bias_side_start_node = add_cnode->input(2);
201 auto bias_node = GetInputNodeWithFilter(bias_side_start_node, [&](const AnfNodePtr &anode) {
202 auto prim = GetCNodePrimitive(anode);
203 if (prim == nullptr) {
204 return std::make_pair(false, 0);
205 }
206 bool filter = ALLREDUCE_PULL_DOWN_WHITE_LIST.find(prim->name()) != ALLREDUCE_PULL_DOWN_WHITE_LIST.end();
207 return std::make_pair(filter, 1);
208 });
209 if (bias_node == nullptr || !IsPrimitiveCNode(bias_node, prim::kPrimLoad)) {
210 MS_LOG(INFO) << "For comm reduction, cannot find load op for bias parameter along current add node, please "
211 "check whether it exists, cur add node is: "
212 << add_node->DebugString();
213 return;
214 }
215 // insert mul node
216 auto bias_node_abstract = bias_node->abstract();
217 MS_EXCEPTION_IF_NULL(bias_node_abstract);
218 auto bias_dtype = bias_node_abstract->cast<abstract::AbstractTensorPtr>();
219 MS_EXCEPTION_IF_NULL(bias_dtype);
220 auto bias_dtype_ele = bias_dtype->element();
221 MS_EXCEPTION_IF_NULL(bias_dtype_ele);
222 mindspore::tensor::TensorPtr tensor_ptr =
223 std::make_shared<mindspore::tensor::Tensor>(rank_size, bias_dtype_ele->GetType());
224 auto const_node = NewValueNode(MakeValue(tensor_ptr));
225 const_node->set_abstract(const_node->value()->ToAbstract());
226
227 auto mul_prim = NewValueNode(prim::kPrimMul);
228 auto cur_prim = GetValueNode<PrimitivePtr>(mul_prim);
229 MS_EXCEPTION_IF_NULL(cur_prim);
230 (void)cur_prim->AddAttr(MATMUL_ADD_COMM_MUL, MakeValue(true));
231 AnfNodePtrList mul_node_inputs = {mul_prim, bias_node, const_node};
232 auto fg = comm_node->func_graph();
233 MS_EXCEPTION_IF_NULL(fg);
234 auto mul_node = fg->NewCNode(mul_node_inputs);
235 mul_node->set_abstract(bias_node->abstract()->Clone());
236
237 MS_EXCEPTION_IF_NULL(fg);
238 auto manager = fg->manager();
239 MS_EXCEPTION_IF_NULL(manager);
240 (void)manager->Replace(bias_node, mul_node);
241 MS_LOG(INFO) << "for comm reduction, insert new mul node after parameter node";
242 }
243
HandleNodePullUp(const AnfNodePtr & add_node,const std::vector<AnfNodePtr> & comm_node_list,HashMap<AnfNodePtr,AnfNodePtr> * comm_node_map)244 void HandleNodePullUp(const AnfNodePtr &add_node, const std::vector<AnfNodePtr> &comm_node_list,
245 HashMap<AnfNodePtr, AnfNodePtr> *comm_node_map) {
246 for (size_t index = 0; index < comm_node_list.size(); ++index) {
247 // Node pull down
248 // Node After AllReduce pull up
249 auto each_node = comm_node_list[index];
250 auto each_cnode = each_node->cast<CNodePtr>();
251 auto pre_node = each_cnode->input(1);
252 auto pre_prim = GetCNodePrimitive(pre_node);
253 if (pre_prim == nullptr || IsPrimitiveAttrValid(pre_prim, MATMUL_ADD_COMM_BEGIN)) {
254 MS_LOG(INFO) << "For comm reduction, its pre node does not marked or marked false, skip it.";
255 continue;
256 }
257 auto graph = each_node->func_graph();
258 MS_EXCEPTION_IF_NULL(graph);
259 auto manager = graph->manager();
260 MS_EXCEPTION_IF_NULL(manager);
261 auto add_cnode = add_node->cast<CNodePtr>();
262 HandleNodeBiasAdd(each_node, add_cnode->input(index + 1));
263 (void)manager->Replace(each_node, pre_node);
264 MS_LOG(INFO) << "For comm reduction, pull up node next to comm node, node is: " << pre_node->DebugString();
265 if ((*comm_node_map).find(add_node) == (*comm_node_map).end()) {
266 (*comm_node_map)[add_node] = each_node;
267 }
268 }
269 }
270
HandleNodePullDown(const AnfNodePtr & add_node,const AnfNodePtr & comm_node)271 void HandleNodePullDown(const AnfNodePtr &add_node, const AnfNodePtr &comm_node) {
272 auto comm_cnode = comm_node->cast<CNodePtr>();
273 MS_EXCEPTION_IF_NULL(comm_cnode);
274 AnfNodePtrList new_comm_node_inputs = {comm_cnode->input(0), add_node};
275 auto graph = add_node->func_graph();
276 MS_EXCEPTION_IF_NULL(graph);
277 auto new_comm_node = graph->NewCNode(new_comm_node_inputs);
278 new_comm_node->set_abstract(comm_node->abstract());
279 auto prim = GetCNodePrimitive(new_comm_node);
280 (void)prim->AddAttr(MATMUL_ADD_COMM_REDUCTION, MakeValue(true));
281
282 auto manager = graph->manager();
283 MS_EXCEPTION_IF_NULL(manager);
284 (void)manager->Replace(add_node, new_comm_node);
285 MS_LOG(INFO) << "For comm reduction, pull down comm node, node is: " << new_comm_node->DebugString();
286 }
287
HandleAddNode(const HashMap<AnfNodePtr,std::vector<AnfNodePtr>> & pull_down_node_map)288 void HandleAddNode(const HashMap<AnfNodePtr, std::vector<AnfNodePtr>> &pull_down_node_map) {
289 HashMap<AnfNodePtr, AnfNodePtr> comm_node_map;
290 for (auto &each_pull_down_node : pull_down_node_map) {
291 if (each_pull_down_node.second.size() < kCommReductionValidCommOpsNum) {
292 MS_LOG(INFO) << "For comm reduction, cur node cannot find match structure, skip it. current node is "
293 << each_pull_down_node.first->DebugString();
294 continue;
295 }
296 // Handle node pull up
297 HandleNodePullUp(each_pull_down_node.first, each_pull_down_node.second, &comm_node_map);
298 // Handle node pull down
299 HandleNodePullDown(each_pull_down_node.first, comm_node_map[each_pull_down_node.first]);
300 }
301 }
302
303 } // namespace
304
305 // For Structure as following:
306 // MatMul/BatchMatMul -> AllReduce -> ... -> X -> Add, and MatMul/BatchMatMul -> AllReduce -> ... -> Y -> Add
307 // Change it to MatMul/BatchMatMul -> ... -> X -> Add -> AllReduce and MatMul/BatchMatMul -> ... -> Y -> Add ->
308 // AllReduce thus it can reduce a communication op.
MatmulAddCommReduction(const FuncGraphPtr & graph,const opt::OptimizerPtr &)309 bool MatmulAddCommReduction(const FuncGraphPtr &graph, const opt::OptimizerPtr &) {
310 MS_EXCEPTION_IF_NULL(graph);
311 auto manager = graph->manager();
312 MS_EXCEPTION_IF_NULL(manager);
313 // assume no change to graph
314 bool changes = false;
315 HashMap<AnfNodePtr, std::vector<AnfNodePtr>> pull_down_node_map;
316 // candidate node to pull down
317 for (const auto &each_graph : manager->func_graphs()) {
318 FindAllValidAddNode(each_graph, &pull_down_node_map);
319 }
320 // Node Pull up
321 HandleAddNode(pull_down_node_map);
322 return changes;
323 }
324 } // namespace parallel
325 } // namespace mindspore
326