• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/pass/overlap_opt_shard_in_pipeline.h"
18 #include <memory>
19 #include <vector>
20 #include <list>
21 #include <algorithm>
22 #include <string>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/other_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "frontend/parallel/ops_info/ops_utils.h"
27 #include "frontend/parallel/device_manager.h"
28 #include "include/common/utils/parallel_context.h"
29 #include "frontend/parallel/step_parallel_utils.h"
30 #include "include/common/utils/utils.h"
31 #include "include/common/utils/comm_manager.h"
32 
33 namespace mindspore {
34 namespace parallel {
35 namespace {
is_allgather_comm_ops(const AnfNodePtr & node)36 inline bool is_allgather_comm_ops(const AnfNodePtr &node) {
37   static const std::vector<PrimitivePtr> kAllGatherOpsPrim = {prim::kPrimMicroStepAllGather,
38                                                               prim::kPrimMiniStepAllGather, prim::kPrimAllGather};
39 
40   for (const auto &prim : kAllGatherOpsPrim) {
41     if (IsPrimitiveCNode(node, prim)) {
42       auto allgather_instance_name = GetCNodePrimitive(node->cast<CNodePtr>())->instance_name();
43       if (allgather_instance_name.find(parallel::PARALLEL_OPTIMIZER) == std::string::npos) {
44         return false;
45       }
46       return true;
47     }
48   }
49   return false;
50 }
51 
is_first_receive(const AnfNodePtr & node)52 inline bool is_first_receive(const AnfNodePtr &node) {
53   if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
54     auto recv_node = node->cast<CNodePtr>();
55     if (recv_node->HasPrimalAttr(kPrimalAttrForwardNodeName)) {
56       return false;
57     }
58     auto micro = GetValue<int64_t>(recv_node->GetPrimalAttr(parallel::MICRO));
59     if (micro != 0 || recv_node->HasPrimalAttr(parallel::PIPELINE_PARAM)) {
60       return false;
61     }
62     return true;
63   }
64   return false;
65 }
66 }  // namespace
67 
OverlapOptShardInPipeline(const FuncGraphPtr & graph)68 void OverlapOptShardInPipeline(const FuncGraphPtr &graph) {
69   auto context = MsContext::GetInstance();
70   MS_EXCEPTION_IF_NULL(context);
71   static const bool is_enable_ge = (context->backend_policy() == "ge");
72   if (is_enable_ge) {
73     return;
74   }
75   if (parallel::g_device_manager == nullptr) {
76     MS_LOG(INFO) << "parallel::g_device_manager is not initialized.";
77     return;
78   }
79   MS_EXCEPTION_IF_NULL(graph);
80   if (!IsTraining(graph->manager())) {
81     MS_LOG(INFO) << "Skip overlap in Evaluation.";
82     return;
83   }
84   auto manager = graph->manager();
85   MS_EXCEPTION_IF_NULL(manager);
86   if (!parallel::IsAutoParallelCareGraph(graph) ||
87       parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() <= 1 ||
88       parallel::ParallelContext::GetInstance()->grad_accumulation_shard()) {
89     return;
90   }
91   if (parallel::ParallelContext::GetInstance()->enable_fold_pipeline()) {
92     return;
93   }
94   std::list<CNodePtr> orders = graph->GetOrderedCnodes();
95   std::vector<CNodePtr> origin_nodes_topological(orders.cbegin(), orders.cend());
96   CNodePtr first_receive_cnode = nullptr;
97   for (auto &node : origin_nodes_topological) {
98     if (is_first_receive((node))) {
99       first_receive_cnode = node->cast<CNodePtr>();
100       first_receive_cnode->AddAttr(parallel::FIRST_RECEIVE, MakeValue(True));
101     }
102   }
103   if (first_receive_cnode == nullptr) {
104     return;
105   }
106   auto recv_users = manager->node_users()[first_receive_cnode];
107   if (recv_users.empty()) {
108     return;
109   }
110 
111   std::vector<CNodePtr> opt_shard_allgather_list;
112   for (auto &node : origin_nodes_topological) {
113     MS_EXCEPTION_IF_NULL(node);
114     if (!is_allgather_comm_ops(node)) {
115       continue;
116     }
117     auto cnode_allgather = node->cast<CNodePtr>();
118     opt_shard_allgather_list.push_back(cnode_allgather);
119     auto allgather_prim = GetCNodePrimitive(cnode_allgather);
120     auto group_name = GetValue<std::string>(allgather_prim->GetAttr(parallel::GROUP));
121     if (group_name.find("parallel_optimizer") != std::string::npos) {
122       continue;
123     }
124     auto rank_ids = parallel::g_device_manager->FindRankListByHashName(group_name);
125     if (rank_ids.empty()) {
126       continue;
127     }
128     auto dev_list = parallel::g_device_manager->CreateDeviceListByRankList(rank_ids);
129     auto new_group_name = group_name + "_parallel_optimizer";
130     parallel::Group cur_device_list;
131     (void)parallel::g_device_manager->CreateGroup(new_group_name, dev_list, &cur_device_list);
132     (void)allgather_prim->AddAttr(parallel::GROUP, MakeValue<std::string>(new_group_name));
133   }
134   std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
135   (void)std::copy(opt_shard_allgather_list.begin(), opt_shard_allgather_list.end(),
136                   std::back_inserter(make_tuple_inputs));
137   std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), first_receive_cnode,
138                                         graph->NewCNode(make_tuple_inputs)};
139   auto depend_node = graph->NewCNode(depend_inputs);
140   depend_node->set_abstract(first_receive_cnode->abstract()->Clone());
141   depend_node->AddAttr("RecAllGatherDepend", MakeValue(True));
142   (void)manager->Replace(first_receive_cnode, depend_node);
143 }
144 
GetOptShardReduceScatter(const std::vector<AnfNodePtr> & all_nodes)145 static std::vector<CNodePtr> GetOptShardReduceScatter(const std::vector<AnfNodePtr> &all_nodes) {
146   std::vector<CNodePtr> reduce_scatters;
147   for (const auto &node : all_nodes) {
148     if (!IsPrimitiveCNode(node, prim::kPrimReduceScatter)) {
149       continue;
150     }
151     auto prim = GetCNodePrimitive(node);
152     MS_EXCEPTION_IF_NULL(prim);
153     auto instance_name = prim->instance_name();
154     if (instance_name.find(kAttrNeedAllGather) != std::string::npos) {
155       (void)reduce_scatters.emplace_back(node->cast<CNodePtr>());
156     }
157   }
158   return reduce_scatters;
159 }
160 
OverlapOptShardGradInPipeline(const FuncGraphPtr & graph)161 void OverlapOptShardGradInPipeline(const FuncGraphPtr &graph) {
162   if (parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kSemiAutoParallel &&
163       parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kAutoParallel) {
164     return;
165   }
166   if (parallel::g_device_manager == nullptr) {
167     MS_LOG(INFO) << "parallel::g_device_manager is not initialized.";
168     return;
169   }
170   auto context = MsContext::GetInstance();
171   MS_EXCEPTION_IF_NULL(context);
172   auto is_kbk_mode = context->IsKByKExecutorMode();
173   if (!is_kbk_mode) {
174     return;
175   }
176   MS_EXCEPTION_IF_NULL(g_device_manager);
177   auto stage_num = g_device_manager->stage_num();
178   if (stage_num <= 1) {
179     return;
180   }
181   auto ret_after = graph->get_return();
182   MS_EXCEPTION_IF_NULL(ret_after);
183   auto all_nodes = TopoSort(ret_after, SuccDeeperSimple);
184   std::vector<CNodePtr> sends;
185   int64_t micro_size = 1;
186   CNodePtr last_send = nullptr;
187   for (const auto &node : all_nodes) {
188     if (!IsPrimitiveCNode(node, prim::kPrimSend)) {
189       continue;
190     }
191     auto cnode = node->cast<CNodePtr>();
192     if (!cnode->HasPrimalAttr(kPrimalAttrForwardNodeName)) {
193       continue;
194     }
195     sends.emplace_back(cnode);
196     auto micro_attr = cnode->GetPrimalAttr(MICRO);
197     MS_EXCEPTION_IF_NULL(micro_attr);
198     auto micro = GetValue<int64_t>(micro_attr) + 1;
199     if (micro > micro_size) {
200       micro_size = micro;
201       last_send = cnode;
202     }
203   }
204   if (last_send != nullptr) {
205     auto opt_shard_rs = GetOptShardReduceScatter(all_nodes);
206     if (opt_shard_rs.empty()) {
207       return;
208     }
209     auto manager = graph->manager();
210     for (auto rs : opt_shard_rs) {
211       std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), rs->input(1), last_send};
212       auto depend = graph->NewCNode(depend_input);
213       depend->AddPrimalAttr(PP_OPT_SHARD_CONTROL, MakeValue(1));
214       depend->set_abstract(rs->input(1)->abstract());
215       manager->SetEdge(rs, 1, depend);
216     }
217   }
218 }
219 }  // namespace parallel
220 }  // namespace mindspore
221