1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/optimizer/grouped_pairwise_exchange_alltoall.h"
18 #include <memory>
19 #include <queue>
20 #include <utility>
21 #include <list>
22 #include <vector>
23 #include <string>
24 #include <algorithm>
25
26 #include "mindspore/core/ops/sequence_ops.h"
27 #include "mindspore/core/ops/other_ops.h"
28 #include "mindspore/core/ops/array_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "utils/anf_utils.h"
31 #include "include/common/utils/utils.h"
32 #include "include/common/utils/anfalgo.h"
33 #include "include/common/utils/parallel_context.h"
34 #include "frontend/parallel/ops_info/ops_utils.h"
35 #include "frontend/parallel/ops_info/operator_info.h"
36 #include "frontend/parallel/tensor_layout/tensor_info.h"
37 #include "frontend/parallel/device_matrix.h"
38 #include "pipeline/jit/ps/action.h"
39
40 namespace mindspore {
41 namespace opt {
42 namespace {
43 using CNodePtrPair = std::pair<CNodePtr, CNodePtr>;
44 using GpeaInfo = GroupedPairwiseExchangeAllToAllInfo;
45
FindFrontAlltoall(const CNodePtr & marked_node,std::vector<CNodePtr> * visited_marked_nodes)46 CNodePtr FindFrontAlltoall(const CNodePtr &marked_node, std::vector<CNodePtr> *visited_marked_nodes) {
47 MS_EXCEPTION_IF_NULL(marked_node);
48 auto input_node = marked_node->input(1);
49 auto input_cnode = input_node->cast<CNodePtr>();
50 MS_EXCEPTION_IF_NULL(input_cnode);
51 std::queue<CNodePtr> node_queue;
52 node_queue.push(input_cnode);
53
54 CNodePtr alltoall_node = nullptr;
55 while (!node_queue.empty()) {
56 auto cnode = node_queue.front();
57 node_queue.pop();
58 if (IsPrimitiveCNode(cnode, prim::kPrimAlltoAll)) {
59 alltoall_node = cnode;
60 break;
61 }
62
63 if (cnode->HasAttr("gpea_label")) {
64 visited_marked_nodes->push_back(cnode);
65 }
66
67 auto input = cnode->input(1);
68 MS_EXCEPTION_IF_NULL(input);
69 if (!input->isa<CNode>()) {
70 break;
71 }
72 auto in_cnode = input->cast<CNodePtr>();
73 MS_EXCEPTION_IF_NULL(in_cnode);
74 node_queue.push(in_cnode);
75 }
76
77 if (alltoall_node == nullptr) {
78 MS_LOG(WARNING) << "Can't find alltoall node before " << GetCNodePrimitive(marked_node)->name();
79 }
80 return alltoall_node;
81 }
82
FindBackAlltoall(const FuncGraphManagerPtr & manager,const CNodePtr & marked_node,std::vector<CNodePtr> * visited_marked_nodes)83 CNodePtr FindBackAlltoall(const FuncGraphManagerPtr &manager, const CNodePtr &marked_node,
84 std::vector<CNodePtr> *visited_marked_nodes) {
85 MS_EXCEPTION_IF_NULL(marked_node);
86 auto node_users_map = manager->node_users();
87 auto node_users = node_users_map[marked_node];
88 auto first_user = node_users.front().first;
89 auto first_user_cnode = first_user->cast<CNodePtr>();
90 MS_EXCEPTION_IF_NULL(first_user_cnode);
91 std::queue<CNodePtr> node_queue;
92 node_queue.push(first_user_cnode);
93
94 CNodePtr alltoall_node = nullptr;
95 while (!node_queue.empty()) {
96 auto cnode = node_queue.front();
97 node_queue.pop();
98 if (IsPrimitiveCNode(cnode, prim::kPrimAlltoAll)) {
99 alltoall_node = cnode;
100 break;
101 }
102
103 if (GetCNodePrimitive(cnode)->HasAttr("gpea_label")) {
104 visited_marked_nodes->push_back(cnode);
105 }
106
107 auto cnode_users = node_users_map[cnode];
108 if (cnode_users.empty()) { // last cnode, exit while
109 break;
110 }
111 auto first_node = cnode_users.front().first;
112 MS_EXCEPTION_IF_NULL(first_node);
113 auto first_cnode = first_node->cast<CNodePtr>();
114 MS_EXCEPTION_IF_NULL(first_cnode);
115 node_queue.push(first_cnode);
116 }
117
118 if (alltoall_node == nullptr) {
119 MS_LOG(WARNING) << "Can't find alltoall node after " << GetCNodePrimitive(marked_node)->name();
120 }
121 return alltoall_node;
122 }
123
FindAlltoallPair(const FuncGraphManagerPtr & manager,const CNodePtr & marked_node,std::vector<CNodePtr> * visited_marked_nodes)124 CNodePtrPair FindAlltoallPair(const FuncGraphManagerPtr &manager, const CNodePtr &marked_node,
125 std::vector<CNodePtr> *visited_marked_nodes) {
126 auto front_alltoall = FindFrontAlltoall(marked_node, visited_marked_nodes);
127 if (front_alltoall == nullptr) {
128 CNodePtrPair null_alltoall_pair(nullptr, nullptr);
129 return null_alltoall_pair;
130 }
131
132 auto back_alltoall = FindBackAlltoall(manager, marked_node, visited_marked_nodes);
133 if (back_alltoall == nullptr) {
134 CNodePtrPair null_alltoall_pair(nullptr, nullptr);
135 return null_alltoall_pair;
136 }
137
138 CNodePtrPair alltoall_pair(front_alltoall, back_alltoall);
139 return alltoall_pair;
140 }
141
FindAlltoallNodePairs(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & origin_nodes_topological,std::vector<CNodePtrPair> * alltoall_pairs)142 void FindAlltoallNodePairs(const FuncGraphManagerPtr &manager, const std::vector<CNodePtr> &origin_nodes_topological,
143 std::vector<CNodePtrPair> *alltoall_pairs) {
144 std::vector<CNodePtr> visited_marked_nodes;
145 for (size_t i = 0; i < origin_nodes_topological.size(); i++) {
146 auto cnode = origin_nodes_topological[i];
147 if (!IsPrimitiveCNode(cnode)) {
148 continue;
149 }
150
151 if (!GetCNodePrimitive(cnode)->HasAttr("gpea_label")) {
152 continue;
153 }
154
155 if (std::find(visited_marked_nodes.begin(), visited_marked_nodes.end(), cnode) != visited_marked_nodes.end()) {
156 continue;
157 }
158
159 visited_marked_nodes.push_back(cnode);
160 auto alltoall_pair = FindAlltoallPair(manager, cnode, &visited_marked_nodes);
161 if (alltoall_pair.first == nullptr || alltoall_pair.second == nullptr) {
162 MS_LOG(WARNING) << "not find alltoall_pair around cnode: " << GetCNodePrimitive(cnode)->name();
163 continue;
164 }
165 alltoall_pairs->push_back(alltoall_pair);
166 }
167 }
168
GetSplitDimFromAlltoall(const AnfNodePtr & alltoall)169 size_t GetSplitDimFromAlltoall(const AnfNodePtr &alltoall) {
170 size_t split_dim = LongToSize(common::AnfAlgo::GetNodeAttr<int64_t>(alltoall, kAttrSplitDim));
171 return split_dim;
172 }
173
GetConcatDimFromAlltoall(const AnfNodePtr & alltoall)174 size_t GetConcatDimFromAlltoall(const AnfNodePtr &alltoall) {
175 size_t concat_dim = LongToSize(common::AnfAlgo::GetNodeAttr<int64_t>(alltoall, kAttrConcatDim));
176 return concat_dim;
177 }
178
NewSplitNode(const AnfNodePtr & input_node,size_t split_dim,size_t split_num)179 CNodePtr NewSplitNode(const AnfNodePtr &input_node, size_t split_dim, size_t split_num) {
180 if (split_num == 0) {
181 MS_LOG(INTERNAL_EXCEPTION) << "split_num should not be zero.";
182 }
183 MS_EXCEPTION_IF_NULL(input_node);
184 std::vector<AnfNodePtr> split_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplit->name())),
185 input_node, NewValueNode<int64_t>(split_dim),
186 NewValueNode<int64_t>(split_num)};
187 auto split = input_node->func_graph()->NewCNode(split_inputs);
188 MS_EXCEPTION_IF_NULL(split);
189
190 auto dtype = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
191 std::vector<TypeId> dtypes(split_num, dtype);
192 auto shape = common::AnfAlgo::GetOutputInferShape(input_node, 0);
193 shape[split_dim] /= SizeToLong(split_num);
194 std::vector<ShapeVector> shapes(split_num, shape);
195 common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());
196 split->set_scope(input_node->scope());
197 return split;
198 }
199
NewSplitNode(const AnfNodePtr & input_node,size_t split_dim,size_t split_num,const ShapeVector & input_shape,const TypeId & input_dtype)200 CNodePtr NewSplitNode(const AnfNodePtr &input_node, size_t split_dim, size_t split_num, const ShapeVector &input_shape,
201 const TypeId &input_dtype) {
202 if (split_num == 0) {
203 MS_LOG(INTERNAL_EXCEPTION) << "split_num should not be zero.";
204 }
205 MS_EXCEPTION_IF_NULL(input_node);
206 std::vector<AnfNodePtr> split_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplit->name())),
207 input_node, NewValueNode<int64_t>(split_dim),
208 NewValueNode<int64_t>(split_num)};
209 auto split = input_node->func_graph()->NewCNode(split_inputs);
210 MS_EXCEPTION_IF_NULL(split);
211
212 std::vector<TypeId> dtypes(split_num, input_dtype);
213 ShapeVector shape;
214 for (size_t i = 0; i < input_shape.size(); i++) {
215 shape.push_back(input_shape[i]);
216 }
217 shape[split_dim] /= SizeToLong(split_num);
218 std::vector<ShapeVector> shapes(split_num, shape);
219 common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());
220 split->set_scope(input_node->scope());
221 return split;
222 }
223
NewConcatNode(const AnfNodePtr & input_node,size_t concat_dim,size_t input_num)224 CNodePtr NewConcatNode(const AnfNodePtr &input_node, size_t concat_dim, size_t input_num) {
225 MS_EXCEPTION_IF_NULL(input_node);
226 std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
227 input_node, NewValueNode(MakeValue(static_cast<int64_t>(concat_dim)))};
228 auto concat = input_node->func_graph()->NewCNode(concat_inputs);
229 MS_EXCEPTION_IF_NULL(concat);
230
231 std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(input_node, 0)};
232 auto shape = common::AnfAlgo::GetOutputInferShape(input_node, 0);
233 shape[concat_dim] *= SizeToLong(input_num);
234 std::vector<ShapeVector> shapes(1, shape);
235 common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, concat.get());
236 concat->set_scope(input_node->scope());
237 return concat;
238 }
239
NewMakeTupleNode(const std::vector<AnfNodePtr> & input_nodes)240 CNodePtr NewMakeTupleNode(const std::vector<AnfNodePtr> &input_nodes) {
241 // input_nodes are getitem nodes
242 std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
243 for (size_t i = 0; i < input_nodes.size(); i++) {
244 make_tuple_inputs.push_back(input_nodes[i]);
245 }
246 auto make_tuple = input_nodes[0]->func_graph()->NewCNode(make_tuple_inputs);
247 MS_EXCEPTION_IF_NULL(make_tuple);
248
249 auto dtype = common::AnfAlgo::GetOutputInferDataType(input_nodes[0], 0);
250 std::vector<TypeId> dtypes(input_nodes.size(), dtype);
251 auto shape = common::AnfAlgo::GetOutputInferShape(input_nodes[0], 0);
252 std::vector<ShapeVector> shapes(input_nodes.size(), shape);
253 common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, make_tuple.get());
254 make_tuple->set_scope(input_nodes[0]->scope());
255 return make_tuple;
256 }
257
NewTupleGetItemNode(const AnfNodePtr & input_node,size_t output_index)258 CNodePtr NewTupleGetItemNode(const AnfNodePtr &input_node, size_t output_index) {
259 MS_EXCEPTION_IF_NULL(input_node);
260 auto idx = NewValueNode(SizeToLong(output_index));
261 MS_EXCEPTION_IF_NULL(idx);
262 auto getitem = input_node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input_node, idx});
263 MS_EXCEPTION_IF_NULL(getitem);
264
265 std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(input_node, output_index)};
266 auto shapes = {common::AnfAlgo::GetOutputInferShape(input_node, output_index)};
267 common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, getitem.get());
268 getitem->set_scope(input_node->scope());
269 return getitem;
270 }
271
MakeSortedSplitGetItemNodes(const AnfNodePtr & input_node,const std::vector<int64_t> & sort_idx,std::vector<AnfNodePtr> * getitem_nodes)272 void MakeSortedSplitGetItemNodes(const AnfNodePtr &input_node, const std::vector<int64_t> &sort_idx,
273 std::vector<AnfNodePtr> *getitem_nodes) {
274 if (AnfUtils::GetOutputTensorNum(input_node) != sort_idx.size()) {
275 MS_LOG(INTERNAL_EXCEPTION) << "The number of MakeTuple inputs is not equal to sort index number";
276 }
277
278 for (size_t i = 0; i < sort_idx.size(); i++) {
279 auto getitem = NewTupleGetItemNode(input_node, LongToSize(sort_idx[i]));
280 getitem_nodes->push_back(getitem);
281 }
282 }
283
NewTupleGetItemNodes(const AnfNodePtr & input_node,size_t split_num,std::vector<AnfNodePtr> * getitem_nodes)284 void NewTupleGetItemNodes(const AnfNodePtr &input_node, size_t split_num, std::vector<AnfNodePtr> *getitem_nodes) {
285 // input_node is a node such as split node or neighbor exchange node
286 for (size_t i = 0; i < split_num; i++) {
287 auto getitem = NewTupleGetItemNode(input_node, i);
288 getitem_nodes->push_back(getitem);
289 }
290 }
291
NewNeighborExchangeNode(const AnfNodePtr & input_node,const std::vector<int64_t> & send_rank_ids,const std::vector<int64_t> & recv_rank_ids)292 CNodePtr NewNeighborExchangeNode(const AnfNodePtr &input_node, const std::vector<int64_t> &send_rank_ids,
293 const std::vector<int64_t> &recv_rank_ids) {
294 // input_node is maketuple node
295 std::vector<AnfNodePtr> ne_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimNeighborExchange->name())),
296 input_node};
297 auto neighbor_exchange = input_node->func_graph()->NewCNode(ne_inputs);
298 MS_EXCEPTION_IF_NULL(neighbor_exchange);
299 auto input_cnode = input_node->cast<CNodePtr>();
300
301 // RECV_TYPE
302 auto dtype = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
303 common::AnfAlgo::SetNodeAttr(parallel::RECV_TYPE, TypeIdToType(dtype), neighbor_exchange);
304
305 // GROUP
306 std::string group = parallel::g_device_manager->world_group();
307 common::AnfAlgo::SetNodeAttr(parallel::GROUP, MakeValue<std::string>(group), neighbor_exchange);
308
309 // SEND_RANK_IDS, RECV_RANK_IDS
310 common::AnfAlgo::SetNodeAttr(parallel::SEND_RANK_IDS, parallel::MakeListValue(send_rank_ids), neighbor_exchange);
311 common::AnfAlgo::SetNodeAttr(parallel::RECV_RANK_IDS, parallel::MakeListValue(recv_rank_ids), neighbor_exchange);
312
313 // SEND_SHAPES, RECV_SHAPES
314 auto maketuple_input = input_cnode->inputs()[1];
315 parallel::Shape shape = common::AnfAlgo::GetOutputInferShape(maketuple_input, 0);
316 parallel::Shapes send_shapes;
317 parallel::Shapes recv_shapes;
318 for (size_t i = 0; i < send_rank_ids.size(); i++) {
319 send_shapes.push_back(shape);
320 recv_shapes.push_back(shape);
321 }
322 common::AnfAlgo::SetNodeAttr(parallel::SEND_SHAPES, parallel::MakeTupleListValue(send_shapes), neighbor_exchange);
323 common::AnfAlgo::SetNodeAttr(parallel::RECV_SHAPES, parallel::MakeTupleListValue(recv_shapes), neighbor_exchange);
324
325 // set dtypes and shapes
326 std::vector<TypeId> dtypes(recv_shapes.size(), dtype);
327 std::vector<ShapeVector> shapes(recv_shapes.size(), shape);
328 common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, neighbor_exchange.get());
329
330 neighbor_exchange->set_scope(input_node->scope());
331 return neighbor_exchange;
332 }
333
CreateNeighborExchangeNodes(const AnfNodePtr & input_node,size_t split_dim,size_t concat_dim,const std::vector<int64_t> & send_rank_ids,const std::vector<int64_t> & recv_rank_ids,std::vector<AnfNodePtr> * neighbor_exchange_nodes)334 void CreateNeighborExchangeNodes(const AnfNodePtr &input_node, size_t split_dim, size_t concat_dim,
335 const std::vector<int64_t> &send_rank_ids, const std::vector<int64_t> &recv_rank_ids,
336 std::vector<AnfNodePtr> *neighbor_exchange_nodes) {
337 CNodePtr split = nullptr;
338 size_t send_num = send_rank_ids.size();
339 std::vector<AnfNodePtr> getitem_nodes;
340 if (IsPrimitiveCNode(input_node, prim::kPrimSplit)) {
341 NewTupleGetItemNodes(input_node, send_num, &getitem_nodes);
342 } else {
343 split = NewSplitNode(input_node, split_dim, send_num);
344 NewTupleGetItemNodes(split, send_num, &getitem_nodes);
345 }
346
347 auto maketuple = NewMakeTupleNode(getitem_nodes);
348 auto neighbor_exchange = NewNeighborExchangeNode(maketuple, send_rank_ids, recv_rank_ids);
349 std::vector<AnfNodePtr> getitem_nodes_after;
350 size_t recv_num = recv_rank_ids.size();
351 NewTupleGetItemNodes(neighbor_exchange, recv_num, &getitem_nodes_after);
352 auto maketuple_after = NewMakeTupleNode(getitem_nodes_after);
353 auto concat = NewConcatNode(maketuple_after, concat_dim, recv_num);
354
355 if (split != nullptr) {
356 neighbor_exchange_nodes->push_back(split);
357 }
358 (void)neighbor_exchange_nodes->insert(neighbor_exchange_nodes->end(), getitem_nodes.begin(), getitem_nodes.end());
359 neighbor_exchange_nodes->push_back(maketuple);
360 neighbor_exchange_nodes->push_back(neighbor_exchange);
361 (void)neighbor_exchange_nodes->insert(neighbor_exchange_nodes->end(), getitem_nodes_after.begin(),
362 getitem_nodes_after.end());
363 neighbor_exchange_nodes->push_back(maketuple_after);
364 neighbor_exchange_nodes->push_back(concat);
365 }
366
FindNodeIndex(const std::vector<CNodePtr> & node_vector,const CNodePtr & target_node)367 int64_t FindNodeIndex(const std::vector<CNodePtr> &node_vector, const CNodePtr &target_node) {
368 auto iter = std::find(node_vector.begin(), node_vector.end(), target_node);
369 if (iter == node_vector.end()) {
370 return -1;
371 } else {
372 return std::distance(node_vector.begin(), iter);
373 }
374 }
375
FindAlltoallIndex(const std::vector<CNodePtr> & origin_nodes_topological,const CNodePtr & alltoall)376 size_t FindAlltoallIndex(const std::vector<CNodePtr> &origin_nodes_topological, const CNodePtr &alltoall) {
377 int64_t idx = FindNodeIndex(origin_nodes_topological, alltoall);
378 if (idx == -1) {
379 MS_LOG(INTERNAL_EXCEPTION) << "Can not find alltoall node in origin_nodes_topological";
380 }
381 return LongToSize(idx);
382 }
383
ScaleShapeValueNode(const AnfNodePtr & old_shape_node,size_t scale_dim,int64_t scale_factor)384 ValueNodePtr ScaleShapeValueNode(const AnfNodePtr &old_shape_node, size_t scale_dim, int64_t scale_factor) {
385 if (scale_factor == 0) {
386 MS_LOG(INTERNAL_EXCEPTION) << "scale_factor should not be zero.";
387 }
388 auto shape_value_node = old_shape_node->cast<ValueNodePtr>();
389 auto value_ptr = shape_value_node->value();
390 std::vector<ValuePtr> value_ptr_vec = value_ptr->cast<ValueTuplePtr>()->value();
391 ShapeVector new_shape;
392 for (size_t i = 0; i < value_ptr_vec.size(); i++) {
393 auto shape_value = GetValue<int64_t>(value_ptr_vec[i]);
394 if (i == scale_dim) {
395 shape_value /= scale_factor;
396 }
397 new_shape.push_back(shape_value);
398 }
399 return NewValueNode(MakeValue(new_shape));
400 }
401
FindCNodesAmongAlltoall(const std::vector<CNodePtr> & origin_nodes_topological,const CNodePtrPair & alltoall_pair)402 const std::vector<CNodePtr> FindCNodesAmongAlltoall(const std::vector<CNodePtr> &origin_nodes_topological,
403 const CNodePtrPair &alltoall_pair) {
404 auto front_alltoall = alltoall_pair.first;
405 auto back_alltoall = alltoall_pair.second;
406 size_t front_alltoall_idx = FindAlltoallIndex(origin_nodes_topological, front_alltoall);
407 size_t back_alltoall_idx = FindAlltoallIndex(origin_nodes_topological, back_alltoall);
408 std::vector<CNodePtr> cnodes;
409 for (size_t i = front_alltoall_idx + 1; i < back_alltoall_idx; i++) {
410 cnodes.push_back(origin_nodes_topological[i]);
411 }
412 return cnodes;
413 }
414
CloneScaledGraph(const std::vector<CNodePtr> & old_cnodes,const AnfNodePtr & input_node,size_t scale_factor,GpeaInfo * gpea_info,std::vector<AnfNodePtr> * new_nodes)415 void CloneScaledGraph(const std::vector<CNodePtr> &old_cnodes, const AnfNodePtr &input_node, size_t scale_factor,
416 GpeaInfo *gpea_info, std::vector<AnfNodePtr> *new_nodes) {
417 mindspore::HashMap<CNodePtr, CNodePtr> cnode_map;
418 auto input_cnode = input_node->cast<CNodePtr>();
419 auto old_input_node = old_cnodes[0]->input(1);
420 auto old_input_cnode = old_input_node->cast<CNodePtr>();
421 MS_EXCEPTION_IF_NULL(old_input_cnode);
422 cnode_map[old_input_cnode] = input_cnode;
423
424 size_t reshape_cnt = 0;
425 std::vector<uint32_t> reshape_scale_axis = gpea_info->GetReshapeScaleAxisVec();
426 for (size_t i = 0; i < old_cnodes.size(); i++) {
427 auto cnode = old_cnodes[i];
428 MS_LOG(DEBUG) << "node in " << i << " " << GetCNodePrimitive(cnode)->name();
429 if (IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
430 new_nodes->push_back(cnode); // reuse old Load node to not increase device memory
431 continue;
432 }
433
434 // clone inputs
435 std::vector<AnfNodePtr> new_inputs;
436 auto inputs = cnode->inputs();
437 for (size_t j = 0; j < inputs.size(); j++) {
438 auto input = inputs[j];
439 if (input->isa<CNode>()) {
440 auto curr_input_cnode = input->cast<CNodePtr>();
441 CNodePtr new_cnode;
442 if (IsPrimitiveCNode(curr_input_cnode, prim::kPrimLoad)) {
443 new_cnode = curr_input_cnode;
444 } else {
445 new_cnode = cnode_map[curr_input_cnode];
446 }
447 auto new_anf_node = new_cnode->cast<AnfNodePtr>();
448 new_inputs.push_back(new_anf_node);
449 } else if (input->isa<ValueNode>()) {
450 ValueNodePtr new_value_node = NewValueNode(GetValueNode(input));
451 new_inputs.push_back(new_value_node);
452 } else if (input->isa<Parameter>()) {
453 new_inputs.push_back(input);
454 }
455 }
456
457 // scale reshape shape value
458 if (IsPrimitiveCNode(cnode, prim::kPrimReshape)) {
459 new_inputs[kIndex2] = ScaleShapeValueNode(new_inputs[kIndex2], reshape_scale_axis[reshape_cnt], scale_factor);
460 reshape_cnt += 1;
461 }
462
463 // create CNode
464 auto new_cnode = input_node->func_graph()->NewCNode(new_inputs);
465 MS_EXCEPTION_IF_NULL(new_cnode);
466 new_cnode->set_scope(cnode->scope());
467 cnode_map[cnode] = new_cnode;
468 new_nodes->push_back(new_cnode);
469 }
470 }
471
InsertDependOnBranches(const FuncGraphManagerPtr & manager,const std::vector<std::vector<AnfNodePtr>> & front_nodes,const std::vector<std::vector<AnfNodePtr>> & back_nodes)472 void InsertDependOnBranches(const FuncGraphManagerPtr &manager, const std::vector<std::vector<AnfNodePtr>> &front_nodes,
473 const std::vector<std::vector<AnfNodePtr>> &back_nodes) {
474 size_t comm_node_idx = 0;
475 for (size_t i = 0; i < front_nodes[0].size(); i++) {
476 auto cnode = front_nodes[0][i]->cast<CNodePtr>();
477 if (IsPrimitiveCNode(cnode, prim::kPrimNeighborExchange)) {
478 comm_node_idx = i;
479 break;
480 }
481 }
482
483 for (size_t branch_idx = 0; branch_idx < front_nodes.size() - 1; branch_idx++) {
484 auto prev_node = front_nodes[branch_idx][comm_node_idx];
485 auto node = front_nodes[branch_idx][comm_node_idx + 1];
486 auto add_node = back_nodes[branch_idx + 1].front();
487 // graph branch execution is reverse, so former branch depends on latter branch
488 std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), prev_node, add_node};
489 auto depend = node->func_graph()->NewCNode(depend_inputs);
490 MS_EXCEPTION_IF_NULL(depend);
491 depend->set_abstract(prev_node->abstract()->Clone());
492 manager->SetEdge(node, 1, depend);
493 }
494 }
495
CreateReplaceGraph(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & origin_nodes_topological,const CNodePtrPair & alltoall_pair,GpeaInfo * gpea_info)496 CNodePtr CreateReplaceGraph(const FuncGraphManagerPtr &manager, const std::vector<CNodePtr> &origin_nodes_topological,
497 const CNodePtrPair &alltoall_pair, GpeaInfo *gpea_info) {
498 auto front_alltoall = alltoall_pair.first;
499 auto back_alltoall = alltoall_pair.second;
500 auto graph_input = front_alltoall->input(1);
501
502 // split input into several branch
503 size_t front_split_dim = GetSplitDimFromAlltoall(front_alltoall);
504 size_t front_concat_dim = GetConcatDimFromAlltoall(front_alltoall);
505 size_t split_num = LongToSize(gpea_info->GetGroupNum());
506 auto split = NewSplitNode(graph_input, front_split_dim, split_num);
507
508 // clone several branch calculation graph
509 size_t back_split_dim = GetSplitDimFromAlltoall(back_alltoall);
510 size_t back_concat_dim = GetConcatDimFromAlltoall(back_alltoall);
511 auto back_input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(back_alltoall, 0);
512 if (split_num == 0) {
513 MS_LOG(INTERNAL_EXCEPTION) << "split_num should not be zero.";
514 }
515 back_input_shape[back_split_dim] /= SizeToLong(split_num);
516 auto back_input_dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(back_alltoall, 0);
517 std::vector<int64_t> send_group_ranks = gpea_info->GetSendGroupRanks();
518 std::vector<std::vector<AnfNodePtr>> front_comm_branches;
519 std::vector<std::vector<AnfNodePtr>> back_comm_branches;
520 std::vector<AnfNodePtr> branch_output_nodes;
521 auto old_calc_cnodes = FindCNodesAmongAlltoall(origin_nodes_topological, alltoall_pair);
522 for (size_t i = 0; i < split_num; i++) {
523 auto send_rank_ids = gpea_info->GetSendRankIds(i);
524 auto recv_rank_ids = gpea_info->GetRecvRankIds(i);
525 size_t split_branch_idx = LongToSize(send_group_ranks[i]);
526 auto getitem = NewTupleGetItemNode(split, split_branch_idx);
527 // create first neighbor exchange nodes
528 std::vector<AnfNodePtr> front_neighbor_exchange_nodes;
529 CreateNeighborExchangeNodes(getitem, front_split_dim, front_concat_dim, send_rank_ids, recv_rank_ids,
530 &front_neighbor_exchange_nodes);
531 front_comm_branches.push_back(front_neighbor_exchange_nodes);
532 // clone calculation nodes
533 std::vector<AnfNodePtr> new_calc_nodes;
534 CloneScaledGraph(old_calc_cnodes, front_neighbor_exchange_nodes.back(), split_num, gpea_info, &new_calc_nodes);
535 MS_LOG(DEBUG) << "Create calculation done";
536 // create second neighbor exchange nodes
537 auto back_comm_split =
538 NewSplitNode(new_calc_nodes.back(), back_split_dim, recv_rank_ids.size(), back_input_shape, back_input_dtype);
539 std::vector<AnfNodePtr> back_neighbor_exchange_nodes;
540 back_neighbor_exchange_nodes.push_back(back_comm_split);
541 CreateNeighborExchangeNodes(back_comm_split, back_split_dim, back_concat_dim, recv_rank_ids, send_rank_ids,
542 &back_neighbor_exchange_nodes);
543 branch_output_nodes.push_back(back_neighbor_exchange_nodes.back());
544 back_comm_branches.push_back(back_neighbor_exchange_nodes);
545 }
546 MS_LOG(DEBUG) << "Create multi branch done";
547 InsertDependOnBranches(manager, front_comm_branches, back_comm_branches);
548 MS_LOG(DEBUG) << "InsertDependOnBranches done";
549
550 // concat several branch into one branch
551 auto maketuple = NewMakeTupleNode(branch_output_nodes);
552 auto concat = NewConcatNode(maketuple, back_concat_dim, split_num);
553 auto reorder_split = NewSplitNode(concat, back_concat_dim, split_num);
554 std::vector<AnfNodePtr> reorder_getitem_nodes;
555 std::vector<int64_t> sort_idx = gpea_info->GetSortedInputsIdx();
556 MakeSortedSplitGetItemNodes(reorder_split, sort_idx, &reorder_getitem_nodes);
557 auto reorder_maketuple = NewMakeTupleNode(reorder_getitem_nodes);
558 auto reorder_concat = NewConcatNode(reorder_maketuple, back_concat_dim, split_num);
559 MS_LOG(DEBUG) << "Create replace graph done";
560 return reorder_concat;
561 }
562
CreateAndReplaceAlltoall(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & origin_nodes_topological,const CNodePtrPair & alltoall_pair,GpeaInfo * gpea_info)563 void CreateAndReplaceAlltoall(const FuncGraphManagerPtr &manager, const std::vector<CNodePtr> &origin_nodes_topological,
564 const CNodePtrPair &alltoall_pair, GpeaInfo *gpea_info) {
565 auto cnode = CreateReplaceGraph(manager, origin_nodes_topological, alltoall_pair, gpea_info);
566 (void)manager->Replace(alltoall_pair.second, cnode);
567 }
568
CreateAndReplaceGraph(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & origin_nodes_topological,const std::vector<CNodePtrPair> & alltoall_pairs,GpeaInfo * gpea_info)569 void CreateAndReplaceGraph(const FuncGraphManagerPtr &manager, const std::vector<CNodePtr> &origin_nodes_topological,
570 const std::vector<CNodePtrPair> &alltoall_pairs, GpeaInfo *gpea_info) {
571 for (size_t i = 0; i < alltoall_pairs.size(); i++) {
572 CreateAndReplaceAlltoall(manager, origin_nodes_topological, alltoall_pairs[i], gpea_info);
573 }
574 }
575
CheckReshapeScaleAxis(const std::vector<CNodePtr> & origin_nodes_topological,const CNodePtrPair & alltoall_pair,GpeaInfo * gpea_info)576 void CheckReshapeScaleAxis(const std::vector<CNodePtr> &origin_nodes_topological, const CNodePtrPair &alltoall_pair,
577 GpeaInfo *gpea_info) {
578 size_t front_alltoall_idx = FindAlltoallIndex(origin_nodes_topological, alltoall_pair.first);
579 size_t back_alltoall_idx = FindAlltoallIndex(origin_nodes_topological, alltoall_pair.second);
580 size_t reshape_cnode_num = 0;
581 for (size_t i = front_alltoall_idx + 1; i < back_alltoall_idx; i++) {
582 auto cnode = origin_nodes_topological[i];
583 if (IsPrimitiveCNode(cnode, prim::kPrimReshape)) {
584 reshape_cnode_num += 1;
585 }
586 }
587
588 std::vector<uint32_t> reshape_scale_axis = gpea_info->GetReshapeScaleAxisVec();
589 size_t axis_length = reshape_scale_axis.size();
590 MS_LOG(DEBUG) << "The graph has " << reshape_cnode_num << " reshape nodes, which will be scaled on certain axis";
591 if (axis_length == reshape_cnode_num) {
592 MS_LOG(DEBUG) << "'reshape_scale_axis' " << reshape_scale_axis;
593 return;
594 } else {
595 MS_LOG(DEBUG) << "'reshape_scale_axis' has " << axis_length
596 << " element, its length should be same as the number of scaled reshape node";
597 MS_LOG(DEBUG) << "Show graph nodes start";
598 for (size_t i = front_alltoall_idx; i < back_alltoall_idx + 1; i++) {
599 auto cnode = origin_nodes_topological[i];
600 auto prim_name = GetCNodePrimitive(cnode)->name();
601 auto scope_name = cnode->scope()->name();
602 std::string input_shape_string = "";
603 std::string output_shape_string = "";
604 std::string space = " ";
605 size_t output_num = AnfUtils::GetOutputTensorNum(cnode);
606 size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
607 for (size_t j = 0; j < input_num; j++) {
608 auto shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, j);
609 if (shape.size() > 0) {
610 input_shape_string += space;
611 input_shape_string += parallel::ShapeToString(shape);
612 }
613 }
614 for (size_t j = 0; j < output_num; j++) {
615 auto shape = common::AnfAlgo::GetOutputInferShape(cnode, j);
616 if (shape.size() > 0) {
617 output_shape_string += space;
618 output_shape_string += parallel::ShapeToString(shape);
619 }
620 }
621
622 MS_LOG(DEBUG) << "Node: " << prim_name;
623 MS_LOG(DEBUG) << "Scope: " << scope_name;
624 MS_LOG(DEBUG) << "Input shapes: " << input_shape_string;
625 MS_LOG(DEBUG) << "Output shapes: " << output_shape_string;
626 }
627 MS_LOG(DEBUG) << "Show graph nodes end";
628 MS_LOG(EXCEPTION) << "The size of 'reshape scale axis' is not equal to reshape nodes number in graph. There are "
629 << reshape_cnode_num
630 << " reshape nodes to be scaled. Please set correct scale axis for each reshape node";
631 }
632 }
633
CheckUserSettings(const FuncGraphPtr & fg,GpeaInfo * gpea_info)634 bool CheckUserSettings(const FuncGraphPtr &fg, GpeaInfo *gpea_info) {
635 if (parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kSemiAutoParallel) {
636 MS_LOG(DEBUG) << "To activate the pass, set_auto_parallel_context 'parallel_mode' should be 'semi_auto_parallel'";
637 return false;
638 }
639
640 if (!parallel::ParallelContext::GetInstance()->enable_all2all()) {
641 MS_LOG(DEBUG) << "To activate the pass, set_auto_parallel_context 'enable_alltoall' should be true";
642 return false;
643 }
644
645 if (fg->has_flag(kTraining)) {
646 MS_LOG(DEBUG) << "To activate the pass, network 'set_train' should be false";
647 return false;
648 }
649
650 gpea_info->DisplayInfo();
651
652 int64_t gpea_num = gpea_info->GetGroupNum();
653 if (gpea_num <= 1 || LongToSize(gpea_num) == GetDeviceNum()) {
654 MS_LOG(DEBUG) << "To activate the pass, gpea_num " << gpea_num << " should between (1, " << GetDeviceNum() << ")";
655 return false;
656 }
657
658 if (GetDeviceNum() % LongToSize(gpea_num) != 0) {
659 MS_LOG(DEBUG) << "To activate the pass, device num " << GetDeviceNum() << " should be divisible by gpea_num "
660 << LongToSize(gpea_num);
661 return false;
662 }
663 return true;
664 }
665 } // namespace
666
GetDeviceNum()667 size_t GetDeviceNum() { return parallel::g_device_manager->DeviceNum(); }
668
GetGlobalRankID()669 size_t GetGlobalRankID() { return LongToSize(parallel::g_device_manager->global_rank()); }
670
SetGroupedPairwiseExchangeAllToAll(const pipeline::ResourcePtr & resource)671 void SetGroupedPairwiseExchangeAllToAll(const pipeline::ResourcePtr &resource) {
672 if (parallel::g_device_manager == nullptr) {
673 MS_LOG(INFO) << "parallel::g_device_manager is not initialized.";
674 return;
675 }
676
677 MS_EXCEPTION_IF_NULL(resource);
678 FuncGraphPtr func_graph = resource->func_graph();
679 MS_EXCEPTION_IF_NULL(func_graph);
680
681 auto gpea_info = GpeaInfo();
682 if (!CheckUserSettings(func_graph, &gpea_info)) {
683 return;
684 }
685
686 auto manager = func_graph->manager();
687 MS_EXCEPTION_IF_NULL(manager);
688 std::list<CNodePtr> orders = func_graph->GetOrderedCnodes();
689 std::vector<CNodePtr> origin_nodes_topological(orders.cbegin(), orders.cend());
690
691 std::vector<CNodePtrPair> alltoall_pairs;
692 FindAlltoallNodePairs(manager, origin_nodes_topological, &alltoall_pairs);
693 MS_LOG(DEBUG) << "Find alltoall_pairs num: " << alltoall_pairs.size();
694 if (alltoall_pairs.size() == 0) {
695 MS_LOG(WARNING) << "Not find alltoall_pairs, skip the pass";
696 return;
697 }
698
699 CheckReshapeScaleAxis(origin_nodes_topological, alltoall_pairs[0], &gpea_info);
700
701 CreateAndReplaceGraph(manager, origin_nodes_topological, alltoall_pairs, &gpea_info);
702 MS_LOG(DEBUG) << "CreateAndReplaceGraph done";
703
704 // Renormalize, infer shape and set abstract for all nodes in graph
705 abstract::AbstractBasePtrList args_abs;
706 auto parameters = func_graph->parameters();
707 (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_abs),
708 [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
709 FuncGraphPtr new_fg = pipeline::Renormalize(resource, func_graph, args_abs);
710 resource->set_func_graph(new_fg);
711 resource->set_args_abs(args_abs);
712 MS_LOG(DEBUG) << "Renormalize done";
713 return;
714 }
715 } // namespace opt
716 } // namespace mindspore
717