• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/grouped_pairwise_exchange_alltoall.h"
18 #include <memory>
19 #include <queue>
20 #include <utility>
21 #include <list>
22 #include <vector>
23 #include <string>
24 #include <algorithm>
25 
26 #include "mindspore/core/ops/sequence_ops.h"
27 #include "mindspore/core/ops/other_ops.h"
28 #include "mindspore/core/ops/array_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "utils/anf_utils.h"
31 #include "include/common/utils/utils.h"
32 #include "include/common/utils/anfalgo.h"
33 #include "include/common/utils/parallel_context.h"
34 #include "frontend/parallel/ops_info/ops_utils.h"
35 #include "frontend/parallel/ops_info/operator_info.h"
36 #include "frontend/parallel/tensor_layout/tensor_info.h"
37 #include "frontend/parallel/device_matrix.h"
38 #include "pipeline/jit/ps/action.h"
39 
40 namespace mindspore {
41 namespace opt {
42 namespace {
43 using CNodePtrPair = std::pair<CNodePtr, CNodePtr>;
44 using GpeaInfo = GroupedPairwiseExchangeAllToAllInfo;
45 
FindFrontAlltoall(const CNodePtr & marked_node,std::vector<CNodePtr> * visited_marked_nodes)46 CNodePtr FindFrontAlltoall(const CNodePtr &marked_node, std::vector<CNodePtr> *visited_marked_nodes) {
47   MS_EXCEPTION_IF_NULL(marked_node);
48   auto input_node = marked_node->input(1);
49   auto input_cnode = input_node->cast<CNodePtr>();
50   MS_EXCEPTION_IF_NULL(input_cnode);
51   std::queue<CNodePtr> node_queue;
52   node_queue.push(input_cnode);
53 
54   CNodePtr alltoall_node = nullptr;
55   while (!node_queue.empty()) {
56     auto cnode = node_queue.front();
57     node_queue.pop();
58     if (IsPrimitiveCNode(cnode, prim::kPrimAlltoAll)) {
59       alltoall_node = cnode;
60       break;
61     }
62 
63     if (cnode->HasAttr("gpea_label")) {
64       visited_marked_nodes->push_back(cnode);
65     }
66 
67     auto input = cnode->input(1);
68     MS_EXCEPTION_IF_NULL(input);
69     if (!input->isa<CNode>()) {
70       break;
71     }
72     auto in_cnode = input->cast<CNodePtr>();
73     MS_EXCEPTION_IF_NULL(in_cnode);
74     node_queue.push(in_cnode);
75   }
76 
77   if (alltoall_node == nullptr) {
78     MS_LOG(WARNING) << "Can't find alltoall node before " << GetCNodePrimitive(marked_node)->name();
79   }
80   return alltoall_node;
81 }
82 
FindBackAlltoall(const FuncGraphManagerPtr & manager,const CNodePtr & marked_node,std::vector<CNodePtr> * visited_marked_nodes)83 CNodePtr FindBackAlltoall(const FuncGraphManagerPtr &manager, const CNodePtr &marked_node,
84                           std::vector<CNodePtr> *visited_marked_nodes) {
85   MS_EXCEPTION_IF_NULL(marked_node);
86   auto node_users_map = manager->node_users();
87   auto node_users = node_users_map[marked_node];
88   auto first_user = node_users.front().first;
89   auto first_user_cnode = first_user->cast<CNodePtr>();
90   MS_EXCEPTION_IF_NULL(first_user_cnode);
91   std::queue<CNodePtr> node_queue;
92   node_queue.push(first_user_cnode);
93 
94   CNodePtr alltoall_node = nullptr;
95   while (!node_queue.empty()) {
96     auto cnode = node_queue.front();
97     node_queue.pop();
98     if (IsPrimitiveCNode(cnode, prim::kPrimAlltoAll)) {
99       alltoall_node = cnode;
100       break;
101     }
102 
103     if (GetCNodePrimitive(cnode)->HasAttr("gpea_label")) {
104       visited_marked_nodes->push_back(cnode);
105     }
106 
107     auto cnode_users = node_users_map[cnode];
108     if (cnode_users.empty()) {  // last cnode, exit while
109       break;
110     }
111     auto first_node = cnode_users.front().first;
112     MS_EXCEPTION_IF_NULL(first_node);
113     auto first_cnode = first_node->cast<CNodePtr>();
114     MS_EXCEPTION_IF_NULL(first_cnode);
115     node_queue.push(first_cnode);
116   }
117 
118   if (alltoall_node == nullptr) {
119     MS_LOG(WARNING) << "Can't find alltoall node after " << GetCNodePrimitive(marked_node)->name();
120   }
121   return alltoall_node;
122 }
123 
FindAlltoallPair(const FuncGraphManagerPtr & manager,const CNodePtr & marked_node,std::vector<CNodePtr> * visited_marked_nodes)124 CNodePtrPair FindAlltoallPair(const FuncGraphManagerPtr &manager, const CNodePtr &marked_node,
125                               std::vector<CNodePtr> *visited_marked_nodes) {
126   auto front_alltoall = FindFrontAlltoall(marked_node, visited_marked_nodes);
127   if (front_alltoall == nullptr) {
128     CNodePtrPair null_alltoall_pair(nullptr, nullptr);
129     return null_alltoall_pair;
130   }
131 
132   auto back_alltoall = FindBackAlltoall(manager, marked_node, visited_marked_nodes);
133   if (back_alltoall == nullptr) {
134     CNodePtrPair null_alltoall_pair(nullptr, nullptr);
135     return null_alltoall_pair;
136   }
137 
138   CNodePtrPair alltoall_pair(front_alltoall, back_alltoall);
139   return alltoall_pair;
140 }
141 
FindAlltoallNodePairs(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & origin_nodes_topological,std::vector<CNodePtrPair> * alltoall_pairs)142 void FindAlltoallNodePairs(const FuncGraphManagerPtr &manager, const std::vector<CNodePtr> &origin_nodes_topological,
143                            std::vector<CNodePtrPair> *alltoall_pairs) {
144   std::vector<CNodePtr> visited_marked_nodes;
145   for (size_t i = 0; i < origin_nodes_topological.size(); i++) {
146     auto cnode = origin_nodes_topological[i];
147     if (!IsPrimitiveCNode(cnode)) {
148       continue;
149     }
150 
151     if (!GetCNodePrimitive(cnode)->HasAttr("gpea_label")) {
152       continue;
153     }
154 
155     if (std::find(visited_marked_nodes.begin(), visited_marked_nodes.end(), cnode) != visited_marked_nodes.end()) {
156       continue;
157     }
158 
159     visited_marked_nodes.push_back(cnode);
160     auto alltoall_pair = FindAlltoallPair(manager, cnode, &visited_marked_nodes);
161     if (alltoall_pair.first == nullptr || alltoall_pair.second == nullptr) {
162       MS_LOG(WARNING) << "not find alltoall_pair around cnode: " << GetCNodePrimitive(cnode)->name();
163       continue;
164     }
165     alltoall_pairs->push_back(alltoall_pair);
166   }
167 }
168 
GetSplitDimFromAlltoall(const AnfNodePtr & alltoall)169 size_t GetSplitDimFromAlltoall(const AnfNodePtr &alltoall) {
170   size_t split_dim = LongToSize(common::AnfAlgo::GetNodeAttr<int64_t>(alltoall, kAttrSplitDim));
171   return split_dim;
172 }
173 
GetConcatDimFromAlltoall(const AnfNodePtr & alltoall)174 size_t GetConcatDimFromAlltoall(const AnfNodePtr &alltoall) {
175   size_t concat_dim = LongToSize(common::AnfAlgo::GetNodeAttr<int64_t>(alltoall, kAttrConcatDim));
176   return concat_dim;
177 }
178 
NewSplitNode(const AnfNodePtr & input_node,size_t split_dim,size_t split_num)179 CNodePtr NewSplitNode(const AnfNodePtr &input_node, size_t split_dim, size_t split_num) {
180   if (split_num == 0) {
181     MS_LOG(INTERNAL_EXCEPTION) << "split_num should not be zero.";
182   }
183   MS_EXCEPTION_IF_NULL(input_node);
184   std::vector<AnfNodePtr> split_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplit->name())),
185                                           input_node, NewValueNode<int64_t>(split_dim),
186                                           NewValueNode<int64_t>(split_num)};
187   auto split = input_node->func_graph()->NewCNode(split_inputs);
188   MS_EXCEPTION_IF_NULL(split);
189 
190   auto dtype = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
191   std::vector<TypeId> dtypes(split_num, dtype);
192   auto shape = common::AnfAlgo::GetOutputInferShape(input_node, 0);
193   shape[split_dim] /= SizeToLong(split_num);
194   std::vector<ShapeVector> shapes(split_num, shape);
195   common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());
196   split->set_scope(input_node->scope());
197   return split;
198 }
199 
NewSplitNode(const AnfNodePtr & input_node,size_t split_dim,size_t split_num,const ShapeVector & input_shape,const TypeId & input_dtype)200 CNodePtr NewSplitNode(const AnfNodePtr &input_node, size_t split_dim, size_t split_num, const ShapeVector &input_shape,
201                       const TypeId &input_dtype) {
202   if (split_num == 0) {
203     MS_LOG(INTERNAL_EXCEPTION) << "split_num should not be zero.";
204   }
205   MS_EXCEPTION_IF_NULL(input_node);
206   std::vector<AnfNodePtr> split_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplit->name())),
207                                           input_node, NewValueNode<int64_t>(split_dim),
208                                           NewValueNode<int64_t>(split_num)};
209   auto split = input_node->func_graph()->NewCNode(split_inputs);
210   MS_EXCEPTION_IF_NULL(split);
211 
212   std::vector<TypeId> dtypes(split_num, input_dtype);
213   ShapeVector shape;
214   for (size_t i = 0; i < input_shape.size(); i++) {
215     shape.push_back(input_shape[i]);
216   }
217   shape[split_dim] /= SizeToLong(split_num);
218   std::vector<ShapeVector> shapes(split_num, shape);
219   common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());
220   split->set_scope(input_node->scope());
221   return split;
222 }
223 
NewConcatNode(const AnfNodePtr & input_node,size_t concat_dim,size_t input_num)224 CNodePtr NewConcatNode(const AnfNodePtr &input_node, size_t concat_dim, size_t input_num) {
225   MS_EXCEPTION_IF_NULL(input_node);
226   std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
227                                            input_node, NewValueNode(MakeValue(static_cast<int64_t>(concat_dim)))};
228   auto concat = input_node->func_graph()->NewCNode(concat_inputs);
229   MS_EXCEPTION_IF_NULL(concat);
230 
231   std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(input_node, 0)};
232   auto shape = common::AnfAlgo::GetOutputInferShape(input_node, 0);
233   shape[concat_dim] *= SizeToLong(input_num);
234   std::vector<ShapeVector> shapes(1, shape);
235   common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, concat.get());
236   concat->set_scope(input_node->scope());
237   return concat;
238 }
239 
NewMakeTupleNode(const std::vector<AnfNodePtr> & input_nodes)240 CNodePtr NewMakeTupleNode(const std::vector<AnfNodePtr> &input_nodes) {
241   // input_nodes are getitem nodes
242   std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
243   for (size_t i = 0; i < input_nodes.size(); i++) {
244     make_tuple_inputs.push_back(input_nodes[i]);
245   }
246   auto make_tuple = input_nodes[0]->func_graph()->NewCNode(make_tuple_inputs);
247   MS_EXCEPTION_IF_NULL(make_tuple);
248 
249   auto dtype = common::AnfAlgo::GetOutputInferDataType(input_nodes[0], 0);
250   std::vector<TypeId> dtypes(input_nodes.size(), dtype);
251   auto shape = common::AnfAlgo::GetOutputInferShape(input_nodes[0], 0);
252   std::vector<ShapeVector> shapes(input_nodes.size(), shape);
253   common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, make_tuple.get());
254   make_tuple->set_scope(input_nodes[0]->scope());
255   return make_tuple;
256 }
257 
NewTupleGetItemNode(const AnfNodePtr & input_node,size_t output_index)258 CNodePtr NewTupleGetItemNode(const AnfNodePtr &input_node, size_t output_index) {
259   MS_EXCEPTION_IF_NULL(input_node);
260   auto idx = NewValueNode(SizeToLong(output_index));
261   MS_EXCEPTION_IF_NULL(idx);
262   auto getitem = input_node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input_node, idx});
263   MS_EXCEPTION_IF_NULL(getitem);
264 
265   std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(input_node, output_index)};
266   auto shapes = {common::AnfAlgo::GetOutputInferShape(input_node, output_index)};
267   common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, getitem.get());
268   getitem->set_scope(input_node->scope());
269   return getitem;
270 }
271 
MakeSortedSplitGetItemNodes(const AnfNodePtr & input_node,const std::vector<int64_t> & sort_idx,std::vector<AnfNodePtr> * getitem_nodes)272 void MakeSortedSplitGetItemNodes(const AnfNodePtr &input_node, const std::vector<int64_t> &sort_idx,
273                                  std::vector<AnfNodePtr> *getitem_nodes) {
274   if (AnfUtils::GetOutputTensorNum(input_node) != sort_idx.size()) {
275     MS_LOG(INTERNAL_EXCEPTION) << "The number of MakeTuple inputs is not equal to sort index number";
276   }
277 
278   for (size_t i = 0; i < sort_idx.size(); i++) {
279     auto getitem = NewTupleGetItemNode(input_node, LongToSize(sort_idx[i]));
280     getitem_nodes->push_back(getitem);
281   }
282 }
283 
NewTupleGetItemNodes(const AnfNodePtr & input_node,size_t split_num,std::vector<AnfNodePtr> * getitem_nodes)284 void NewTupleGetItemNodes(const AnfNodePtr &input_node, size_t split_num, std::vector<AnfNodePtr> *getitem_nodes) {
285   // input_node is a node such as split node or neighbor exchange node
286   for (size_t i = 0; i < split_num; i++) {
287     auto getitem = NewTupleGetItemNode(input_node, i);
288     getitem_nodes->push_back(getitem);
289   }
290 }
291 
NewNeighborExchangeNode(const AnfNodePtr & input_node,const std::vector<int64_t> & send_rank_ids,const std::vector<int64_t> & recv_rank_ids)292 CNodePtr NewNeighborExchangeNode(const AnfNodePtr &input_node, const std::vector<int64_t> &send_rank_ids,
293                                  const std::vector<int64_t> &recv_rank_ids) {
294   // input_node is maketuple node
295   std::vector<AnfNodePtr> ne_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimNeighborExchange->name())),
296                                        input_node};
297   auto neighbor_exchange = input_node->func_graph()->NewCNode(ne_inputs);
298   MS_EXCEPTION_IF_NULL(neighbor_exchange);
299   auto input_cnode = input_node->cast<CNodePtr>();
300 
301   // RECV_TYPE
302   auto dtype = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
303   common::AnfAlgo::SetNodeAttr(parallel::RECV_TYPE, TypeIdToType(dtype), neighbor_exchange);
304 
305   // GROUP
306   std::string group = parallel::g_device_manager->world_group();
307   common::AnfAlgo::SetNodeAttr(parallel::GROUP, MakeValue<std::string>(group), neighbor_exchange);
308 
309   // SEND_RANK_IDS, RECV_RANK_IDS
310   common::AnfAlgo::SetNodeAttr(parallel::SEND_RANK_IDS, parallel::MakeListValue(send_rank_ids), neighbor_exchange);
311   common::AnfAlgo::SetNodeAttr(parallel::RECV_RANK_IDS, parallel::MakeListValue(recv_rank_ids), neighbor_exchange);
312 
313   // SEND_SHAPES, RECV_SHAPES
314   auto maketuple_input = input_cnode->inputs()[1];
315   parallel::Shape shape = common::AnfAlgo::GetOutputInferShape(maketuple_input, 0);
316   parallel::Shapes send_shapes;
317   parallel::Shapes recv_shapes;
318   for (size_t i = 0; i < send_rank_ids.size(); i++) {
319     send_shapes.push_back(shape);
320     recv_shapes.push_back(shape);
321   }
322   common::AnfAlgo::SetNodeAttr(parallel::SEND_SHAPES, parallel::MakeTupleListValue(send_shapes), neighbor_exchange);
323   common::AnfAlgo::SetNodeAttr(parallel::RECV_SHAPES, parallel::MakeTupleListValue(recv_shapes), neighbor_exchange);
324 
325   // set dtypes and shapes
326   std::vector<TypeId> dtypes(recv_shapes.size(), dtype);
327   std::vector<ShapeVector> shapes(recv_shapes.size(), shape);
328   common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, neighbor_exchange.get());
329 
330   neighbor_exchange->set_scope(input_node->scope());
331   return neighbor_exchange;
332 }
333 
CreateNeighborExchangeNodes(const AnfNodePtr & input_node,size_t split_dim,size_t concat_dim,const std::vector<int64_t> & send_rank_ids,const std::vector<int64_t> & recv_rank_ids,std::vector<AnfNodePtr> * neighbor_exchange_nodes)334 void CreateNeighborExchangeNodes(const AnfNodePtr &input_node, size_t split_dim, size_t concat_dim,
335                                  const std::vector<int64_t> &send_rank_ids, const std::vector<int64_t> &recv_rank_ids,
336                                  std::vector<AnfNodePtr> *neighbor_exchange_nodes) {
337   CNodePtr split = nullptr;
338   size_t send_num = send_rank_ids.size();
339   std::vector<AnfNodePtr> getitem_nodes;
340   if (IsPrimitiveCNode(input_node, prim::kPrimSplit)) {
341     NewTupleGetItemNodes(input_node, send_num, &getitem_nodes);
342   } else {
343     split = NewSplitNode(input_node, split_dim, send_num);
344     NewTupleGetItemNodes(split, send_num, &getitem_nodes);
345   }
346 
347   auto maketuple = NewMakeTupleNode(getitem_nodes);
348   auto neighbor_exchange = NewNeighborExchangeNode(maketuple, send_rank_ids, recv_rank_ids);
349   std::vector<AnfNodePtr> getitem_nodes_after;
350   size_t recv_num = recv_rank_ids.size();
351   NewTupleGetItemNodes(neighbor_exchange, recv_num, &getitem_nodes_after);
352   auto maketuple_after = NewMakeTupleNode(getitem_nodes_after);
353   auto concat = NewConcatNode(maketuple_after, concat_dim, recv_num);
354 
355   if (split != nullptr) {
356     neighbor_exchange_nodes->push_back(split);
357   }
358   (void)neighbor_exchange_nodes->insert(neighbor_exchange_nodes->end(), getitem_nodes.begin(), getitem_nodes.end());
359   neighbor_exchange_nodes->push_back(maketuple);
360   neighbor_exchange_nodes->push_back(neighbor_exchange);
361   (void)neighbor_exchange_nodes->insert(neighbor_exchange_nodes->end(), getitem_nodes_after.begin(),
362                                         getitem_nodes_after.end());
363   neighbor_exchange_nodes->push_back(maketuple_after);
364   neighbor_exchange_nodes->push_back(concat);
365 }
366 
FindNodeIndex(const std::vector<CNodePtr> & node_vector,const CNodePtr & target_node)367 int64_t FindNodeIndex(const std::vector<CNodePtr> &node_vector, const CNodePtr &target_node) {
368   auto iter = std::find(node_vector.begin(), node_vector.end(), target_node);
369   if (iter == node_vector.end()) {
370     return -1;
371   } else {
372     return std::distance(node_vector.begin(), iter);
373   }
374 }
375 
FindAlltoallIndex(const std::vector<CNodePtr> & origin_nodes_topological,const CNodePtr & alltoall)376 size_t FindAlltoallIndex(const std::vector<CNodePtr> &origin_nodes_topological, const CNodePtr &alltoall) {
377   int64_t idx = FindNodeIndex(origin_nodes_topological, alltoall);
378   if (idx == -1) {
379     MS_LOG(INTERNAL_EXCEPTION) << "Can not find alltoall node in origin_nodes_topological";
380   }
381   return LongToSize(idx);
382 }
383 
ScaleShapeValueNode(const AnfNodePtr & old_shape_node,size_t scale_dim,int64_t scale_factor)384 ValueNodePtr ScaleShapeValueNode(const AnfNodePtr &old_shape_node, size_t scale_dim, int64_t scale_factor) {
385   if (scale_factor == 0) {
386     MS_LOG(INTERNAL_EXCEPTION) << "scale_factor should not be zero.";
387   }
388   auto shape_value_node = old_shape_node->cast<ValueNodePtr>();
389   auto value_ptr = shape_value_node->value();
390   std::vector<ValuePtr> value_ptr_vec = value_ptr->cast<ValueTuplePtr>()->value();
391   ShapeVector new_shape;
392   for (size_t i = 0; i < value_ptr_vec.size(); i++) {
393     auto shape_value = GetValue<int64_t>(value_ptr_vec[i]);
394     if (i == scale_dim) {
395       shape_value /= scale_factor;
396     }
397     new_shape.push_back(shape_value);
398   }
399   return NewValueNode(MakeValue(new_shape));
400 }
401 
FindCNodesAmongAlltoall(const std::vector<CNodePtr> & origin_nodes_topological,const CNodePtrPair & alltoall_pair)402 const std::vector<CNodePtr> FindCNodesAmongAlltoall(const std::vector<CNodePtr> &origin_nodes_topological,
403                                                     const CNodePtrPair &alltoall_pair) {
404   auto front_alltoall = alltoall_pair.first;
405   auto back_alltoall = alltoall_pair.second;
406   size_t front_alltoall_idx = FindAlltoallIndex(origin_nodes_topological, front_alltoall);
407   size_t back_alltoall_idx = FindAlltoallIndex(origin_nodes_topological, back_alltoall);
408   std::vector<CNodePtr> cnodes;
409   for (size_t i = front_alltoall_idx + 1; i < back_alltoall_idx; i++) {
410     cnodes.push_back(origin_nodes_topological[i]);
411   }
412   return cnodes;
413 }
414 
CloneScaledGraph(const std::vector<CNodePtr> & old_cnodes,const AnfNodePtr & input_node,size_t scale_factor,GpeaInfo * gpea_info,std::vector<AnfNodePtr> * new_nodes)415 void CloneScaledGraph(const std::vector<CNodePtr> &old_cnodes, const AnfNodePtr &input_node, size_t scale_factor,
416                       GpeaInfo *gpea_info, std::vector<AnfNodePtr> *new_nodes) {
417   mindspore::HashMap<CNodePtr, CNodePtr> cnode_map;
418   auto input_cnode = input_node->cast<CNodePtr>();
419   auto old_input_node = old_cnodes[0]->input(1);
420   auto old_input_cnode = old_input_node->cast<CNodePtr>();
421   MS_EXCEPTION_IF_NULL(old_input_cnode);
422   cnode_map[old_input_cnode] = input_cnode;
423 
424   size_t reshape_cnt = 0;
425   std::vector<uint32_t> reshape_scale_axis = gpea_info->GetReshapeScaleAxisVec();
426   for (size_t i = 0; i < old_cnodes.size(); i++) {
427     auto cnode = old_cnodes[i];
428     MS_LOG(DEBUG) << "node in " << i << " " << GetCNodePrimitive(cnode)->name();
429     if (IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
430       new_nodes->push_back(cnode);  // reuse old Load node to not increase device memory
431       continue;
432     }
433 
434     // clone inputs
435     std::vector<AnfNodePtr> new_inputs;
436     auto inputs = cnode->inputs();
437     for (size_t j = 0; j < inputs.size(); j++) {
438       auto input = inputs[j];
439       if (input->isa<CNode>()) {
440         auto curr_input_cnode = input->cast<CNodePtr>();
441         CNodePtr new_cnode;
442         if (IsPrimitiveCNode(curr_input_cnode, prim::kPrimLoad)) {
443           new_cnode = curr_input_cnode;
444         } else {
445           new_cnode = cnode_map[curr_input_cnode];
446         }
447         auto new_anf_node = new_cnode->cast<AnfNodePtr>();
448         new_inputs.push_back(new_anf_node);
449       } else if (input->isa<ValueNode>()) {
450         ValueNodePtr new_value_node = NewValueNode(GetValueNode(input));
451         new_inputs.push_back(new_value_node);
452       } else if (input->isa<Parameter>()) {
453         new_inputs.push_back(input);
454       }
455     }
456 
457     // scale reshape shape value
458     if (IsPrimitiveCNode(cnode, prim::kPrimReshape)) {
459       new_inputs[kIndex2] = ScaleShapeValueNode(new_inputs[kIndex2], reshape_scale_axis[reshape_cnt], scale_factor);
460       reshape_cnt += 1;
461     }
462 
463     // create CNode
464     auto new_cnode = input_node->func_graph()->NewCNode(new_inputs);
465     MS_EXCEPTION_IF_NULL(new_cnode);
466     new_cnode->set_scope(cnode->scope());
467     cnode_map[cnode] = new_cnode;
468     new_nodes->push_back(new_cnode);
469   }
470 }
471 
InsertDependOnBranches(const FuncGraphManagerPtr & manager,const std::vector<std::vector<AnfNodePtr>> & front_nodes,const std::vector<std::vector<AnfNodePtr>> & back_nodes)472 void InsertDependOnBranches(const FuncGraphManagerPtr &manager, const std::vector<std::vector<AnfNodePtr>> &front_nodes,
473                             const std::vector<std::vector<AnfNodePtr>> &back_nodes) {
474   size_t comm_node_idx = 0;
475   for (size_t i = 0; i < front_nodes[0].size(); i++) {
476     auto cnode = front_nodes[0][i]->cast<CNodePtr>();
477     if (IsPrimitiveCNode(cnode, prim::kPrimNeighborExchange)) {
478       comm_node_idx = i;
479       break;
480     }
481   }
482 
483   for (size_t branch_idx = 0; branch_idx < front_nodes.size() - 1; branch_idx++) {
484     auto prev_node = front_nodes[branch_idx][comm_node_idx];
485     auto node = front_nodes[branch_idx][comm_node_idx + 1];
486     auto add_node = back_nodes[branch_idx + 1].front();
487     // graph branch execution is reverse, so former branch depends on latter branch
488     std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), prev_node, add_node};
489     auto depend = node->func_graph()->NewCNode(depend_inputs);
490     MS_EXCEPTION_IF_NULL(depend);
491     depend->set_abstract(prev_node->abstract()->Clone());
492     manager->SetEdge(node, 1, depend);
493   }
494 }
495 
CreateReplaceGraph(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & origin_nodes_topological,const CNodePtrPair & alltoall_pair,GpeaInfo * gpea_info)496 CNodePtr CreateReplaceGraph(const FuncGraphManagerPtr &manager, const std::vector<CNodePtr> &origin_nodes_topological,
497                             const CNodePtrPair &alltoall_pair, GpeaInfo *gpea_info) {
498   auto front_alltoall = alltoall_pair.first;
499   auto back_alltoall = alltoall_pair.second;
500   auto graph_input = front_alltoall->input(1);
501 
502   // split input into several branch
503   size_t front_split_dim = GetSplitDimFromAlltoall(front_alltoall);
504   size_t front_concat_dim = GetConcatDimFromAlltoall(front_alltoall);
505   size_t split_num = LongToSize(gpea_info->GetGroupNum());
506   auto split = NewSplitNode(graph_input, front_split_dim, split_num);
507 
508   // clone several branch calculation graph
509   size_t back_split_dim = GetSplitDimFromAlltoall(back_alltoall);
510   size_t back_concat_dim = GetConcatDimFromAlltoall(back_alltoall);
511   auto back_input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(back_alltoall, 0);
512   if (split_num == 0) {
513     MS_LOG(INTERNAL_EXCEPTION) << "split_num should not be zero.";
514   }
515   back_input_shape[back_split_dim] /= SizeToLong(split_num);
516   auto back_input_dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(back_alltoall, 0);
517   std::vector<int64_t> send_group_ranks = gpea_info->GetSendGroupRanks();
518   std::vector<std::vector<AnfNodePtr>> front_comm_branches;
519   std::vector<std::vector<AnfNodePtr>> back_comm_branches;
520   std::vector<AnfNodePtr> branch_output_nodes;
521   auto old_calc_cnodes = FindCNodesAmongAlltoall(origin_nodes_topological, alltoall_pair);
522   for (size_t i = 0; i < split_num; i++) {
523     auto send_rank_ids = gpea_info->GetSendRankIds(i);
524     auto recv_rank_ids = gpea_info->GetRecvRankIds(i);
525     size_t split_branch_idx = LongToSize(send_group_ranks[i]);
526     auto getitem = NewTupleGetItemNode(split, split_branch_idx);
527     // create first neighbor exchange nodes
528     std::vector<AnfNodePtr> front_neighbor_exchange_nodes;
529     CreateNeighborExchangeNodes(getitem, front_split_dim, front_concat_dim, send_rank_ids, recv_rank_ids,
530                                 &front_neighbor_exchange_nodes);
531     front_comm_branches.push_back(front_neighbor_exchange_nodes);
532     // clone calculation nodes
533     std::vector<AnfNodePtr> new_calc_nodes;
534     CloneScaledGraph(old_calc_cnodes, front_neighbor_exchange_nodes.back(), split_num, gpea_info, &new_calc_nodes);
535     MS_LOG(DEBUG) << "Create calculation done";
536     // create second neighbor exchange nodes
537     auto back_comm_split =
538       NewSplitNode(new_calc_nodes.back(), back_split_dim, recv_rank_ids.size(), back_input_shape, back_input_dtype);
539     std::vector<AnfNodePtr> back_neighbor_exchange_nodes;
540     back_neighbor_exchange_nodes.push_back(back_comm_split);
541     CreateNeighborExchangeNodes(back_comm_split, back_split_dim, back_concat_dim, recv_rank_ids, send_rank_ids,
542                                 &back_neighbor_exchange_nodes);
543     branch_output_nodes.push_back(back_neighbor_exchange_nodes.back());
544     back_comm_branches.push_back(back_neighbor_exchange_nodes);
545   }
546   MS_LOG(DEBUG) << "Create multi branch done";
547   InsertDependOnBranches(manager, front_comm_branches, back_comm_branches);
548   MS_LOG(DEBUG) << "InsertDependOnBranches done";
549 
550   // concat several branch into one branch
551   auto maketuple = NewMakeTupleNode(branch_output_nodes);
552   auto concat = NewConcatNode(maketuple, back_concat_dim, split_num);
553   auto reorder_split = NewSplitNode(concat, back_concat_dim, split_num);
554   std::vector<AnfNodePtr> reorder_getitem_nodes;
555   std::vector<int64_t> sort_idx = gpea_info->GetSortedInputsIdx();
556   MakeSortedSplitGetItemNodes(reorder_split, sort_idx, &reorder_getitem_nodes);
557   auto reorder_maketuple = NewMakeTupleNode(reorder_getitem_nodes);
558   auto reorder_concat = NewConcatNode(reorder_maketuple, back_concat_dim, split_num);
559   MS_LOG(DEBUG) << "Create replace graph done";
560   return reorder_concat;
561 }
562 
CreateAndReplaceAlltoall(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & origin_nodes_topological,const CNodePtrPair & alltoall_pair,GpeaInfo * gpea_info)563 void CreateAndReplaceAlltoall(const FuncGraphManagerPtr &manager, const std::vector<CNodePtr> &origin_nodes_topological,
564                               const CNodePtrPair &alltoall_pair, GpeaInfo *gpea_info) {
565   auto cnode = CreateReplaceGraph(manager, origin_nodes_topological, alltoall_pair, gpea_info);
566   (void)manager->Replace(alltoall_pair.second, cnode);
567 }
568 
CreateAndReplaceGraph(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & origin_nodes_topological,const std::vector<CNodePtrPair> & alltoall_pairs,GpeaInfo * gpea_info)569 void CreateAndReplaceGraph(const FuncGraphManagerPtr &manager, const std::vector<CNodePtr> &origin_nodes_topological,
570                            const std::vector<CNodePtrPair> &alltoall_pairs, GpeaInfo *gpea_info) {
571   for (size_t i = 0; i < alltoall_pairs.size(); i++) {
572     CreateAndReplaceAlltoall(manager, origin_nodes_topological, alltoall_pairs[i], gpea_info);
573   }
574 }
575 
CheckReshapeScaleAxis(const std::vector<CNodePtr> & origin_nodes_topological,const CNodePtrPair & alltoall_pair,GpeaInfo * gpea_info)576 void CheckReshapeScaleAxis(const std::vector<CNodePtr> &origin_nodes_topological, const CNodePtrPair &alltoall_pair,
577                            GpeaInfo *gpea_info) {
578   size_t front_alltoall_idx = FindAlltoallIndex(origin_nodes_topological, alltoall_pair.first);
579   size_t back_alltoall_idx = FindAlltoallIndex(origin_nodes_topological, alltoall_pair.second);
580   size_t reshape_cnode_num = 0;
581   for (size_t i = front_alltoall_idx + 1; i < back_alltoall_idx; i++) {
582     auto cnode = origin_nodes_topological[i];
583     if (IsPrimitiveCNode(cnode, prim::kPrimReshape)) {
584       reshape_cnode_num += 1;
585     }
586   }
587 
588   std::vector<uint32_t> reshape_scale_axis = gpea_info->GetReshapeScaleAxisVec();
589   size_t axis_length = reshape_scale_axis.size();
590   MS_LOG(DEBUG) << "The graph has " << reshape_cnode_num << " reshape nodes, which will be scaled on certain axis";
591   if (axis_length == reshape_cnode_num) {
592     MS_LOG(DEBUG) << "'reshape_scale_axis' " << reshape_scale_axis;
593     return;
594   } else {
595     MS_LOG(DEBUG) << "'reshape_scale_axis' has " << axis_length
596                   << " element, its length should be same as the number of scaled reshape node";
597     MS_LOG(DEBUG) << "Show graph nodes start";
598     for (size_t i = front_alltoall_idx; i < back_alltoall_idx + 1; i++) {
599       auto cnode = origin_nodes_topological[i];
600       auto prim_name = GetCNodePrimitive(cnode)->name();
601       auto scope_name = cnode->scope()->name();
602       std::string input_shape_string = "";
603       std::string output_shape_string = "";
604       std::string space = " ";
605       size_t output_num = AnfUtils::GetOutputTensorNum(cnode);
606       size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
607       for (size_t j = 0; j < input_num; j++) {
608         auto shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, j);
609         if (shape.size() > 0) {
610           input_shape_string += space;
611           input_shape_string += parallel::ShapeToString(shape);
612         }
613       }
614       for (size_t j = 0; j < output_num; j++) {
615         auto shape = common::AnfAlgo::GetOutputInferShape(cnode, j);
616         if (shape.size() > 0) {
617           output_shape_string += space;
618           output_shape_string += parallel::ShapeToString(shape);
619         }
620       }
621 
622       MS_LOG(DEBUG) << "Node: " << prim_name;
623       MS_LOG(DEBUG) << "Scope: " << scope_name;
624       MS_LOG(DEBUG) << "Input shapes: " << input_shape_string;
625       MS_LOG(DEBUG) << "Output shapes: " << output_shape_string;
626     }
627     MS_LOG(DEBUG) << "Show graph nodes end";
628     MS_LOG(EXCEPTION) << "The size of 'reshape scale axis' is not equal to reshape nodes number in graph. There are "
629                       << reshape_cnode_num
630                       << " reshape nodes to be scaled. Please set correct scale axis for each reshape node";
631   }
632 }
633 
CheckUserSettings(const FuncGraphPtr & fg,GpeaInfo * gpea_info)634 bool CheckUserSettings(const FuncGraphPtr &fg, GpeaInfo *gpea_info) {
635   if (parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kSemiAutoParallel) {
636     MS_LOG(DEBUG) << "To activate the pass, set_auto_parallel_context 'parallel_mode' should be 'semi_auto_parallel'";
637     return false;
638   }
639 
640   if (!parallel::ParallelContext::GetInstance()->enable_all2all()) {
641     MS_LOG(DEBUG) << "To activate the pass, set_auto_parallel_context 'enable_alltoall' should be true";
642     return false;
643   }
644 
645   if (fg->has_flag(kTraining)) {
646     MS_LOG(DEBUG) << "To activate the pass, network 'set_train' should be false";
647     return false;
648   }
649 
650   gpea_info->DisplayInfo();
651 
652   int64_t gpea_num = gpea_info->GetGroupNum();
653   if (gpea_num <= 1 || LongToSize(gpea_num) == GetDeviceNum()) {
654     MS_LOG(DEBUG) << "To activate the pass, gpea_num " << gpea_num << " should between (1, " << GetDeviceNum() << ")";
655     return false;
656   }
657 
658   if (GetDeviceNum() % LongToSize(gpea_num) != 0) {
659     MS_LOG(DEBUG) << "To activate the pass, device num " << GetDeviceNum() << " should be divisible by gpea_num "
660                   << LongToSize(gpea_num);
661     return false;
662   }
663   return true;
664 }
665 }  // namespace
666 
GetDeviceNum()667 size_t GetDeviceNum() { return parallel::g_device_manager->DeviceNum(); }
668 
GetGlobalRankID()669 size_t GetGlobalRankID() { return LongToSize(parallel::g_device_manager->global_rank()); }
670 
SetGroupedPairwiseExchangeAllToAll(const pipeline::ResourcePtr & resource)671 void SetGroupedPairwiseExchangeAllToAll(const pipeline::ResourcePtr &resource) {
672   if (parallel::g_device_manager == nullptr) {
673     MS_LOG(INFO) << "parallel::g_device_manager is not initialized.";
674     return;
675   }
676 
677   MS_EXCEPTION_IF_NULL(resource);
678   FuncGraphPtr func_graph = resource->func_graph();
679   MS_EXCEPTION_IF_NULL(func_graph);
680 
681   auto gpea_info = GpeaInfo();
682   if (!CheckUserSettings(func_graph, &gpea_info)) {
683     return;
684   }
685 
686   auto manager = func_graph->manager();
687   MS_EXCEPTION_IF_NULL(manager);
688   std::list<CNodePtr> orders = func_graph->GetOrderedCnodes();
689   std::vector<CNodePtr> origin_nodes_topological(orders.cbegin(), orders.cend());
690 
691   std::vector<CNodePtrPair> alltoall_pairs;
692   FindAlltoallNodePairs(manager, origin_nodes_topological, &alltoall_pairs);
693   MS_LOG(DEBUG) << "Find alltoall_pairs num: " << alltoall_pairs.size();
694   if (alltoall_pairs.size() == 0) {
695     MS_LOG(WARNING) << "Not find alltoall_pairs, skip the pass";
696     return;
697   }
698 
699   CheckReshapeScaleAxis(origin_nodes_topological, alltoall_pairs[0], &gpea_info);
700 
701   CreateAndReplaceGraph(manager, origin_nodes_topological, alltoall_pairs, &gpea_info);
702   MS_LOG(DEBUG) << "CreateAndReplaceGraph done";
703 
704   // Renormalize, infer shape and set abstract for all nodes in graph
705   abstract::AbstractBasePtrList args_abs;
706   auto parameters = func_graph->parameters();
707   (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_abs),
708                        [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
709   FuncGraphPtr new_fg = pipeline::Renormalize(resource, func_graph, args_abs);
710   resource->set_func_graph(new_fg);
711   resource->set_args_abs(args_abs);
712   MS_LOG(DEBUG) << "Renormalize done";
713   return;
714 }
715 }  // namespace opt
716 }  // namespace mindspore
717