• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <set>
18 #include <map>
19 #include <vector>
20 #include <string>
21 #include <memory>
22 #include "ir/primal_attr.h"
23 #include "pipeline/jit/ps/pipeline_split.h"
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/other_ops.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "utils/ms_context.h"
28 #include "include/common/utils/comm_manager.h"
29 #include "include/common/utils/parallel_context.h"
30 #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
31 #include "frontend/parallel/pipeline_transformer/pipeline_interleave.h"
32 #include "frontend/parallel/pipeline_transformer/fold_pipeline_transformer.h"
33 #include "frontend/parallel/dynamic_shape/dynamic_shape.h"
34 #include "frontend/parallel/step_parallel.h"
35 #include "frontend/parallel/step_parallel_utils.h"
36 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
37 #include "frontend/parallel/parameter_manager.h"
38 #if defined(__linux__) && defined(WITH_BACKEND)
39 #include "include/backend/distributed/ps/util.h"
40 #include "include/backend/distributed/ps/ps_context.h"
41 #endif
42 
43 namespace mindspore {
44 namespace pipeline {
HasVirtualDataset(const std::vector<AnfNodePtr> & all_nodes)45 bool HasVirtualDataset(const std::vector<AnfNodePtr> &all_nodes) {
46   for (auto &node : all_nodes) {
47     if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
48       return true;
49     }
50   }
51   return false;
52 }
53 
CreateTupleGetItem(const AnfNodePtr & node,size_t index,const FuncGraphPtr & func_graph)54 static CNodePtr CreateTupleGetItem(const AnfNodePtr &node, size_t index, const FuncGraphPtr &func_graph) {
55   MS_EXCEPTION_IF_NULL(node);
56   MS_EXCEPTION_IF_NULL(func_graph);
57   auto idx = NewValueNode(SizeToLong(index));
58   MS_EXCEPTION_IF_NULL(idx);
59   auto imm = std::make_shared<Int64Imm>(SizeToLong(index));
60   auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
61   idx->set_abstract(abstract_scalar);
62   CNodePtr tuple_get_item = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
63   MS_EXCEPTION_IF_NULL(tuple_get_item);
64   tuple_get_item->set_scope(node->scope());
65   auto input_abstract_tuple = node->abstract()->cast_ptr<abstract::AbstractTuple>();
66   MS_EXCEPTION_IF_NULL(input_abstract_tuple);
67   auto tuple_get_item_abstract = input_abstract_tuple->elements()[index];
68   MS_EXCEPTION_IF_NULL(tuple_get_item_abstract);
69   tuple_get_item->set_abstract(tuple_get_item_abstract);
70   return tuple_get_item;
71 }
72 
CreateVirtualDataset(const FuncGraphPtr & func_graph)73 static CNodePtr CreateVirtualDataset(const FuncGraphPtr &func_graph) {
74   std::vector<AbstractBasePtr> abstract_list;
75   std::vector<AnfNodePtr> virtual_dataset_node_inputs;
76   for (size_t index = 0; index < func_graph->get_inputs().size(); index++) {
77     if (!HasAbstractMonad(func_graph->get_inputs()[index])) {
78       auto graph_input_index = func_graph->get_inputs()[index];
79       auto virtual_dataset_abstract = graph_input_index->abstract()->Clone();
80       MS_EXCEPTION_IF_NULL(virtual_dataset_abstract);
81       (void)abstract_list.emplace_back(virtual_dataset_abstract);
82       virtual_dataset_node_inputs.push_back(func_graph->get_inputs()[index]);
83     }
84   }
85 
86   auto virtual_dataset_node = mindspore::parallel::CreateCNodeByInputsAndAttr(
87     func_graph, mindspore::parallel::VIRTUAL_DATA_SET, mindspore::parallel::VIRTUAL_DATA_SET,
88     virtual_dataset_node_inputs, {});
89   MS_EXCEPTION_IF_NULL(virtual_dataset_node);
90   virtual_dataset_node->set_in_forward_flag(true);
91   virtual_dataset_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
92   return virtual_dataset_node;
93 }
94 
FindForwardGraph(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)95 static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
96   std::set<FuncGraphPtr> graph_sets;
97   auto manager = root->manager();
98   if (!parallel::IsAutoParallelCareGraph(root)) {
99     return graph_sets;
100   }
101   std::set<AnfNodePtr> input_parameters;
102   for (auto &anf_param : root->parameters()) {
103     auto param = anf_param->cast_ptr<Parameter>();
104     if (!param->has_default()) {
105       (void)input_parameters.insert(anf_param);
106     }
107   }
108   for (const auto &input_parameter : input_parameters) {
109     auto node_users_map = root->manager()->node_users();
110     auto node_users = node_users_map[input_parameter];
111     for (auto node_user : node_users) {
112       auto cnode = node_user.first->cast_ptr<CNode>();
113       MS_EXCEPTION_IF_NULL(cnode);
114       if (IsValueNode<Primitive>(cnode->inputs()[0]) ||
115           (IsValueNode<FuncGraph>(cnode->inputs()[0]) && !parallel::IsTraining(manager))) {
116         (void)graph_sets.insert(cnode->func_graph());
117       }
118     }
119   }
120   auto execution_mode = MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE);
121   for (auto &node : all_nodes) {
122     MS_EXCEPTION_IF_NULL(node);
123     if (!node->isa<CNode>()) {
124       continue;
125     }
126     auto cnode = node->cast_ptr<CNode>();
127     if ((cnode->size() < NODE_INPUT_NUM) || !IsValueNode<Primitive>(cnode->input(0))) {
128       continue;
129     }
130     auto expect_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
131     FuncGraphPtr fun_graph = nullptr;
132     if (expect_prim->name() == mindspore::parallel::J ||
133         ((expect_prim->name() == mindspore::parallel::SHARD) && (execution_mode == kPynativeMode))) {
134       if (IsValueNode<FuncGraph>(cnode->inputs()[1])) {
135         fun_graph = GetValueNode<FuncGraphPtr>(cnode->inputs()[1]);
136       } else {
137         fun_graph = node->func_graph();
138       }
139       (void)graph_sets.insert(fun_graph);
140     }
141   }
142   return graph_sets;
143 }
144 
InsertVirtualDataset(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)145 void InsertVirtualDataset(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
146   MS_EXCEPTION_IF_NULL(root);
147   std::set<FuncGraphPtr> forward_graph_set = FindForwardGraph(root, all_nodes);
148   for (const auto &forward_graph : forward_graph_set) {
149     FuncGraphManagerPtr manager = forward_graph->manager();
150     MS_EXCEPTION_IF_NULL(manager);
151     std::vector<AnfNodePtr> graph_inputs = forward_graph->get_inputs();
152     // SetEdge will be called later, so deep copy is required.
153     auto node_user_map = manager->node_users();
154     auto virtual_dataset_node = CreateVirtualDataset(forward_graph);
155     std::map<size_t, CNodePtr> parameter_index_map;
156     for (size_t index = 0; index < graph_inputs.size(); index++) {
157       if (HasAbstractMonad(graph_inputs[index])) {
158         continue;
159       }
160       auto node_users = node_user_map[graph_inputs[index]];
161       for (const auto &node_user : node_users) {
162         auto cnode = node_user.first->cast<CNodePtr>();
163         for (size_t input_index = 1; input_index < cnode->size(); input_index++) {
164           if (!IsValueNode<Primitive>(cnode->inputs()[0]) && !IsValueNode<FuncGraph>(cnode->inputs()[0]) &&
165               !IsPrimitiveCNode(cnode->input(0), prim::kPrimVmap)) {
166             continue;
167           }
168           bool is_node_input_flag = !(IsValueNode<mindspore::tensor::Tensor>(cnode->inputs()[input_index]) ||
169                                       IsValueNode<ValueList>(cnode->inputs()[input_index]) ||
170                                       IsValueNode<ValueTuple>(cnode->inputs()[input_index]));
171           auto node_input_iter = find(graph_inputs.begin(), graph_inputs.end(), cnode->inputs()[input_index]);
172           bool is_match = node_input_iter != graph_inputs.end() && is_node_input_flag &&
173                           !HasAbstractMonad(cnode->inputs()[input_index]);
174           if (!is_match) {
175             continue;
176           }
177           size_t node_input_index = LongToSize(node_input_iter - graph_inputs.begin());
178           if (parameter_index_map.empty() || parameter_index_map.count(node_input_index) == 0) {
179             parameter_index_map[node_input_index] =
180               CreateTupleGetItem(virtual_dataset_node, node_input_index, forward_graph);
181           }
182           manager->SetEdge(cnode, SizeToInt(input_index), parameter_index_map[node_input_index]);
183           manager->SetEdge(parameter_index_map[node_input_index], 1, virtual_dataset_node);
184         }
185       }
186     }
187   }
188 }
189 
PipelineInterleaved(const FuncGraphManagerPtr & mng,const FuncGraphPtr & root,int64_t stage,bool gen_mask_not_fusion)190 static bool PipelineInterleaved(const FuncGraphManagerPtr &mng, const FuncGraphPtr &root, int64_t stage,
191                                 bool gen_mask_not_fusion) {
192   auto pipeline_interleave = std::make_shared<parallel::PipelineInterleave>(mng, stage, root);
193   pipeline_interleave->Init();
194   pipeline_interleave->Coloring();
195   if (!pipeline_interleave->MainGraph()) {
196     MS_LOG(EXCEPTION) << "Cannot find main graph with virtual_dataset in pipeline parallel";
197   }
198   pipeline_interleave->BroadCastColoring();
199   if (!gen_mask_not_fusion) {
200     pipeline_interleave->LabelGenMaskFusion();
201   }
202   pipeline_interleave->LabelMicroBatch();
203   pipeline_interleave->ParameterColoring();
204   pipeline_interleave->CutBorder();
205   pipeline_interleave->ElimParameter();
206   return true;
207 }
208 
209 // Only auto_parallel and semi_auto_parallel support PipelineSplit
PipelineSplit(const ResourcePtr & res)210 bool PipelineSplit(const ResourcePtr &res) {
211 #if defined(__linux__) && defined(WITH_BACKEND)
212   if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
213     return true;
214   }
215 #endif
216   MS_EXCEPTION_IF_NULL(res);
217   auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
218   if (parallel_mode != parallel::kSemiAutoParallel && parallel_mode != parallel::kAutoParallel) {
219     MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split.";
220     return true;
221   }
222 
223   auto manager = res->manager();
224   auto root = res->func_graph();
225 
226   // tag dynamic shape graph
227   parallel::TagDynamicShapeFuncGraph(root);
228 
229   auto global_rank = parallel::GetRank();
230   auto world_group = mindspore::parallel::GetWorldGroup();
231   uint32_t world_rank_size = 0;
232   int64_t device_num = 0;
233   if (!parallel::ParallelContext::GetInstance()->device_num_is_set()) {
234     if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
235       MS_LOG(EXCEPTION) << "Get rank size failed";
236     }
237     device_num = UintToInt(world_rank_size);
238     MS_LOG(INFO) << "Get device num from communication model, the device num is  " << device_num;
239   } else {
240     device_num = parallel::ParallelContext::GetInstance()->device_num();
241   }
242 
243   if (device_num < 1) {
244     MS_LOG(ERROR) << "For 'PipelineSplit', the argument 'device_num' must be positive, "
245                      "but got the value of device_num: "
246                   << device_num;
247   }
248   if (global_rank < 0) {
249     MS_LOG(ERROR) << "For 'PipelineSplit', the argument 'global_rank' must be nonnegative, "
250                      "but got the value of global_rank: "
251                   << global_rank;
252   }
253   static const auto gen_mask_not_fusion = (common::GetEnv("GENMASK_NOT_FUSION") == "1");
254   auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
255   if (stage_num <= 1) {
256     MS_LOG(INFO) << "The parameter 'stage_num' is: " << stage_num << ". No need Pipeline split.";
257     auto tmp_transformer = std::make_shared<parallel::PipelineTransformer>(manager, 0, root, global_rank, global_rank);
258     if (!tmp_transformer->MainGraph()) {
259       return true;
260     }
261     if (!gen_mask_not_fusion) {
262       tmp_transformer->LabelGenMaskFusion();
263     }
264     return true;
265   }
266 
267   auto stage = parallel::InferStage();
268   auto per_stage_rank_num = device_num / stage_num;
269   if (parallel::ParallelInit() != parallel::SUCCESS) {
270     MS_LOG(EXCEPTION) << "parallel init failed.";
271   }
272   auto parallel_context = parallel::ParallelContext::GetInstance();
273   MS_EXCEPTION_IF_NULL(parallel_context);
274   auto is_pp_interleave = parallel_context->pipeline_interleave();
275   if (is_pp_interleave) {
276     return PipelineInterleaved(manager, root, stage, gen_mask_not_fusion);
277   }
278   auto transformer =
279     std::make_shared<parallel::PipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num);
280 
281   if (parallel_context->enable_fold_pipeline()) {
282     MS_LOG(INFO) << "Begin Fold Pipeline Transformer ";
283     transformer =
284       std::make_shared<parallel::FoldPipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num);
285   }
286   // step1: Do color graph
287   transformer->Coloring();
288   if (!transformer->MainGraph()) {
289     MS_LOG(EXCEPTION) << "Cannot find main graph with virtual_dataset in pipeline parallel";
290   }
291 
292   // step2: Do color broadcast
293   transformer->BroadCastColoring();
294   if (!gen_mask_not_fusion) {
295     transformer->LabelGenMaskFusion();
296   }
297   transformer->LabelMicroBatch();
298   // step3: Handle shared parameters
299   transformer->ParameterColoring();
300   // step4: Cut Graph
301   transformer->CutGraph();
302   // step5: Elim Graph stages and no used parameter
303   transformer->ModifyParameterList();
304   transformer->ElimGraphStage();
305   return true;
306 }
307 
308 // Only auto_parallel and semi_auto_parallel support ParallelVirtualDataset
ParallelVirtualDataset(const ResourcePtr & res)309 bool ParallelVirtualDataset(const ResourcePtr &res) {
310 #if defined(__linux__) && defined(WITH_BACKEND)
311   if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
312     return true;
313   }
314 #endif
315   MS_EXCEPTION_IF_NULL(res);
316   auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
317   if (parallel_mode != parallel::kSemiAutoParallel && parallel_mode != parallel::kAutoParallel) {
318     MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support it.";
319     return true;
320   }
321 
322   auto root = res->func_graph();
323   AnfNodePtr ret = root->get_return();
324 
325   // tag dynamic shape graph
326   parallel::TagDynamicShapeFuncGraph(root);
327 
328   MS_EXCEPTION_IF_NULL(ret);
329   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
330 
331   if (!HasVirtualDataset(all_nodes)) {
332     InsertVirtualDataset(root, all_nodes);
333   }
334 
335   return true;
336 }
337 }  // namespace pipeline
338 }  // namespace mindspore
339