• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/pass/optimize_parallel_allgather_comm.h"
18 #include <memory>
19 #include <vector>
20 #include <string>
21 #include <list>
22 #include <unordered_map>
23 #include <algorithm>
24 #include "mindspore/core/ops/other_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "frontend/optimizer/optimizer.h"
27 #include "frontend/parallel/step_parallel_utils.h"
28 #include "frontend/parallel/graph_util/graph_info.h"
29 
30 namespace mindspore {
31 namespace parallel {
32 namespace {
33 
IsDTypeBitsDecrease(TypeId a,TypeId b)34 bool IsDTypeBitsDecrease(TypeId a, TypeId b) {
35   return a == kNumberTypeFloat32 && (b == kNumberTypeFloat16 || b == kNumberTypeBFloat16);
36 }
37 
MoveCastBehindAllGather(const FuncGraphPtr & func_graph,const CNodePtr & all_gather_cnode,const CNodePtr & cast_cnode)38 void MoveCastBehindAllGather(const FuncGraphPtr &func_graph, const CNodePtr &all_gather_cnode,
39                              const CNodePtr &cast_cnode) {
40   MS_EXCEPTION_IF_NULL(func_graph);
41   MS_EXCEPTION_IF_NULL(all_gather_cnode);
42   MS_EXCEPTION_IF_NULL(cast_cnode);
43   auto all_gather_dtype = common::AnfAlgo::GetOutputInferDataType(all_gather_cnode, kIndex0);
44   auto cast_dtype = common::AnfAlgo::GetOutputInferDataType(cast_cnode, kIndex0);
45   if (!IsDTypeBitsDecrease(all_gather_dtype, cast_dtype)) {
46     return;
47   }
48 
49   auto manager = func_graph->manager();
50   MS_EXCEPTION_IF_NULL(manager);
51   auto cast_input_node = cast_cnode->input(kIndex1);
52   auto cast_input_node_users = GetOutputNodesWithFilter(cast_input_node, [](const AnfNodePtr &node) {
53     return IsOneOfPrimitiveCNode(node, {prim::kPrimMakeTuple, prim::kPrimDepend});
54   });
55   for (auto &cast_input_node_user_pair : cast_input_node_users) {
56     if (cast_input_node_user_pair.first != cast_cnode &&
57         !IsPrimitiveCNode(cast_input_node_user_pair.first, prim::kPrimUpdateState)) {
58       return;
59     }
60   }
61 
62   // Get operator list from all_gather to cast
63   AnfNodePtrList op_list;
64   auto cur_node = cast_input_node;
65   while (cur_node != all_gather_cnode) {
66     op_list.push_back(cur_node);
67     auto cur_cnode = cur_node->cast<CNodePtr>();
68     if (cur_cnode == nullptr) {
69       break;
70     }
71     cur_node = cur_cnode->input(kIndex1);
72   }
73   if (cur_node != all_gather_cnode) {
74     MS_LOG(DEBUG) << "Get op list from all_gather to cast failed.";
75     return;
76   }
77   op_list.push_back(cur_node);
78 
79   auto cast_node_users = manager->node_users()[cast_cnode];
80 
81   for (const auto &cast_next_node_user_pair : cast_node_users) {
82     auto next_cnode = cast_next_node_user_pair.first->cast<CNodePtr>();
83     MS_EXCEPTION_IF_NULL(next_cnode);
84     auto next_index = cast_next_node_user_pair.second;
85     manager->SetEdge(next_cnode, next_index, cast_input_node);
86   }
87 
88   auto all_gather_input_node = all_gather_cnode->input(kIndex1);
89   manager->SetEdge(cast_cnode, kIndex1, all_gather_input_node);
90   manager->SetEdge(all_gather_cnode, kIndex1, cast_cnode);
91 
92   // Update abstract from cast to all_gather
93   auto new_cast_abs = std::make_shared<abstract::AbstractTensor>(TypeIdToType(cast_dtype),
94                                                                  cast_cnode->input(kIndex1)->abstract()->GetShape());
95   cast_cnode->set_abstract(new_cast_abs);
96   for (auto node : op_list) {
97     auto abs = std::make_shared<abstract::AbstractTensor>(TypeIdToType(cast_dtype), node->abstract()->GetShape());
98     node->set_abstract(abs);
99   }
100   return;
101 }
102 }  // namespace
103 
OptimizeParallelAllGatherComm(const FuncGraphPtr & graph)104 void OptimizeParallelAllGatherComm(const FuncGraphPtr &graph) {
105   auto manager = graph->manager();
106   for (const auto &each_graph : manager->func_graphs()) {
107     std::list<CNodePtr> graph_orders = each_graph->GetOrderedCnodes();
108     std::vector<CNodePtr> origin_nodes_topological(graph_orders.cbegin(), graph_orders.cend());
109     for (const auto &node : origin_nodes_topological) {
110       if (!IsPrimitiveCNode(node, prim::kPrimAllGather) || !common::AnfAlgo::IsFromParallelOptimizer(node)) {
111         continue;
112       }
113       auto all_gather_cnode = node->cast<CNodePtr>();
114       auto all_gather_node_user_list = GetOutputNodesWithFilter(all_gather_cnode, [](const AnfNodePtr &node) {
115         return IsOneOfPrimitiveCNode(node, {prim::kPrimLoad, prim::kPrimDepend});
116       });
117       for (auto next_node_pair : all_gather_node_user_list) {
118         if (IsPrimitiveCNode(next_node_pair.first, prim::kPrimCast)) {
119           MoveCastBehindAllGather(each_graph, all_gather_cnode, next_node_pair.first->cast<CNodePtr>());
120         }
121       }
122     }
123   }
124 }
125 }  // namespace parallel
126 }  // namespace mindspore
127