1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/pass/merge_comm.h"
18 #include <string>
19 #include <vector>
20 #include <unordered_map>
21 #include <utility>
22 #include <memory>
23 #include "mindspore/core/ops/other_ops.h"
24 #include "mindspore/core/ops/array_ops.h"
25 #include "frontend/optimizer/optimizer.h"
26 #include "include/common/utils/parallel_context.h"
27 #include "frontend/parallel/step_parallel_utils.h"
28 #include "frontend/parallel/graph_util/graph_info.h"
29 #include "frontend/parallel/step_parallel.h"
30 #include "utils/log_adapter.h"
31
32 namespace mindspore {
33 namespace parallel {
34 namespace {
GetMakeTupleValue(const CNodePtr & cnode)35 static Shape GetMakeTupleValue(const CNodePtr &cnode) {
36 MS_EXCEPTION_IF_NULL(cnode);
37 MS_EXCEPTION_IF_CHECK_FAIL(cnode->inputs().size() == kSizeThree, "Input size of Reshape is not 3.");
38 auto make_tuple = cnode->input(kIndex2);
39 auto make_tuple_cnode = make_tuple->cast<CNodePtr>();
40 Shape ret;
41 for (size_t i = 1; i < make_tuple_cnode->size(); ++i) {
42 auto input_node = make_tuple_cnode->input(i);
43 MS_EXCEPTION_IF_NULL(input_node);
44 auto value_node = GetValueNode(input_node);
45 if (value_node != nullptr && value_node->isa<Int64Imm>()) {
46 auto shape_ele = GetValue<int64_t>(value_node);
47 ret.push_back(shape_ele);
48 } else {
49 ret.push_back(-1);
50 }
51 }
52 return ret;
53 }
54
IsSameTargetDynamicShape(const CNodePtr & reshape_node_a,const CNodePtr & reshape_node_b)55 static bool IsSameTargetDynamicShape(const CNodePtr &reshape_node_a, const CNodePtr &reshape_node_b) {
56 MS_EXCEPTION_IF_NULL(reshape_node_a);
57 MS_EXCEPTION_IF_NULL(reshape_node_b);
58 MS_EXCEPTION_IF_CHECK_FAIL(reshape_node_a->inputs().size() == kSizeThree, "Input size of Reshape is not 3.");
59 MS_EXCEPTION_IF_CHECK_FAIL(reshape_node_b->inputs().size() == kSizeThree, "Input size of Reshape is not 3.");
60 if (!IsPrimitiveCNode(reshape_node_a->input(kIndex2), prim::kPrimMakeTuple)) {
61 MS_LOG(WARNING) << "the dst shape of reshape node a is not make_tuple for dynamic shape";
62 return false;
63 }
64
65 if (!IsPrimitiveCNode(reshape_node_b->input(kIndex2), prim::kPrimMakeTuple)) {
66 MS_LOG(WARNING) << "the dst shape of reshape node b is not make_tuple for dynamic shape";
67 return false;
68 }
69
70 Shape node_a_shape = GetMakeTupleValue(reshape_node_a);
71 Shape node_b_shape = GetMakeTupleValue(reshape_node_b);
72 MS_LOG(INFO) << "the node a shape is " << node_a_shape << ", the node b shape is " << node_b_shape;
73 if (std::count(node_a_shape.cbegin(), node_a_shape.cend(), -1) > 1) {
74 return false;
75 }
76 if (std::count(node_b_shape.cbegin(), node_b_shape.cend(), -1) > 1) {
77 return false;
78 }
79
80 return (node_a_shape == node_b_shape);
81 }
82
IsSameTargetShape(const CNodePtr & reshape_node_a,const CNodePtr & reshape_node_b)83 bool IsSameTargetShape(const CNodePtr &reshape_node_a, const CNodePtr &reshape_node_b) {
84 MS_EXCEPTION_IF_CHECK_FAIL(reshape_node_a->inputs().size() == kSizeThree, "Input size of Reshape is not 3.");
85 MS_EXCEPTION_IF_CHECK_FAIL(reshape_node_b->inputs().size() == kSizeThree, "Input size of Reshape is not 3.");
86 if (!reshape_node_a->input(kIndex2)->isa<ValueNode>() || !reshape_node_b->input(kIndex2)->isa<ValueNode>()) {
87 return IsSameTargetDynamicShape(reshape_node_a, reshape_node_b);
88 }
89 auto value_ptr_a = reshape_node_a->input(kIndex2)->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
90 auto value_ptr_b = reshape_node_b->input(kIndex2)->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
91 if (value_ptr_a.size() != value_ptr_b.size()) {
92 return false;
93 }
94 for (size_t i = 0; i < value_ptr_a.size(); i++) {
95 int64_t cur_shape_a = GetValue<int64_t>(value_ptr_a.at(i));
96 int64_t cur_shape_b = GetValue<int64_t>(value_ptr_b.at(i));
97 if (cur_shape_a != cur_shape_b) {
98 return false;
99 }
100 }
101 return true;
102 }
103
MergeAllGather(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager)104 void MergeAllGather(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
105 std::unordered_map<CNodePtr, std::vector<CNodePtr>> allgather_input_map;
106 for (const auto &node : all_nodes) {
107 if (!IsPrimitiveCNode(node, prim::kPrimAllGather)) {
108 continue;
109 }
110 auto allgather_cnode = node->cast<CNodePtr>();
111 auto pre_node = GetInputNodeWithFilter(allgather_cnode->input(kIndex1), [&](const CNodePtr &cnode) {
112 bool filter = IsPrimitiveCNode(cnode, prim::kPrimReshape);
113 return std::make_pair(filter, 1);
114 });
115 if (!IsPrimitiveCNode(pre_node)) {
116 continue;
117 }
118 auto pre_cnode = pre_node->cast<CNodePtr>();
119 allgather_input_map[pre_cnode].push_back(allgather_cnode);
120 }
121 for (const auto &allgather_pairs : allgather_input_map) {
122 if (allgather_pairs.second.size() <= 1) {
123 continue;
124 }
125 auto allgather_list = allgather_pairs.second;
126 auto allgather_cnode1 = allgather_list.front();
127 auto is_same_allgather =
128 std::all_of(allgather_list.begin(), allgather_list.end(), [&allgather_cnode1](const CNodePtr &allgather_cnode2) {
129 auto ag1_prim = GetCNodePrimitive(allgather_cnode1);
130 auto ag2_prim = GetCNodePrimitive(allgather_cnode2);
131 auto group1 = ag1_prim->GetAttr(GROUP);
132 auto group2 = ag2_prim->GetAttr(GROUP);
133 if (!group1 || !group2) {
134 return false;
135 }
136 if (GetValue<std::string>(group1) != GetValue<std::string>(group2)) {
137 return false;
138 }
139 if (IsPrimitiveCNode(allgather_cnode1->input(kIndex1), prim::kPrimReshape) !=
140 IsPrimitiveCNode(allgather_cnode2->input(kIndex1), prim::kPrimReshape)) {
141 return false;
142 }
143 if (IsPrimitiveCNode(allgather_cnode1->input(kIndex1), prim::kPrimReshape) &&
144 IsPrimitiveCNode(allgather_cnode2->input(kIndex1), prim::kPrimReshape)) {
145 if (!IsSameTargetShape(allgather_cnode1->input(kIndex1)->cast<CNodePtr>(),
146 allgather_cnode2->input(kIndex1)->cast<CNodePtr>())) {
147 return false;
148 }
149 }
150 if (allgather_cnode1->func_graph() != allgather_cnode2->func_graph()) {
151 return false;
152 }
153 return true;
154 });
155 if (!is_same_allgather) {
156 MS_LOG(INFO) << "allgather nodes share the same input node:" << allgather_pairs.first->DebugString()
157 << " is not equal.";
158 continue;
159 }
160 auto ag0 = allgather_list.front();
161 for (const auto &ag : allgather_list) {
162 manager->Replace(ag, ag0);
163 }
164 }
165 }
166 } // namespace
167
MergeComm(const FuncGraphPtr & root,const opt::OptimizerPtr & optimizer)168 bool MergeComm(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
169 MS_EXCEPTION_IF_NULL(root);
170 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
171 auto graph_set = ForwardGraph(root);
172 // assume no change to graph
173 bool changes = false;
174 // control whether use model_parallel mode
175 if (!IsAutoParallelCareGraph(root) || (root->has_flag(MERGE_COMM_RUN_ONCE_ONLY)) || graph_set.size() < 1) {
176 return changes;
177 }
178 FuncGraphManagerPtr manager;
179 pipeline::ResourceBasePtr res;
180 if (optimizer == nullptr) {
181 manager = root->manager();
182 res = std::make_shared<pipeline::Resource>();
183 res->set_manager(manager);
184 } else {
185 res = optimizer->resource();
186 MS_EXCEPTION_IF_NULL(res);
187 manager = res->manager();
188 }
189
190 MS_EXCEPTION_IF_NULL(manager);
191 CNodePtr ret = root->get_return();
192 MS_EXCEPTION_IF_NULL(ret);
193 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
194 MergeAllGather(all_nodes, manager);
195 DumpGraph(root, std::string("merge_comm"));
196
197 // allreduce fusion only run once
198 root->set_flag(MERGE_COMM_RUN_ONCE_ONLY, true);
199 // Keep all func graph for parallel before save result.
200 SetReserved(root);
201 res->SetResult(pipeline::kStepParallelGraph, root);
202 return changes;
203 }
204 } // namespace parallel
205 } // namespace mindspore
206