1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/pass/optimize_parallel_allgather_comm.h"
18 #include <memory>
19 #include <vector>
20 #include <string>
21 #include <list>
22 #include <unordered_map>
23 #include <algorithm>
24 #include "mindspore/core/ops/other_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "frontend/optimizer/optimizer.h"
27 #include "frontend/parallel/step_parallel_utils.h"
28 #include "frontend/parallel/graph_util/graph_info.h"
29
30 namespace mindspore {
31 namespace parallel {
32 namespace {
33
IsDTypeBitsDecrease(TypeId a,TypeId b)34 bool IsDTypeBitsDecrease(TypeId a, TypeId b) {
35 return a == kNumberTypeFloat32 && (b == kNumberTypeFloat16 || b == kNumberTypeBFloat16);
36 }
37
MoveCastBehindAllGather(const FuncGraphPtr & func_graph,const CNodePtr & all_gather_cnode,const CNodePtr & cast_cnode)38 void MoveCastBehindAllGather(const FuncGraphPtr &func_graph, const CNodePtr &all_gather_cnode,
39 const CNodePtr &cast_cnode) {
40 MS_EXCEPTION_IF_NULL(func_graph);
41 MS_EXCEPTION_IF_NULL(all_gather_cnode);
42 MS_EXCEPTION_IF_NULL(cast_cnode);
43 auto all_gather_dtype = common::AnfAlgo::GetOutputInferDataType(all_gather_cnode, kIndex0);
44 auto cast_dtype = common::AnfAlgo::GetOutputInferDataType(cast_cnode, kIndex0);
45 if (!IsDTypeBitsDecrease(all_gather_dtype, cast_dtype)) {
46 return;
47 }
48
49 auto manager = func_graph->manager();
50 MS_EXCEPTION_IF_NULL(manager);
51 auto cast_input_node = cast_cnode->input(kIndex1);
52 auto cast_input_node_users = GetOutputNodesWithFilter(cast_input_node, [](const AnfNodePtr &node) {
53 return IsOneOfPrimitiveCNode(node, {prim::kPrimMakeTuple, prim::kPrimDepend});
54 });
55 for (auto &cast_input_node_user_pair : cast_input_node_users) {
56 if (cast_input_node_user_pair.first != cast_cnode &&
57 !IsPrimitiveCNode(cast_input_node_user_pair.first, prim::kPrimUpdateState)) {
58 return;
59 }
60 }
61
62 // Get operator list from all_gather to cast
63 AnfNodePtrList op_list;
64 auto cur_node = cast_input_node;
65 while (cur_node != all_gather_cnode) {
66 op_list.push_back(cur_node);
67 auto cur_cnode = cur_node->cast<CNodePtr>();
68 if (cur_cnode == nullptr) {
69 break;
70 }
71 cur_node = cur_cnode->input(kIndex1);
72 }
73 if (cur_node != all_gather_cnode) {
74 MS_LOG(DEBUG) << "Get op list from all_gather to cast failed.";
75 return;
76 }
77 op_list.push_back(cur_node);
78
79 auto cast_node_users = manager->node_users()[cast_cnode];
80
81 for (const auto &cast_next_node_user_pair : cast_node_users) {
82 auto next_cnode = cast_next_node_user_pair.first->cast<CNodePtr>();
83 MS_EXCEPTION_IF_NULL(next_cnode);
84 auto next_index = cast_next_node_user_pair.second;
85 manager->SetEdge(next_cnode, next_index, cast_input_node);
86 }
87
88 auto all_gather_input_node = all_gather_cnode->input(kIndex1);
89 manager->SetEdge(cast_cnode, kIndex1, all_gather_input_node);
90 manager->SetEdge(all_gather_cnode, kIndex1, cast_cnode);
91
92 // Update abstract from cast to all_gather
93 auto new_cast_abs = std::make_shared<abstract::AbstractTensor>(TypeIdToType(cast_dtype),
94 cast_cnode->input(kIndex1)->abstract()->GetShape());
95 cast_cnode->set_abstract(new_cast_abs);
96 for (auto node : op_list) {
97 auto abs = std::make_shared<abstract::AbstractTensor>(TypeIdToType(cast_dtype), node->abstract()->GetShape());
98 node->set_abstract(abs);
99 }
100 return;
101 }
102 } // namespace
103
OptimizeParallelAllGatherComm(const FuncGraphPtr & graph)104 void OptimizeParallelAllGatherComm(const FuncGraphPtr &graph) {
105 auto manager = graph->manager();
106 for (const auto &each_graph : manager->func_graphs()) {
107 std::list<CNodePtr> graph_orders = each_graph->GetOrderedCnodes();
108 std::vector<CNodePtr> origin_nodes_topological(graph_orders.cbegin(), graph_orders.cend());
109 for (const auto &node : origin_nodes_topological) {
110 if (!IsPrimitiveCNode(node, prim::kPrimAllGather) || !common::AnfAlgo::IsFromParallelOptimizer(node)) {
111 continue;
112 }
113 auto all_gather_cnode = node->cast<CNodePtr>();
114 auto all_gather_node_user_list = GetOutputNodesWithFilter(all_gather_cnode, [](const AnfNodePtr &node) {
115 return IsOneOfPrimitiveCNode(node, {prim::kPrimLoad, prim::kPrimDepend});
116 });
117 for (auto next_node_pair : all_gather_node_user_list) {
118 if (IsPrimitiveCNode(next_node_pair.first, prim::kPrimCast)) {
119 MoveCastBehindAllGather(each_graph, all_gather_cnode, next_node_pair.first->cast<CNodePtr>());
120 }
121 }
122 }
123 }
124 }
125 } // namespace parallel
126 } // namespace mindspore
127