1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/pass/overlap_opt_shard_in_pipeline.h"
18 #include <memory>
19 #include <vector>
20 #include <list>
21 #include <algorithm>
22 #include <string>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/other_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "frontend/parallel/ops_info/ops_utils.h"
27 #include "frontend/parallel/device_manager.h"
28 #include "include/common/utils/parallel_context.h"
29 #include "frontend/parallel/step_parallel_utils.h"
30 #include "include/common/utils/utils.h"
31 #include "include/common/utils/comm_manager.h"
32
33 namespace mindspore {
34 namespace parallel {
35 namespace {
is_allgather_comm_ops(const AnfNodePtr & node)36 inline bool is_allgather_comm_ops(const AnfNodePtr &node) {
37 static const std::vector<PrimitivePtr> kAllGatherOpsPrim = {prim::kPrimMicroStepAllGather,
38 prim::kPrimMiniStepAllGather, prim::kPrimAllGather};
39
40 for (const auto &prim : kAllGatherOpsPrim) {
41 if (IsPrimitiveCNode(node, prim)) {
42 auto allgather_instance_name = GetCNodePrimitive(node->cast<CNodePtr>())->instance_name();
43 if (allgather_instance_name.find(parallel::PARALLEL_OPTIMIZER) == std::string::npos) {
44 return false;
45 }
46 return true;
47 }
48 }
49 return false;
50 }
51
is_first_receive(const AnfNodePtr & node)52 inline bool is_first_receive(const AnfNodePtr &node) {
53 if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
54 auto recv_node = node->cast<CNodePtr>();
55 if (recv_node->HasPrimalAttr(kPrimalAttrForwardNodeName)) {
56 return false;
57 }
58 auto micro = GetValue<int64_t>(recv_node->GetPrimalAttr(parallel::MICRO));
59 if (micro != 0 || recv_node->HasPrimalAttr(parallel::PIPELINE_PARAM)) {
60 return false;
61 }
62 return true;
63 }
64 return false;
65 }
66 } // namespace
67
OverlapOptShardInPipeline(const FuncGraphPtr & graph)68 void OverlapOptShardInPipeline(const FuncGraphPtr &graph) {
69 auto context = MsContext::GetInstance();
70 MS_EXCEPTION_IF_NULL(context);
71 static const bool is_enable_ge = (context->backend_policy() == "ge");
72 if (is_enable_ge) {
73 return;
74 }
75 if (parallel::g_device_manager == nullptr) {
76 MS_LOG(INFO) << "parallel::g_device_manager is not initialized.";
77 return;
78 }
79 MS_EXCEPTION_IF_NULL(graph);
80 if (!IsTraining(graph->manager())) {
81 MS_LOG(INFO) << "Skip overlap in Evaluation.";
82 return;
83 }
84 auto manager = graph->manager();
85 MS_EXCEPTION_IF_NULL(manager);
86 if (!parallel::IsAutoParallelCareGraph(graph) ||
87 parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() <= 1 ||
88 parallel::ParallelContext::GetInstance()->grad_accumulation_shard()) {
89 return;
90 }
91 if (parallel::ParallelContext::GetInstance()->enable_fold_pipeline()) {
92 return;
93 }
94 std::list<CNodePtr> orders = graph->GetOrderedCnodes();
95 std::vector<CNodePtr> origin_nodes_topological(orders.cbegin(), orders.cend());
96 CNodePtr first_receive_cnode = nullptr;
97 for (auto &node : origin_nodes_topological) {
98 if (is_first_receive((node))) {
99 first_receive_cnode = node->cast<CNodePtr>();
100 first_receive_cnode->AddAttr(parallel::FIRST_RECEIVE, MakeValue(True));
101 }
102 }
103 if (first_receive_cnode == nullptr) {
104 return;
105 }
106 auto recv_users = manager->node_users()[first_receive_cnode];
107 if (recv_users.empty()) {
108 return;
109 }
110
111 std::vector<CNodePtr> opt_shard_allgather_list;
112 for (auto &node : origin_nodes_topological) {
113 MS_EXCEPTION_IF_NULL(node);
114 if (!is_allgather_comm_ops(node)) {
115 continue;
116 }
117 auto cnode_allgather = node->cast<CNodePtr>();
118 opt_shard_allgather_list.push_back(cnode_allgather);
119 auto allgather_prim = GetCNodePrimitive(cnode_allgather);
120 auto group_name = GetValue<std::string>(allgather_prim->GetAttr(parallel::GROUP));
121 if (group_name.find("parallel_optimizer") != std::string::npos) {
122 continue;
123 }
124 auto rank_ids = parallel::g_device_manager->FindRankListByHashName(group_name);
125 if (rank_ids.empty()) {
126 continue;
127 }
128 auto dev_list = parallel::g_device_manager->CreateDeviceListByRankList(rank_ids);
129 auto new_group_name = group_name + "_parallel_optimizer";
130 parallel::Group cur_device_list;
131 (void)parallel::g_device_manager->CreateGroup(new_group_name, dev_list, &cur_device_list);
132 (void)allgather_prim->AddAttr(parallel::GROUP, MakeValue<std::string>(new_group_name));
133 }
134 std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
135 (void)std::copy(opt_shard_allgather_list.begin(), opt_shard_allgather_list.end(),
136 std::back_inserter(make_tuple_inputs));
137 std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), first_receive_cnode,
138 graph->NewCNode(make_tuple_inputs)};
139 auto depend_node = graph->NewCNode(depend_inputs);
140 depend_node->set_abstract(first_receive_cnode->abstract()->Clone());
141 depend_node->AddAttr("RecAllGatherDepend", MakeValue(True));
142 (void)manager->Replace(first_receive_cnode, depend_node);
143 }
144
GetOptShardReduceScatter(const std::vector<AnfNodePtr> & all_nodes)145 static std::vector<CNodePtr> GetOptShardReduceScatter(const std::vector<AnfNodePtr> &all_nodes) {
146 std::vector<CNodePtr> reduce_scatters;
147 for (const auto &node : all_nodes) {
148 if (!IsPrimitiveCNode(node, prim::kPrimReduceScatter)) {
149 continue;
150 }
151 auto prim = GetCNodePrimitive(node);
152 MS_EXCEPTION_IF_NULL(prim);
153 auto instance_name = prim->instance_name();
154 if (instance_name.find(kAttrNeedAllGather) != std::string::npos) {
155 (void)reduce_scatters.emplace_back(node->cast<CNodePtr>());
156 }
157 }
158 return reduce_scatters;
159 }
160
OverlapOptShardGradInPipeline(const FuncGraphPtr & graph)161 void OverlapOptShardGradInPipeline(const FuncGraphPtr &graph) {
162 if (parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kSemiAutoParallel &&
163 parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kAutoParallel) {
164 return;
165 }
166 if (parallel::g_device_manager == nullptr) {
167 MS_LOG(INFO) << "parallel::g_device_manager is not initialized.";
168 return;
169 }
170 auto context = MsContext::GetInstance();
171 MS_EXCEPTION_IF_NULL(context);
172 auto is_kbk_mode = context->IsKByKExecutorMode();
173 if (!is_kbk_mode) {
174 return;
175 }
176 MS_EXCEPTION_IF_NULL(g_device_manager);
177 auto stage_num = g_device_manager->stage_num();
178 if (stage_num <= 1) {
179 return;
180 }
181 auto ret_after = graph->get_return();
182 MS_EXCEPTION_IF_NULL(ret_after);
183 auto all_nodes = TopoSort(ret_after, SuccDeeperSimple);
184 std::vector<CNodePtr> sends;
185 int64_t micro_size = 1;
186 CNodePtr last_send = nullptr;
187 for (const auto &node : all_nodes) {
188 if (!IsPrimitiveCNode(node, prim::kPrimSend)) {
189 continue;
190 }
191 auto cnode = node->cast<CNodePtr>();
192 if (!cnode->HasPrimalAttr(kPrimalAttrForwardNodeName)) {
193 continue;
194 }
195 sends.emplace_back(cnode);
196 auto micro_attr = cnode->GetPrimalAttr(MICRO);
197 MS_EXCEPTION_IF_NULL(micro_attr);
198 auto micro = GetValue<int64_t>(micro_attr) + 1;
199 if (micro > micro_size) {
200 micro_size = micro;
201 last_send = cnode;
202 }
203 }
204 if (last_send != nullptr) {
205 auto opt_shard_rs = GetOptShardReduceScatter(all_nodes);
206 if (opt_shard_rs.empty()) {
207 return;
208 }
209 auto manager = graph->manager();
210 for (auto rs : opt_shard_rs) {
211 std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), rs->input(1), last_send};
212 auto depend = graph->NewCNode(depend_input);
213 depend->AddPrimalAttr(PP_OPT_SHARD_CONTROL, MakeValue(1));
214 depend->set_abstract(rs->input(1)->abstract());
215 manager->SetEdge(rs, 1, depend);
216 }
217 }
218 }
219 } // namespace parallel
220 } // namespace mindspore
221