1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <set>
18 #include <map>
19 #include <vector>
20 #include <string>
21 #include <memory>
22 #include "ir/primal_attr.h"
23 #include "pipeline/jit/ps/pipeline_split.h"
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/other_ops.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "utils/ms_context.h"
28 #include "include/common/utils/comm_manager.h"
29 #include "include/common/utils/parallel_context.h"
30 #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
31 #include "frontend/parallel/pipeline_transformer/pipeline_interleave.h"
32 #include "frontend/parallel/pipeline_transformer/fold_pipeline_transformer.h"
33 #include "frontend/parallel/dynamic_shape/dynamic_shape.h"
34 #include "frontend/parallel/step_parallel.h"
35 #include "frontend/parallel/step_parallel_utils.h"
36 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
37 #include "frontend/parallel/parameter_manager.h"
38 #if defined(__linux__) && defined(WITH_BACKEND)
39 #include "include/backend/distributed/ps/util.h"
40 #include "include/backend/distributed/ps/ps_context.h"
41 #endif
42
43 namespace mindspore {
44 namespace pipeline {
HasVirtualDataset(const std::vector<AnfNodePtr> & all_nodes)45 bool HasVirtualDataset(const std::vector<AnfNodePtr> &all_nodes) {
46 for (auto &node : all_nodes) {
47 if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
48 return true;
49 }
50 }
51 return false;
52 }
53
CreateTupleGetItem(const AnfNodePtr & node,size_t index,const FuncGraphPtr & func_graph)54 static CNodePtr CreateTupleGetItem(const AnfNodePtr &node, size_t index, const FuncGraphPtr &func_graph) {
55 MS_EXCEPTION_IF_NULL(node);
56 MS_EXCEPTION_IF_NULL(func_graph);
57 auto idx = NewValueNode(SizeToLong(index));
58 MS_EXCEPTION_IF_NULL(idx);
59 auto imm = std::make_shared<Int64Imm>(SizeToLong(index));
60 auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
61 idx->set_abstract(abstract_scalar);
62 CNodePtr tuple_get_item = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
63 MS_EXCEPTION_IF_NULL(tuple_get_item);
64 tuple_get_item->set_scope(node->scope());
65 auto input_abstract_tuple = node->abstract()->cast_ptr<abstract::AbstractTuple>();
66 MS_EXCEPTION_IF_NULL(input_abstract_tuple);
67 auto tuple_get_item_abstract = input_abstract_tuple->elements()[index];
68 MS_EXCEPTION_IF_NULL(tuple_get_item_abstract);
69 tuple_get_item->set_abstract(tuple_get_item_abstract);
70 return tuple_get_item;
71 }
72
CreateVirtualDataset(const FuncGraphPtr & func_graph)73 static CNodePtr CreateVirtualDataset(const FuncGraphPtr &func_graph) {
74 std::vector<AbstractBasePtr> abstract_list;
75 std::vector<AnfNodePtr> virtual_dataset_node_inputs;
76 for (size_t index = 0; index < func_graph->get_inputs().size(); index++) {
77 if (!HasAbstractMonad(func_graph->get_inputs()[index])) {
78 auto graph_input_index = func_graph->get_inputs()[index];
79 auto virtual_dataset_abstract = graph_input_index->abstract()->Clone();
80 MS_EXCEPTION_IF_NULL(virtual_dataset_abstract);
81 (void)abstract_list.emplace_back(virtual_dataset_abstract);
82 virtual_dataset_node_inputs.push_back(func_graph->get_inputs()[index]);
83 }
84 }
85
86 auto virtual_dataset_node = mindspore::parallel::CreateCNodeByInputsAndAttr(
87 func_graph, mindspore::parallel::VIRTUAL_DATA_SET, mindspore::parallel::VIRTUAL_DATA_SET,
88 virtual_dataset_node_inputs, {});
89 MS_EXCEPTION_IF_NULL(virtual_dataset_node);
90 virtual_dataset_node->set_in_forward_flag(true);
91 virtual_dataset_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
92 return virtual_dataset_node;
93 }
94
FindForwardGraph(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)95 static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
96 std::set<FuncGraphPtr> graph_sets;
97 auto manager = root->manager();
98 if (!parallel::IsAutoParallelCareGraph(root)) {
99 return graph_sets;
100 }
101 std::set<AnfNodePtr> input_parameters;
102 for (auto &anf_param : root->parameters()) {
103 auto param = anf_param->cast_ptr<Parameter>();
104 if (!param->has_default()) {
105 (void)input_parameters.insert(anf_param);
106 }
107 }
108 for (const auto &input_parameter : input_parameters) {
109 auto node_users_map = root->manager()->node_users();
110 auto node_users = node_users_map[input_parameter];
111 for (auto node_user : node_users) {
112 auto cnode = node_user.first->cast_ptr<CNode>();
113 MS_EXCEPTION_IF_NULL(cnode);
114 if (IsValueNode<Primitive>(cnode->inputs()[0]) ||
115 (IsValueNode<FuncGraph>(cnode->inputs()[0]) && !parallel::IsTraining(manager))) {
116 (void)graph_sets.insert(cnode->func_graph());
117 }
118 }
119 }
120 auto execution_mode = MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE);
121 for (auto &node : all_nodes) {
122 MS_EXCEPTION_IF_NULL(node);
123 if (!node->isa<CNode>()) {
124 continue;
125 }
126 auto cnode = node->cast_ptr<CNode>();
127 if ((cnode->size() < NODE_INPUT_NUM) || !IsValueNode<Primitive>(cnode->input(0))) {
128 continue;
129 }
130 auto expect_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
131 FuncGraphPtr fun_graph = nullptr;
132 if (expect_prim->name() == mindspore::parallel::J ||
133 ((expect_prim->name() == mindspore::parallel::SHARD) && (execution_mode == kPynativeMode))) {
134 if (IsValueNode<FuncGraph>(cnode->inputs()[1])) {
135 fun_graph = GetValueNode<FuncGraphPtr>(cnode->inputs()[1]);
136 } else {
137 fun_graph = node->func_graph();
138 }
139 (void)graph_sets.insert(fun_graph);
140 }
141 }
142 return graph_sets;
143 }
144
InsertVirtualDataset(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)145 void InsertVirtualDataset(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
146 MS_EXCEPTION_IF_NULL(root);
147 std::set<FuncGraphPtr> forward_graph_set = FindForwardGraph(root, all_nodes);
148 for (const auto &forward_graph : forward_graph_set) {
149 FuncGraphManagerPtr manager = forward_graph->manager();
150 MS_EXCEPTION_IF_NULL(manager);
151 std::vector<AnfNodePtr> graph_inputs = forward_graph->get_inputs();
152 // SetEdge will be called later, so deep copy is required.
153 auto node_user_map = manager->node_users();
154 auto virtual_dataset_node = CreateVirtualDataset(forward_graph);
155 std::map<size_t, CNodePtr> parameter_index_map;
156 for (size_t index = 0; index < graph_inputs.size(); index++) {
157 if (HasAbstractMonad(graph_inputs[index])) {
158 continue;
159 }
160 auto node_users = node_user_map[graph_inputs[index]];
161 for (const auto &node_user : node_users) {
162 auto cnode = node_user.first->cast<CNodePtr>();
163 for (size_t input_index = 1; input_index < cnode->size(); input_index++) {
164 if (!IsValueNode<Primitive>(cnode->inputs()[0]) && !IsValueNode<FuncGraph>(cnode->inputs()[0]) &&
165 !IsPrimitiveCNode(cnode->input(0), prim::kPrimVmap)) {
166 continue;
167 }
168 bool is_node_input_flag = !(IsValueNode<mindspore::tensor::Tensor>(cnode->inputs()[input_index]) ||
169 IsValueNode<ValueList>(cnode->inputs()[input_index]) ||
170 IsValueNode<ValueTuple>(cnode->inputs()[input_index]));
171 auto node_input_iter = find(graph_inputs.begin(), graph_inputs.end(), cnode->inputs()[input_index]);
172 bool is_match = node_input_iter != graph_inputs.end() && is_node_input_flag &&
173 !HasAbstractMonad(cnode->inputs()[input_index]);
174 if (!is_match) {
175 continue;
176 }
177 size_t node_input_index = LongToSize(node_input_iter - graph_inputs.begin());
178 if (parameter_index_map.empty() || parameter_index_map.count(node_input_index) == 0) {
179 parameter_index_map[node_input_index] =
180 CreateTupleGetItem(virtual_dataset_node, node_input_index, forward_graph);
181 }
182 manager->SetEdge(cnode, SizeToInt(input_index), parameter_index_map[node_input_index]);
183 manager->SetEdge(parameter_index_map[node_input_index], 1, virtual_dataset_node);
184 }
185 }
186 }
187 }
188 }
189
PipelineInterleaved(const FuncGraphManagerPtr & mng,const FuncGraphPtr & root,int64_t stage,bool gen_mask_not_fusion)190 static bool PipelineInterleaved(const FuncGraphManagerPtr &mng, const FuncGraphPtr &root, int64_t stage,
191 bool gen_mask_not_fusion) {
192 auto pipeline_interleave = std::make_shared<parallel::PipelineInterleave>(mng, stage, root);
193 pipeline_interleave->Init();
194 pipeline_interleave->Coloring();
195 if (!pipeline_interleave->MainGraph()) {
196 MS_LOG(EXCEPTION) << "Cannot find main graph with virtual_dataset in pipeline parallel";
197 }
198 pipeline_interleave->BroadCastColoring();
199 if (!gen_mask_not_fusion) {
200 pipeline_interleave->LabelGenMaskFusion();
201 }
202 pipeline_interleave->LabelMicroBatch();
203 pipeline_interleave->ParameterColoring();
204 pipeline_interleave->CutBorder();
205 pipeline_interleave->ElimParameter();
206 return true;
207 }
208
209 // Only auto_parallel and semi_auto_parallel support PipelineSplit
PipelineSplit(const ResourcePtr & res)210 bool PipelineSplit(const ResourcePtr &res) {
211 #if defined(__linux__) && defined(WITH_BACKEND)
212 if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
213 return true;
214 }
215 #endif
216 MS_EXCEPTION_IF_NULL(res);
217 auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
218 if (parallel_mode != parallel::kSemiAutoParallel && parallel_mode != parallel::kAutoParallel) {
219 MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split.";
220 return true;
221 }
222
223 auto manager = res->manager();
224 auto root = res->func_graph();
225
226 // tag dynamic shape graph
227 parallel::TagDynamicShapeFuncGraph(root);
228
229 auto global_rank = parallel::GetRank();
230 auto world_group = mindspore::parallel::GetWorldGroup();
231 uint32_t world_rank_size = 0;
232 int64_t device_num = 0;
233 if (!parallel::ParallelContext::GetInstance()->device_num_is_set()) {
234 if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
235 MS_LOG(EXCEPTION) << "Get rank size failed";
236 }
237 device_num = UintToInt(world_rank_size);
238 MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num;
239 } else {
240 device_num = parallel::ParallelContext::GetInstance()->device_num();
241 }
242
243 if (device_num < 1) {
244 MS_LOG(ERROR) << "For 'PipelineSplit', the argument 'device_num' must be positive, "
245 "but got the value of device_num: "
246 << device_num;
247 }
248 if (global_rank < 0) {
249 MS_LOG(ERROR) << "For 'PipelineSplit', the argument 'global_rank' must be nonnegative, "
250 "but got the value of global_rank: "
251 << global_rank;
252 }
253 static const auto gen_mask_not_fusion = (common::GetEnv("GENMASK_NOT_FUSION") == "1");
254 auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
255 if (stage_num <= 1) {
256 MS_LOG(INFO) << "The parameter 'stage_num' is: " << stage_num << ". No need Pipeline split.";
257 auto tmp_transformer = std::make_shared<parallel::PipelineTransformer>(manager, 0, root, global_rank, global_rank);
258 if (!tmp_transformer->MainGraph()) {
259 return true;
260 }
261 if (!gen_mask_not_fusion) {
262 tmp_transformer->LabelGenMaskFusion();
263 }
264 return true;
265 }
266
267 auto stage = parallel::InferStage();
268 auto per_stage_rank_num = device_num / stage_num;
269 if (parallel::ParallelInit() != parallel::SUCCESS) {
270 MS_LOG(EXCEPTION) << "parallel init failed.";
271 }
272 auto parallel_context = parallel::ParallelContext::GetInstance();
273 MS_EXCEPTION_IF_NULL(parallel_context);
274 auto is_pp_interleave = parallel_context->pipeline_interleave();
275 if (is_pp_interleave) {
276 return PipelineInterleaved(manager, root, stage, gen_mask_not_fusion);
277 }
278 auto transformer =
279 std::make_shared<parallel::PipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num);
280
281 if (parallel_context->enable_fold_pipeline()) {
282 MS_LOG(INFO) << "Begin Fold Pipeline Transformer ";
283 transformer =
284 std::make_shared<parallel::FoldPipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num);
285 }
286 // step1: Do color graph
287 transformer->Coloring();
288 if (!transformer->MainGraph()) {
289 MS_LOG(EXCEPTION) << "Cannot find main graph with virtual_dataset in pipeline parallel";
290 }
291
292 // step2: Do color broadcast
293 transformer->BroadCastColoring();
294 if (!gen_mask_not_fusion) {
295 transformer->LabelGenMaskFusion();
296 }
297 transformer->LabelMicroBatch();
298 // step3: Handle shared parameters
299 transformer->ParameterColoring();
300 // step4: Cut Graph
301 transformer->CutGraph();
302 // step5: Elim Graph stages and no used parameter
303 transformer->ModifyParameterList();
304 transformer->ElimGraphStage();
305 return true;
306 }
307
308 // Only auto_parallel and semi_auto_parallel support ParallelVirtualDataset
ParallelVirtualDataset(const ResourcePtr & res)309 bool ParallelVirtualDataset(const ResourcePtr &res) {
310 #if defined(__linux__) && defined(WITH_BACKEND)
311 if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
312 return true;
313 }
314 #endif
315 MS_EXCEPTION_IF_NULL(res);
316 auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
317 if (parallel_mode != parallel::kSemiAutoParallel && parallel_mode != parallel::kAutoParallel) {
318 MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support it.";
319 return true;
320 }
321
322 auto root = res->func_graph();
323 AnfNodePtr ret = root->get_return();
324
325 // tag dynamic shape graph
326 parallel::TagDynamicShapeFuncGraph(root);
327
328 MS_EXCEPTION_IF_NULL(ret);
329 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
330
331 if (!HasVirtualDataset(all_nodes)) {
332 InsertVirtualDataset(root, all_nodes);
333 }
334
335 return true;
336 }
337 } // namespace pipeline
338 } // namespace mindspore
339