1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/step_auto_parallel.h"
18
19 #include <algorithm>
20 #include <cinttypes>
21 #include <ctime>
22 #include <map>
23 #include <memory>
24 #include <set>
25 #include <string>
26 #include <utility>
27 #include <vector>
28
29 #include "frontend/optimizer/opt.h"
30 #include "frontend/optimizer/optimizer.h"
31 #include "frontend/parallel/auto_parallel/dp_algo_costmodel.h"
32 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
33 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
34 #include "frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h"
35 #include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h"
36 #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h"
37 #include "frontend/parallel/graph_util/graph_info.h"
38 #include "frontend/parallel/graph_util/node_info.h"
39 #include "frontend/parallel/ops_info/reshape_info.h"
40 #include "frontend/parallel/ops_info/tmp_identity_info.h"
41 #include "frontend/parallel/parameter_manager.h"
42 #include "frontend/parallel/step_parallel.h"
43 #include "frontend/parallel/step_parallel_utils.h"
44 #include "frontend/parallel/dynamic_shape/dynamic_shape.h"
45 #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
46 #include "include/common/utils/parallel_context.h"
47 #include "ir/anf.h"
48 #include "ir/param_info.h"
49 #include "ir/tensor.h"
50 #include "ops/array_ops.h"
51 #include "ops/framework_ops.h"
52 #include "ops/math_ops.h"
53 #include "ops/other_ops.h"
54 #include "ops/sequence_ops.h"
55 #include "pipeline/jit/ps/pipeline_split.h"
56 #include "utils/hash_map.h"
57 #include "utils/hash_set.h"
58 #include "utils/ms_context.h"
59 #if defined(__linux__) && defined(WITH_BACKEND)
60 #include "include/backend/distributed/ps/util.h"
61 #endif
62
63 namespace mindspore {
64 namespace parallel {
SearchParallelStrategy(const std::string & strategy_search_mode,const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)65 void SearchParallelStrategy(const std::string &strategy_search_mode, const FuncGraphPtr &root,
66 const std::vector<AnfNodePtr> &all_nodes) {
67 if (StrategyCheckpoint::GetInstance().LoadAutoOpStrategyOn()) {
68 if (LoadStrategyFromFile(root, all_nodes) == SUCCESS) {
69 MS_LOG(INFO) << "Load strategies success, jump searching strategy.";
70 return;
71 }
72 }
73 if ((strategy_search_mode == kDynamicProgramming) || (strategy_search_mode == kShardingPropagation)) {
74 if (ParallelStrategySearch(all_nodes, root) != SUCCESS) {
75 MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using " << strategy_search_mode
76 << " searching mode";
77 }
78 } else if (strategy_search_mode == kRecursiveProgramming) {
79 if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) {
80 MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode";
81 }
82 } else {
83 MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected: " << strategy_search_mode;
84 }
85 if (StrategyCheckpoint::GetInstance().SaveAutoOpStrategyOn()) {
86 SaveStrategyToFile();
87 }
88 }
89
HasCellShard(const FuncGraphPtr & func_graph)90 bool HasCellShard(const FuncGraphPtr &func_graph) {
91 AnfNodePtr ret = func_graph->get_return();
92 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
93 for (auto &node : all_nodes) {
94 if (IsPrimitiveCNode(node, prim::kPrimShard) || IsPrimitiveCNode(node, prim::kPrimReshard)) {
95 return true;
96 }
97 }
98 return false;
99 }
100
IsSkipAutoParallel(const FuncGraphPtr & root,const std::string & strategy_search_mode,const bool is_pre_action)101 bool IsSkipAutoParallel(const FuncGraphPtr &root, const std::string &strategy_search_mode, const bool is_pre_action) {
102 root->set_flag(kHasShard, HasCellShard(root));
103 std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
104 if (root->has_flag(kSkipAutoParallelCompile) || parallel_mode != kAutoParallel ||
105 root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY) || HasNestedMetaFg(root)) {
106 return true;
107 }
108
109 // For parallel with shard, skip PreAutoParallel
110 // Shard Prim will be deleted once shard is set, see pass.cc.
111 if (root->has_flag(kHasShard)) {
112 return true;
113 }
114 if (root->has_flag(kSharded)) {
115 return false;
116 }
117 if (parallel::IsPynativeParallel() && !root->has_flag(kHasShard)) {
118 return true;
119 }
120
121 if ((is_pre_action && strategy_search_mode == kDynamicProgramming) ||
122 (!is_pre_action && strategy_search_mode != kDynamicProgramming)) {
123 return true;
124 }
125 return false;
126 }
127
StepAutoParallel(const FuncGraphPtr & root,const opt::OptimizerPtr &)128 bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
129 // Mode 'dynamic programming' will run after pipeline_split, others don't.
130 MS_EXCEPTION_IF_NULL(root);
131 bool is_pre_action = !root->has_flag(AUTO_PARALLEL_FINISH_PRE_ACTION);
132 bool changes;
133 if (is_pre_action) {
134 root->set_flag(AUTO_PARALLEL_FINISH_PRE_ACTION, true);
135 auto manager = root->manager();
136 const auto &graphs = manager->func_graphs();
137 bool is_training = std::any_of(graphs.cbegin(), graphs.cend(),
138 [](auto cur_graph) -> bool { return cur_graph->has_flag(kTraining); });
139 if (is_training) {
140 root->set_flag(kTraining, true);
141 }
142 changes = true;
143 } else {
144 changes = false;
145 }
146
147 #if defined(__linux__) && defined(WITH_BACKEND)
148 if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) {
149 return changes;
150 }
151 #endif
152 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
153 // control whether use model_parallel mode
154 std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode();
155 bool is_skip = IsSkipAutoParallel(root, strategy_search_mode, is_pre_action);
156 if (is_skip && !ParallelContext::GetInstance()->direct_split()) {
157 return changes;
158 }
159 MS_LOG(INFO) << "search_mode: " << strategy_search_mode;
160
161 // tag dynamic shape graph
162 parallel::TagDynamicShapeFuncGraph(root);
163
164 MSLogTime msTime;
165 msTime.Start();
166 #ifdef ENABLE_DUMP_IR
167 auto context = MsContext::GetInstance();
168 MS_EXCEPTION_IF_NULL(context);
169 if (context->CanDump(kIntroductory)) {
170 DumpGraph(root, std::string(STEP_AUTO_PARALLEL_BEGIN));
171 }
172 #endif
173 MS_LOG(INFO) << "Now entering step auto parallel";
174 TOTAL_OPS = 0;
175 AnfNodePtr ret = root->get_return();
176 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
177
178 // insert Virtual Dataset if not exist
179 if (ParallelInit() != SUCCESS) {
180 MS_LOG(EXCEPTION) << "Parallel init failed";
181 }
182 if (!mindspore::pipeline::HasVirtualDataset(all_nodes)) {
183 mindspore::pipeline::InsertVirtualDataset(root, all_nodes);
184 }
185 // redo deepscoped search again to connected the Virtual Dataset into the graph
186 all_nodes = DeepScopedGraphSearch(ret);
187
188 if (strategy_search_mode == kRecursiveProgramming &&
189 ((g_device_manager->DeviceNum() & (g_device_manager->DeviceNum() - 1)) != 0)) {
190 MS_LOG(EXCEPTION)
191 << "The recursive auto parallel strategy searching mode requires the device num be the power of 2.";
192 }
193 // mark the forward cnodes, parallel only care these nodes
194 MarkForwardCNode(root);
195
196 ExceptionIfHasCommunicationOp(all_nodes);
197
198 if (IsInsertVirtualOutput(root)) {
199 InsertVirtualOutput(root, all_nodes);
200 AnfNodePtr ret_after = root->get_return();
201 MS_EXCEPTION_IF_NULL(ret_after);
202 all_nodes = DeepScopedGraphSearch(ret_after);
203 }
204
205 // search parallelization strategy
206 SearchParallelStrategy(strategy_search_mode, root, all_nodes);
207 msTime.End();
208 uint64_t time = msTime.GetRunTimeUS();
209
210 MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us";
211
212 root->set_flag(AUTO_PARALLEL_RUN_ONCE_ONLY, true);
213 return changes;
214 }
215
IsElementWiseOperator(const std::string & op_name)216 bool IsElementWiseOperator(const std::string &op_name) {
217 // clang-format off
218 static const std::set<std::string> elementwise_op = {ACTIVATION, GELU, TANH,
219 SOFTMAX, LOG_SOFTMAX, RELU,
220 SQRT, CAST, POW,
221 EXP, LOG, COS,
222 ACOS, LOGICALNOT, NEG,
223 SQUARE, SIGMOID, ABS,
224 ACOSH, ASIN, ASINH,
225 ATAN, ATANH, CEIL,
226 COSH, EXPM1, LOG1P,
227 SIN, SINH, TAN,
228 RSQRT, RECIPROCAL, INV,
229 ROUND, FLOOR, SIGN,
230 ERF, ERFC, ZEROSLIKE,
231 ONESLIKE, BESSELI0E, BESSELI0,
232 BESSELI1, BESSELJ0, BESSELJ0,
233 ASSIGN, ASSIGN_ADD, ATAN2,
234 DIVNONAN, LOGICALAND, ELU,
235 LOGICALOR, RELU6, SOFTPLUS,
236 SOFTSIGN, LESS, LESSEQUAL,
237 BESSELI1E, GREATEREQUAL, APPROXIMATEEQUAL,
238 MOD, REVERSEV2, REPEAT_ELEMENTS,
239 TRUNC, LGAMMA, CHOLESKY};
240 // clang-format on
241 auto iter = elementwise_op.find(op_name);
242 return (iter != elementwise_op.cend());
243 }
244
245 // Recording the operators appearing in a for-loop.
246 // Currently, we assume that the operators in different for-loops are identical, and their traversal
247 // orderings are also identical.
248 // Therefore, we create OperatorInfo objects for the operators in a loop (say, loop-3), and reuse them in
249 // the rest of loops (loop-2, loop-1 and loop-0)
250 std::set<std::string> ops_in_a_loop_;
251 // Whether two operators are in different loops; if it is true, then return true.
252 // If at least one of the two operators is not in the loop, then return false.
253 // If two operators are in the same loop, the return false.
IsOperatorsInTwoSeparateLoops(const CNodePtr & a_cnode,const CNodePtr & b_cnode)254 bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cnode) {
255 auto a_op_info = a_cnode->user_data<OperatorInfo>();
256 MS_EXCEPTION_IF_NULL(a_op_info);
257 auto b_op_info = b_cnode->user_data<OperatorInfo>();
258 MS_EXCEPTION_IF_NULL(b_op_info);
259 if ((ops_in_a_loop_.find(a_op_info->name()) == ops_in_a_loop_.end()) ||
260 (ops_in_a_loop_.find(b_op_info->name()) == ops_in_a_loop_.end())) {
261 return false;
262 }
263 size_t a_loop_index = 0, b_loop_index = 0;
264 const auto &a_fullname = a_cnode->fullname_with_scope();
265 if (!GetLoopIndexFromCNode(a_cnode, &a_loop_index)) {
266 MS_LOG(EXCEPTION) << "The operator with fullname_with_scope: " << a_fullname << " was not included in the set.";
267 }
268 const auto &b_fullname = b_cnode->fullname_with_scope();
269 if (!GetLoopIndexFromCNode(b_cnode, &b_loop_index)) {
270 MS_LOG(EXCEPTION) << "The operator with fullname_with_scope: " << b_fullname << " was not included in the set.";
271 }
272 if (a_loop_index == b_loop_index) {
273 return false;
274 }
275 return true;
276 }
277
278 // 'configured_stra_ops_' includes all operators that are configured sharding strategies.
279 std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> configured_stra_ops_;
280 std::set<OperatorInfoPtr> ignore_candidate_;
InitCostGraph()281 void InitCostGraph() {
282 if (entire_costgraph == nullptr) {
283 entire_costgraph = std::make_shared<CostGraph>();
284 }
285 MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
286 CostModelContext::GetInstance()->PrintCostModel();
287 entire_costgraph->Init();
288 configured_stra_ops_.clear();
289 ignore_candidate_.clear();
290 }
291
SetStrategyToOperator(const OperatorInfoPtr & operator_info,const PrimitivePtr & prim,mindspore::HashMap<std::string,ValuePtr> attrs,bool,StrategyMap * stra_map,const std::string & strategy_key_name)292 void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const PrimitivePtr &prim,
293 mindspore::HashMap<std::string, ValuePtr> attrs, bool, StrategyMap *stra_map,
294 const std::string &strategy_key_name) {
295 // In this case, the configured strategy should be extracted to help setting cost
296 StrategyPtr strategyPtr;
297 if (StrategyFound(attrs)) {
298 strategyPtr = parallel::ExtractStrategy(attrs[IN_STRATEGY]);
299 } else {
300 strategyPtr = (*stra_map)[strategy_key_name];
301 }
302
303 if (strategyPtr == nullptr) {
304 return;
305 }
306
307 if (prim->name() == RESHAPE) {
308 MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
309 }
310
311 // Set cost for this configured strategy
312 if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) {
313 MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed";
314 }
315
316 const auto fully_use_devices = CostModelContext::GetInstance()->fully_use_device();
317 if (fully_use_devices) {
318 // If configured to fully use devices, then checking for the user-specified strategy
319 int64_t used_devices = operator_info->used_devices();
320 MS_EXCEPTION_IF_NULL(g_device_manager);
321 auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size();
322
323 // 'used_devices == -1' means that 'used_devices_' is not set
324 // 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel
325 if (used_devices == -1 || (used_devices != 1 && LongToSize(used_devices) != total_device_num)) {
326 MS_LOG(EXCEPTION) << "In current configuration 'fully_use_devices' = True, "
327 << "but the specified strategy uses device: " << used_devices
328 << ", total devices: " << total_device_num
329 << ", try to set 'set_algo_parameters(fully_use_devices=False)' "
330 "in package 'mindspore.parallel'.";
331 }
332 }
333 (void)configured_stra_ops_.emplace(operator_info, strategyPtr);
334 }
335
ApplyApproximationForNode(const OperatorInfoPtr & operator_info)336 void ApplyApproximationForNode(const OperatorInfoPtr &operator_info) {
337 auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
338 if (approximation) {
339 operator_info->ApproximateStrategies();
340 MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
341 }
342 }
343
AddOperatorToIgnoreCandidates(const PrimitivePtr & prim,const OperatorInfoPtr & operator_info)344 void AddOperatorToIgnoreCandidates(const PrimitivePtr &prim, const OperatorInfoPtr &operator_info) {
345 if (prim->name() == CAST) {
346 // add CAST into ignore_candidate
347 (void)ignore_candidate_.insert(operator_info);
348 }
349 }
350
GenerateStrategiesByOperatorInfoPtr(const OperatorInfoPtr & operator_info)351 bool GenerateStrategiesByOperatorInfoPtr(const OperatorInfoPtr &operator_info) {
352 Status retGenStra;
353 auto attrs = operator_info->attrs();
354 if (AttrFound(attrs, STRATEGY_GEN_MODE) && GetValue<std::string>(attrs[STRATEGY_GEN_MODE]) == kDataParallel) {
355 MS_LOG(INFO) << "generating batch parallel strategy...";
356 auto prim = std::make_shared<Primitive>(operator_info->name());
357 StrategyPtr strategyPtr = parallel::GenerateBatchParallelStrategy(operator_info, prim);
358 retGenStra = operator_info->SetCostUnderStrategy(strategyPtr);
359 attrs = prim->attrs();
360 operator_info->addAttr(IN_STRATEGY, attrs[GEN_STRATEGY]); // for d-rec
361 } else {
362 MS_LOG(INFO) << "auto-searching strategy...";
363 retGenStra = operator_info->GenerateStrategies(0);
364 }
365 if (retGenStra != SUCCESS) {
366 MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed.";
367 return false;
368 }
369 return true;
370 }
371
CreateTheOperatorInfo(const PrimitivePtr & prim,const CNodePtr & cnode,bool is_last_nodes,StrategyMap * stra_map)372 OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes,
373 StrategyMap *stra_map) {
374 MS_EXCEPTION_IF_NULL(prim);
375 MS_EXCEPTION_IF_NULL(cnode);
376 // Create an OperatorInfo instance
377 OperatorInfoPtr operator_info = CreateOperatorInfo(cnode);
378 MS_EXCEPTION_IF_NULL(operator_info);
379 // Set the parameter information for this OperatorInfo (whether the inputs are parameters or not)
380 std::vector<bool> parameter_info = ExtractInputParameterByNode(cnode);
381 if (operator_info->set_is_parameter(parameter_info) != SUCCESS) {
382 MS_LOG(ERROR) << "Initializing parameter information failed for operator: " << operator_info->name();
383 return nullptr;
384 }
385 // Set the data type for inputs and outputs of this OperatorInfo
386 auto inputs_type_length = ExtractInputTypeLengthByNode(cnode);
387 auto outputs_type = ExtractOutputTypeByNode(cnode);
388 if (ParallelContext::GetInstance()->strategy_search_mode() == kRecursiveProgramming) {
389 std::string param_name = ExtractInputParameterNameByNode(cnode);
390 if (!param_name.empty()) {
391 operator_info->set_involved_param_name(param_name);
392 }
393 }
394 std::vector<size_t> outputs_type_length;
395 outputs_type_length.reserve(outputs_type.size());
396 (void)std::transform(outputs_type.begin(), outputs_type.end(), std::back_inserter(outputs_type_length),
397 GetLengthOfDataType);
398 if (operator_info->SetInputAndOutputTypeLength(inputs_type_length, outputs_type_length) != SUCCESS) {
399 MS_LOG(ERROR) << "Setting the lengths of inputs and outputs failed for operator: " << operator_info->name();
400 return nullptr;
401 }
402 if (operator_info->set_outputs_type(outputs_type) != SUCCESS) {
403 MS_LOG(ERROR) << "Setting the types of outputs failed for operator: " << operator_info->name();
404 return nullptr;
405 }
406
407 operator_info->set_auto_parallel(true);
408
409 AddOperatorToIgnoreCandidates(prim, operator_info);
410 // key of strategy map
411 std::string strategy_key_name;
412 auto param_names = NodeParameterName(cnode, -1, 0);
413 if (!param_names.empty()) {
414 strategy_key_name = prim->name() + "_" + param_names[0].first;
415 }
416 bool load_strategy_from_ckpt =
417 StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
418 // If no strategy has been configured for this operator, then candidate strategies are generated for
419 // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
420 // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
421 auto attrs = prim->attrs();
422 if (ParallelContext::GetInstance()->strategy_search_mode() != kRecursiveProgramming) {
423 if ((StrategyFound(attrs) && prim->name() != CAST) || load_strategy_from_ckpt) {
424 SetStrategyToOperator(operator_info, prim, attrs, is_last_nodes, stra_map, strategy_key_name);
425 return operator_info;
426 }
427 }
428 // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
429 // BatchParallelInfo operator
430 operator_info->ComputeBatchSplitFlagList();
431 if ((ParallelContext::GetInstance()->strategy_search_mode() != kRecursiveProgramming)) {
432 (void)GenerateStrategiesByOperatorInfoPtr(operator_info);
433 }
434
435 bool use_sp_and_dataset = ((ParallelContext::GetInstance()->strategy_search_mode() == kShardingPropagation) ||
436 (ParallelContext::GetInstance()->sharding_propagation())) &&
437 (operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos);
438 if (use_sp_and_dataset) {
439 const auto &swc_vec = operator_info->GetStrategyCost();
440 if (swc_vec.empty()) {
441 MS_LOG(EXCEPTION) << "No available strategy for: " << operator_info->name();
442 }
443 MS_EXCEPTION_IF_NULL(swc_vec[0]->strategy_ptr);
444 (void)configured_stra_ops_.emplace(operator_info, swc_vec[0]->strategy_ptr);
445 }
446 // If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated
447 ApplyApproximationForNode(operator_info);
448 return operator_info;
449 }
450
IsFindWrong(const OperatorInfoPtr & current_op_ptr,const std::string & prim_name)451 bool IsFindWrong(const OperatorInfoPtr ¤t_op_ptr, const std::string &prim_name) {
452 bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
453 (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) &&
454 (current_op_ptr->name().find(prim_name + "Info") == std::string::npos);
455 if (prim_name == GATHERV2) {
456 is_find_wrong = is_find_wrong && (current_op_ptr->name().find(prim_name + "PInfo") == std::string::npos);
457 }
458 return is_find_wrong;
459 }
460
AddUsersUniqueIdWhenSharingParameter(const std::pair<std::string,std::pair<AnfNodePtr,AnfNodeIndexSet>> & parameter_users_info)461 void AddUsersUniqueIdWhenSharingParameter(
462 const std::pair<std::string, std::pair<AnfNodePtr, AnfNodeIndexSet>> ¶meter_users_info) {
463 auto users_set = parameter_users_info.second.second;
464 if (users_set.size() > 1) {
465 MS_LOG(INFO) << "Parameter " << parameter_users_info.first << " has " << users_set.size() << " users.";
466 std::vector<std::string> param_users_uniqueid;
467 for (const auto &user : users_set) {
468 MS_LOG(INFO) << "with ID: " << user.first->UniqueId() << " and name: " << user.first->UniqueName();
469 param_users_uniqueid.push_back(user.first->UniqueId());
470 }
471 entire_costgraph->add_param_users_uniqueid(param_users_uniqueid);
472 }
473 }
474
AddParamUsersForRec(const std::vector<AnfNodePtr> & all_nodes)475 void AddParamUsersForRec(const std::vector<AnfNodePtr> &all_nodes) {
476 for (auto &node : all_nodes) {
477 if (node->isa<Parameter>()) {
478 ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsParallelCareNode, all_nodes);
479 AddUsersUniqueIdWhenSharingParameter(parameter_users_info);
480 }
481 }
482 }
483
484 // Using CNode's UniqueIds to construct nodes
ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr &)485 Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
486 MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
487 // The map from CNode's UniqueId to its operatorInfo
488 std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
489 // The operator_infos in a loop
490 std::vector<OperatorInfoPtr> operators_in_forloop;
491 // Key: i-th loop; Value: index of 'operators_in_forloop'
492 std::map<size_t, size_t> loop_to_ops;
493 // extract strategy from checkpoint for multi-train
494 StrategyMap stra_map;
495 if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
496 if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
497 MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
498 }
499 }
500
501 for (auto &node : all_nodes) {
502 // NOTE: we only care about splittable Primitive operators
503 auto cnode = node->cast<CNodePtr>();
504 bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)));
505 if (bool_result) {
506 continue;
507 }
508 auto prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
509 if (!IsAutoParallelCareNode(cnode)) {
510 // Needed by rec_parser
511 if (ParallelContext::GetInstance()->strategy_search_mode() == kRecursiveProgramming) {
512 auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node);
513 if (prev_cnode != nullptr) {
514 entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId()));
515 }
516 }
517 continue;
518 }
519 auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
520 MS_EXCEPTION_IF_NULL(prim);
521
522 auto search_cnode = from_cnode_to_info.find(cnode->UniqueId() + prim->name());
523 if (search_cnode == from_cnode_to_info.cend()) {
524 size_t loop_index = 0;
525 bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index);
526 const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop();
527 if (single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) {
528 const auto ¤t_op_ptr = operators_in_forloop[loop_to_ops[loop_index]];
529 if (IsFindWrong(current_op_ptr, prim->name())) {
530 MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
531 << " does not match the Prim: " << prim->name()
532 << ". The fullname_with_scope: " << cnode->fullname_with_scope();
533 }
534 loop_to_ops[loop_index]++;
535 cnode->set_user_data<OperatorInfo>(current_op_ptr);
536 MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
537 << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
538 << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
539 << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
540 (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId() + prim->name(), current_op_ptr));
541 continue;
542 }
543 bool is_last_nodes = IsPrimitiveCNode(cnode, prim::kPrimVirtualOutput);
544 auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
545 if (operator_info == nullptr) {
546 return FAILED;
547 }
548 if (ParallelContext::GetInstance()->strategy_search_mode() == kRecursiveProgramming) {
549 operator_info->set_type(prim->name());
550 operator_info->set_last_node_flag(is_last_nodes);
551 std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode, all_nodes);
552 entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
553 }
554
555 entire_costgraph->AddOperator(operator_info);
556 cnode->set_user_data<OperatorInfo>(operator_info);
557 MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
558 << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
559 << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
560 << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
561 (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId() + prim->name(), operator_info));
562 if (single_loop && is_in_loop) {
563 operators_in_forloop.push_back(operator_info);
564 (void)ops_in_a_loop_.insert(operator_info->name());
565 loop_to_ops[loop_index]++;
566 }
567 } else {
568 // Two CNODEs' UniqueIds should not be equal
569 MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId()
570 << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
571 << " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name();
572 }
573 }
574
575 MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
576 // Needed by rec_parser 2
577 AddParamUsersForRec(all_nodes);
578
579 return SUCCESS;
580 }
581
SetOperatorToCNode(const OperatorInfoPtr & current_op_ptr,const PrimitivePtr & prim,const CNodePtr & cnode)582 void SetOperatorToCNode(const OperatorInfoPtr ¤t_op_ptr, const PrimitivePtr &prim, const CNodePtr &cnode) {
583 if (current_op_ptr == nullptr) {
584 MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed.";
585 } else {
586 if (IsFindWrong(current_op_ptr, prim->name())) {
587 MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
588 << " does not match the Prim: " << prim->name();
589 }
590
591 // Needed by rec_parser
592 ModifyInputsTensorNameListIfOperatorInfoCreated(current_op_ptr->name(), cnode->UniqueId());
593
594 cnode->set_user_data<OperatorInfo>(current_op_ptr);
595 current_op_ptr->set_cnode(cnode);
596 MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
597 << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
598 << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
599 << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
600 }
601 }
602
603 // Using CNode's UniqueIdThroughCopys to construct nodes
ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr &)604 Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
605 MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
606 // The map from CNode's UniqueIdThroughCopy to its operatorInfo
607 std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
608 // The operator_infos in a loop
609 std::vector<OperatorInfoPtr> operators_in_forloop;
610 // Key: i-th loop; Value: index of 'operators_in_forloop'
611 std::map<size_t, size_t> loop_to_ops;
612 // extract strategy from checkpoint for multi-train
613 StrategyMap stra_map;
614 if (StrategyCheckpoint::GetInstance().LoadCheckPointOn() &&
615 StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
616 MS_LOG(WARNING) << "Load strategy checkpoint failed";
617 return FAILED;
618 }
619 for (auto &node : all_nodes) {
620 // NOTE: we only care about splittable Primitive operators
621 auto cnode = node->cast<CNodePtr>();
622 if ((cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)))) {
623 continue;
624 }
625 auto prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
626 if (!IsAutoParallelCareNode(cnode)) {
627 // Needed by rec_parser
628 if (ParallelContext::GetInstance()->strategy_search_mode() == kRecursiveProgramming) {
629 auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node);
630 if (prev_cnode != nullptr) {
631 entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId()));
632 }
633 }
634 continue;
635 }
636 auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
637
638 // Find the operatorInfo if it exists
639 auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy() + prim->name());
640 if (search_cnode == from_cnode_to_info.cend()) {
641 size_t loop_index = 0;
642 bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index);
643 const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop();
644 bool is_op_created = single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size());
645 if (is_op_created) {
646 const auto ¤t_op_ptr = operators_in_forloop[loop_to_ops[loop_index]];
647 if (IsFindWrong(current_op_ptr, prim->name())) {
648 MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
649 << " does not match the Prim: " << prim->name()
650 << ". The fullname_with_scope: " << cnode->fullname_with_scope();
651 }
652 loop_to_ops[loop_index]++;
653 cnode->set_user_data<OperatorInfo>(current_op_ptr);
654 MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
655 << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
656 << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
657 << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
658 (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy() + prim->name(), current_op_ptr));
659 continue;
660 }
661 // In this case, the corresponding OperatorInfo is not created, create the new one.
662 bool is_last_nodes = IsPrimitiveCNode(cnode, prim::kPrimVirtualOutput);
663 auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
664 MS_EXCEPTION_IF_NULL(operator_info);
665
666 if (ParallelContext::GetInstance()->strategy_search_mode() == kRecursiveProgramming) {
667 operator_info->set_type(prim->name());
668 operator_info->set_last_node_flag(is_last_nodes);
669 std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode, all_nodes);
670 entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
671 }
672
673 entire_costgraph->AddOperator(operator_info);
674 cnode->set_user_data<OperatorInfo>(operator_info);
675 MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
676 << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
677 << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
678 << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
679 (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy() + prim->name(), operator_info));
680 if (single_loop && is_in_loop) {
681 operators_in_forloop.push_back(operator_info);
682 (void)ops_in_a_loop_.insert(operator_info->name());
683 loop_to_ops[loop_index]++;
684 }
685 } else {
686 SetOperatorToCNode(search_cnode->second, prim, cnode);
687 }
688 }
689
690 MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
691 // Needed by rec_parser 2
692 AddParamUsersForRec(all_nodes);
693
694 return SUCCESS;
695 }
696
PreProcessPreCastForSP(const OperatorInfoPtr & prev_op_info,const OperatorInfoPtr & node_op_info,const CNodePtr & cnode,const EdgePtr edge_ptr,size_t input_index)697 void PreProcessPreCastForSP(const OperatorInfoPtr &prev_op_info, const OperatorInfoPtr &node_op_info,
698 const CNodePtr &cnode, const EdgePtr edge_ptr, size_t input_index) {
699 if (IsPrimitiveCNode(cnode, prim::kPrimMatMul) && input_index == INDEX_TWO) {
700 prev_op_info->set_repeated_num_in_dev_matrix_right(false);
701 prev_op_info->ClearStrategyCost();
702 (void)prev_op_info->GenerateStrategies(0);
703 }
704 if ((configured_stra_ops_.find(node_op_info) != configured_stra_ops_.end())) {
705 const auto next_op_stra = configured_stra_ops_[node_op_info];
706 if (edge_ptr->InitEdgeCost() != SUCCESS) {
707 MS_LOG(EXCEPTION) << "Edge cost initialization failed";
708 }
709 const auto cast_stra = edge_ptr->GetPrevOpStrategyByNextOpStrategyWithMiniComm(next_op_stra);
710 if (cast_stra == nullptr) {
711 MS_LOG(EXCEPTION) << "No available strategy for: " << prev_op_info->name();
712 }
713 prev_op_info->ClearStrategyCost();
714 if (prev_op_info->SetCostUnderStrategy(cast_stra) != SUCCESS) {
715 MS_LOG(EXCEPTION) << "Failure: operator " << prev_op_info->name() << " SetCostUnderStrategy failed";
716 }
717 if (edge_ptr->InitEdgeCost() != SUCCESS) {
718 MS_LOG(EXCEPTION) << "Edge cost re-initialization failed.";
719 }
720 MS_LOG(INFO) << "Set strategy for: " << prev_op_info->name() << " under the strategy of: " << node_op_info->name();
721 (void)configured_stra_ops_.emplace(prev_op_info, cast_stra);
722 }
723 }
724
CreateEdgeBetweenTwoOps(const OperatorInfoPtr & prev_op_info,const OperatorInfoPtr & node_op_info,const CNodePtr & cnode,const CNodePtr & prev_cnode,const PrimitivePtr & prim,const PrimitivePtr & prev_prim,size_t output_index,size_t input_index,size_t * edge_count)725 void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const OperatorInfoPtr &node_op_info,
726 const CNodePtr &cnode, const CNodePtr &prev_cnode, const PrimitivePtr &prim,
727 const PrimitivePtr &prev_prim, size_t output_index, size_t input_index,
728 size_t *edge_count) {
729 std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name();
730 // If the edge between these two operators already has been added, then the edge will not be added again.
731 if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, input_index - 1)) {
732 return;
733 }
734 EdgePtr edge_ptr;
735 MS_LOG(INFO) << "Creating edge: " << edge_name;
736 if (IsOperatorsInTwoSeparateLoops(prev_cnode, cnode)) {
737 MS_LOG(INFO) << "prev_cnode_fullname: " << prev_cnode->fullname_with_scope()
738 << ", cnode_fullname: " << cnode->fullname_with_scope();
739 MS_LOG(INFO) << "The two operators in two separate for-loops, thus skip the edge.";
740 return;
741 }
742 const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow();
743 bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) ||
744 (stra_follow && IsElementWiseOperator(prev_prim->name()));
745 if (follow_strategy) {
746 // Redistribution in not allowed on the edge.
747 // Elementwise operators have the same strategy as their previous operators.
748 edge_ptr =
749 std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, input_index - 1, false, true);
750 } else {
751 edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, input_index - 1, false);
752 }
753 bool use_sp = (ParallelContext::GetInstance()->strategy_search_mode() == kShardingPropagation) ||
754 (ParallelContext::GetInstance()->sharding_propagation());
755 // Init costs for this edge
756 if (ParallelContext::GetInstance()->strategy_search_mode() != kRecursiveProgramming) {
757 if (!use_sp && edge_ptr->InitEdgeCost() != SUCCESS) {
758 MS_LOG(EXCEPTION) << "Edge cost initialization failed";
759 }
760 }
761 node_op_info->AddPrevEdge(edge_ptr);
762 prev_op_info->AddSuccEdge(edge_ptr);
763 entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr);
764 if (use_sp && prev_prim->name() == CAST) {
765 PreProcessPreCastForSP(prev_op_info, node_op_info, cnode, edge_ptr, input_index);
766 }
767 MS_LOG(INFO) << "Successfully adding the edge between " << prev_op_info->name() << " and " << node_op_info->name();
768 (*edge_count)++;
769 }
770
ApplyApproximationForGraphs()771 void ApplyApproximationForGraphs() {
772 // If 'approximation' is enabled, the edges need to be checked have effective costs.
773 auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
774 if (approximation) {
775 entire_costgraph->CheckApproximateCostGraphEdges();
776 }
777 }
778
CreateEdgeAccrossMakeList(const CNodePtr & cnode,const PrimitivePtr & prim,const OperatorInfoPtr & node_op_info,CNodePtr * prev_cnode,ValueNodePtr * prev_prim_anf_node,PrimitivePtr * prev_prim,size_t * edge_count)779 static void CreateEdgeAccrossMakeList(const CNodePtr &cnode, const PrimitivePtr &prim,
780 const OperatorInfoPtr &node_op_info, CNodePtr *prev_cnode,
781 ValueNodePtr *prev_prim_anf_node, PrimitivePtr *prev_prim, size_t *edge_count) {
782 MS_LOG(INFO) << "Creating edges across the 'make_list' operator.";
783 const auto &sub_inputs = (*prev_cnode)->inputs();
784 for (size_t j = 1; j < sub_inputs.size(); ++j) {
785 *prev_cnode = sub_inputs[j]->cast<CNodePtr>();
786 bool bool_result_list = (*prev_cnode == nullptr) || !IsValueNode<Primitive>((*prev_cnode)->input(0)) ||
787 !IsAutoParallelCareNode(*prev_cnode);
788 if (bool_result_list) {
789 continue;
790 }
791 *prev_prim_anf_node = (*prev_cnode)->input(0)->cast<ValueNodePtr>();
792 *prev_prim = (*prev_prim_anf_node)->value()->cast<PrimitivePtr>();
793 auto prev_op_info = (*prev_cnode)->user_data<OperatorInfo>();
794 CreateEdgeBetweenTwoOps(prev_op_info, node_op_info, cnode, *prev_cnode, prim, *prev_prim, 0, j, edge_count);
795 }
796 }
797
ConstructCNodeCostGraphEdges(const mindspore::CNodePtr & cnode,const std::vector<AnfNodePtr> & all_nodes)798 static void ConstructCNodeCostGraphEdges(const mindspore::CNodePtr &cnode, const std::vector<AnfNodePtr> &all_nodes) {
799 auto &inputs = cnode->inputs();
800 ValueNodePtr prim_anf_node = inputs[0]->cast<ValueNodePtr>();
801 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
802 size_t edge_count = 0;
803 auto node_op_info = cnode->user_data<OperatorInfo>();
804
805 for (size_t i = 1; i < inputs.size(); ++i) {
806 AnfNodePtr prev_node = inputs[i];
807 if (inputs[i]->isa<Parameter>()) {
808 prev_node = FindRealInputByFormalParameter(cnode, inputs[i], all_nodes);
809 if (prev_node->UniqueId() == inputs[i]->UniqueId()) {
810 continue;
811 }
812 }
813 auto prev_cnode = prev_node->cast<CNodePtr>();
814 PrimitivePtr prev_prim;
815 ValueNodePtr prev_prim_anf_node;
816 bool is_cross = CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
817 if (is_cross) {
818 continue;
819 }
820 size_t output_index = 0;
821 bool is_before_tuple_get_item = false;
822
823 while (IsCarePrevCNode(prev_cnode, prev_prim)) {
824 if (IsValueNode<FuncGraph>(prev_cnode->input(0))) {
825 auto graph = GetValueNode<FuncGraphPtr>(prev_cnode->input(0));
826 auto output = graph->output();
827 MS_EXCEPTION_IF_NULL(output);
828 prev_cnode = output->cast<CNodePtr>();
829 (void)CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
830 } else if (IsAutoParallelCareNode(prev_cnode)) {
831 auto prev_op_info = prev_cnode->user_data<OperatorInfo>();
832 CreateEdgeBetweenTwoOps(prev_op_info, node_op_info, cnode, prev_cnode, prim, prev_prim, output_index, i,
833 &edge_count);
834 break;
835 } else if (prev_prim->name() == prim::kPrimTupleGetItem->name()) {
836 // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before
837 // this 'tuple_getitem'
838 output_index = LongToSize(GetValue<int64_t>(GetValueNode(prev_cnode->input(2))));
839 prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
840 is_cross = CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
841 if (is_cross) {
842 break;
843 }
844 if (!IsAutoParallelCareNode(prev_cnode) && !IsValueNode<FuncGraph>(prev_cnode->input(0))) {
845 MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name();
846 }
847 is_before_tuple_get_item = true;
848 } else if (prev_prim->name() == kMakeTupleOpName) {
849 if (!is_before_tuple_get_item) {
850 CreateEdgeAccrossMakeList(cnode, prim, node_op_info, &prev_cnode, &prev_prim_anf_node, &prev_prim,
851 &edge_count);
852 break;
853 }
854 prev_cnode = prev_cnode->input(output_index + 1)->cast<CNodePtr>();
855 output_index = 0;
856 is_cross = CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
857 if (is_cross) {
858 break;
859 }
860 is_before_tuple_get_item = false;
861 } else if (prev_prim->name() == kMakeListOpName) {
862 CreateEdgeAccrossMakeList(cnode, prim, node_op_info, &prev_cnode, &prev_prim_anf_node, &prev_prim, &edge_count);
863 break;
864 } else if (prev_prim->name() == kDependOpName || prev_prim->name() == kLoadOpName) {
865 // In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before
866 // this 'depend'
867 prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
868 is_cross = CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
869 if (is_cross) {
870 break;
871 }
872 is_before_tuple_get_item = true;
873 }
874 }
875 }
876 MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name();
877 }
878
ConstructCostGraphEdges(const std::vector<AnfNodePtr> & all_nodes)879 void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
880 // Step 2
881 MS_LOG(INFO) << "Constructing edges for cost graph begins.";
882 for (auto &node : all_nodes) {
883 auto cnode = node->cast<CNodePtr>();
884 if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
885 continue;
886 }
887 if (!IsAutoParallelCareNode(cnode)) {
888 continue;
889 }
890 ConstructCNodeCostGraphEdges(cnode, all_nodes);
891 }
892 ApplyApproximationForGraphs();
893
894 MS_LOG(INFO) << "Constructing edges for cost graph ends.";
895 }
896
ApplyApproximationForParaNode(const OperatorInfoPtr & target_op_info)897 void ApplyApproximationForParaNode(const OperatorInfoPtr &target_op_info) {
898 // If 'approximation' is enabled, the edges need to be checked have effective costs.
899 auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
900 if (approximation) {
901 target_op_info->ExactStrategiesAndRelatedEdges();
902 }
903 }
904
CreateIdentityOp(const std::string & parameter_name,const AnfNodePtr & target_parameter)905 std::pair<OperatorInfoPtr, bool> CreateIdentityOp(const std::string ¶meter_name,
906 const AnfNodePtr &target_parameter) {
907 // Here, it is sure that this Parameter (RefKey) is being used by multiple Operators.
908 OperatorInfoPtr tmp_identity_ptr;
909 bool new_identity = false;
910 auto returned_identity = entire_costgraph->FindTmpIdentityByParameterName(parameter_name);
911 if (returned_identity != nullptr) {
912 // In this case, the TmpIdentityInfo instance has already been created
913 new_identity = false;
914 tmp_identity_ptr = returned_identity;
915 } else {
916 // In the case, the TmpIdentityInfo instance has NOT been created. Thus, a new one is created.
917 new_identity = true;
918 // 1) extract input shape from this Parameter
919 MS_EXCEPTION_IF_NULL(target_parameter);
920 AbstractBasePtr abstract = target_parameter->abstract();
921 if (abstract == nullptr) {
922 MS_LOG(EXCEPTION) << "Failure: abstract is nullptr";
923 }
924 auto input_shape = dyn_cast<abstract::Shape>(abstract->GetShapeTrack());
925 if (input_shape == nullptr) {
926 MS_LOG(EXCEPTION) << "Failure: input_shape is nullptr";
927 }
928 Shape shape = input_shape->shape();
929 Shapes inputs_shape = {shape};
930 Shapes outputs_shape = {shape};
931 // 2) init the attr
932 mindspore::HashMap<std::string, ValuePtr> attr = {};
933
934 // Create the TmpIdentity instance
935 tmp_identity_ptr = std::make_shared<TmpIdentityInfo>(inputs_shape, outputs_shape, attr);
936 tmp_identity_ptr->set_name(tmp_identity_ptr->name() + std::to_string(TOTAL_OPS));
937 TOTAL_OPS++;
938 tmp_identity_ptr->set_refkey_parameter_name(parameter_name);
939 // Set the parameter and type lengths for inputs and outputs
940 std::vector<bool> is_parameter;
941 auto casted_target_parameter = target_parameter->cast<ParameterPtr>();
942 MS_EXCEPTION_IF_NULL(casted_target_parameter);
943 is_parameter.push_back(ParameterRequireGrad(casted_target_parameter));
944 if (tmp_identity_ptr->set_is_parameter(is_parameter) != SUCCESS) {
945 MS_LOG(EXCEPTION) << "Setting parameter for TmpIdentityInfo failed";
946 }
947 auto node_type = target_parameter->Type();
948 if (node_type->isa<mindspore::TensorType>()) {
949 auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
950 std::vector<size_t> type_length = {GetLengthOfDataType(input_element_type)};
951 if (tmp_identity_ptr->SetInputAndOutputTypeLength(type_length, type_length) != SUCCESS) {
952 MS_LOG(EXCEPTION) << "Setting input and output type length for TmpIdentityInfo failed";
953 }
954 } else {
955 MS_LOG(EXCEPTION) << "Unknown type: " << node_type->type_name();
956 }
957
958 // Generate strategies for this TmpIdentityInfo instance;
959 if (tmp_identity_ptr->GenerateStrategies(0) != SUCCESS) {
960 MS_LOG(EXCEPTION) << "Strategy search for Operator failed : " << tmp_identity_ptr->name();
961 }
962 }
963 return std::make_pair(tmp_identity_ptr, new_identity);
964 }
965
AugmentCostGraph(const std::vector<AnfNodePtr> & all_nodes)966 void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
967 // Step 3
968 for (auto &node : all_nodes) {
969 ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsAutoParallelCareNode, all_nodes);
970 auto parameter_name = parameter_users_info.first;
971 auto target_parameter = parameter_users_info.second.first;
972 auto target_set = parameter_users_info.second.second;
973 if (target_set.size() <= 1) {
974 continue;
975 }
976
977 // Rule out the case when a Parameter being used by a Operator, but the Operator appears in multiple CNODEs
978 std::set<std::string> target_without_duplicate;
979 for (auto &target : target_set) {
980 auto target_cnode = target.first->cast<CNodePtr>();
981 // Eliminate the ops without cost.
982 if (IsSomePrimitive(target_cnode, SEND)) {
983 continue;
984 }
985 auto input_index = target.second;
986 (void)target_without_duplicate.insert(std::to_string(input_index) +
987 target_cnode->user_data<OperatorInfo>()->name());
988 }
989 if (target_without_duplicate.size() <= 1 || parameter_name.empty()) {
990 continue;
991 }
992
993 auto pair = CreateIdentityOp(parameter_name, target_parameter);
994 OperatorInfoPtr tmp_identity_ptr = pair.first;
995 bool new_identity = pair.second;
996 // A flag recording whether new edges have been created or not
997 bool add_identity_edge = false;
998
999 // Create edges between this TmpIdentityInfo instance and subsequent Operator instances
1000 for (auto &target : target_set) {
1001 auto target_cnode = target.first->cast<CNodePtr>();
1002 auto input_index = target.second;
1003 auto target_op_info = target_cnode->user_data<OperatorInfo>();
1004 if (!target_op_info->repeated_num_in_dev_matrix_right() && tmp_identity_ptr->repeated_num_in_dev_matrix_right()) {
1005 tmp_identity_ptr->set_repeated_num_in_dev_matrix_right(false);
1006 tmp_identity_ptr->ClearStrategyCost();
1007 (void)tmp_identity_ptr->GenerateStrategies(0);
1008 }
1009
1010 std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name();
1011 // If the edge between these two operators already has been added, then the edge will not be added again.
1012 if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, LongToSize(input_index - 1))) {
1013 continue;
1014 }
1015 std::shared_ptr<Edge> edge_ptr =
1016 std::make_shared<Edge>(edge_name, tmp_identity_ptr, target_op_info, 0, input_index - 1, false, true);
1017 ApplyApproximationForParaNode(target_op_info);
1018
1019 bool use_sp = (ParallelContext::GetInstance()->strategy_search_mode() == kShardingPropagation) ||
1020 (ParallelContext::GetInstance()->sharding_propagation());
1021 if (!use_sp && edge_ptr->InitEdgeCost() != SUCCESS) {
1022 MS_LOG(EXCEPTION) << "Edge cost initialization failed";
1023 }
1024 target_op_info->AddPrevEdge(edge_ptr);
1025 tmp_identity_ptr->AddSuccEdge(edge_ptr);
1026 entire_costgraph->AddEdge(tmp_identity_ptr, target_op_info, edge_ptr);
1027 MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and "
1028 << target_op_info->name();
1029 add_identity_edge = true;
1030 }
1031 if (new_identity && add_identity_edge) {
1032 // Add the TmpIdentityInfo to CostGraph if BOTH two conditions are satisfied
1033 entire_costgraph->AddOperator(tmp_identity_ptr);
1034 }
1035 }
1036 }
1037
ReshapeCostCompute(const std::vector<AnfNodePtr> & all_nodes)1038 void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
1039 mindspore::HashSet<std::string> op_cache;
1040 for (const auto &node : all_nodes) {
1041 auto cnode = node->cast<CNodePtr>();
1042 if (!FindReshape(cnode, &op_cache)) {
1043 continue;
1044 }
1045 MS_ASSERT(cnode->size() == 3);
1046 // get previous node's strategy_cost_
1047 auto pre_node = cnode->input(1);
1048 if (IsPrimitiveCNode(pre_node, prim::kPrimLoad)) {
1049 pre_node = pre_node->cast<CNodePtr>()->input(1);
1050 }
1051 int64_t out_index = 0;
1052 OperatorInfoPtr pre_operator_info;
1053 std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs;
1054 auto operator_info = cnode->user_data<OperatorInfo>();
1055 bool is_prev_param = false;
1056 if (!FindReshapePreNodeStraCosts(pre_node, &pre_operator_info, &is_prev_param, &out_index, 0)) {
1057 MS_LOG(EXCEPTION) << "FindReshapePreNodeStraCosts for reshape failed";
1058 }
1059 // 如果是双递归的话枚举reshape和前向算子的策略
1060 if (ParallelContext::GetInstance()->strategy_search_mode() == kRecursiveProgramming) {
1061 ConstructCNodeCostGraphEdges(cnode, all_nodes);
1062 }
1063 if (is_prev_param) {
1064 auto reshape_info1 = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
1065 reshape_info1->SetCostForReshapeWithParameter();
1066 pre_operator_info = reshape_info1;
1067 pre_stra_costs = reshape_info1->strategy_cost();
1068 } else {
1069 pre_stra_costs = pre_operator_info->strategy_cost();
1070 }
1071 // get next node's strategy_cost_
1072 std::vector<std::pair<OperatorInfoPtr, int64_t>> next_ops_index;
1073 bool is_next_reshape = false;
1074 std::vector<std::pair<std::vector<std::shared_ptr<StrategyWithCost>>, int64_t>> next_costs_index;
1075 (void)FindReshapeNextNodeStraCosts(cnode, &next_ops_index, &is_next_reshape, 0);
1076 if (next_ops_index.empty()) {
1077 MS_LOG(INFO) << "FindReshapeNextNodeStraCosts for reshape failed";
1078 }
1079 // set input_layout and output_layout for reshape.
1080 // init reshape and set cost for each input_layout and output_layout.
1081 auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
1082 reshape_info->set_pre_operator_name(pre_operator_info->name());
1083 reshape_info->set_pre_operator_index(out_index);
1084 if (!next_ops_index.empty()) {
1085 for (auto &op_index : next_ops_index) {
1086 // 如果是双递归的话枚举reshape的后向算子的策略
1087 if (ParallelContext::GetInstance()->strategy_search_mode() == kRecursiveProgramming) {
1088 ConstructCNodeCostGraphEdges(op_index.first->cnode(), all_nodes);
1089 }
1090 auto op_cost = op_index.first->strategy_cost();
1091 (void)next_costs_index.emplace_back(std::make_pair(op_cost, op_index.second));
1092 }
1093 reshape_info->set_next_operator_name(next_ops_index[0].first->name());
1094 reshape_info->set_next_operator_index(next_ops_index[0].second);
1095 }
1096 if (ParallelContext::GetInstance()->strategy_search_mode() != kRecursiveProgramming) {
1097 if (reshape_info->GenerateStrategyCosts(pre_stra_costs, next_costs_index, out_index, is_prev_param,
1098 is_next_reshape) != SUCCESS) {
1099 MS_LOG(EXCEPTION) << "Reshape generate strategy costs failed";
1100 }
1101 }
1102 }
1103 }
1104
IgnoreOperatorsInCostGraph()1105 Status IgnoreOperatorsInCostGraph() {
1106 for (auto op = ignore_candidate_.cbegin(); op != ignore_candidate_.cend(); ++op) {
1107 auto cnodes = (*op)->cnodes();
1108 for (auto &cnode : cnodes) {
1109 MS_EXCEPTION_IF_NULL(cnode);
1110 cnode->set_user_data<OperatorInfo>(nullptr);
1111 }
1112 }
1113 return SUCCESS;
1114 }
1115
ParallelStrategySearch(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root)1116 Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
1117 // There are 4 meta-steps to determine the parallelization strategy for the ANF graph.
1118 // Step 1: Traverse the ANF graph, and create NODEs for costgraph:
1119 // create the OperatorInfo object for each primitive, and enumerate the parallelization strategies
1120 // for each OperatorInfo;
1121 // Step 1.1: Deal with 'Reshape':
1122 // For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's
1123 // layout as its output layout.
1124 // Step 2: Traverse the ANF graph, and create EDGES for costgraph:
1125 // create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies
1126 // for each edge, based on the strategies of two OperatorInfos;
1127 // Step 3: Augment the costgraph:
1128 // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity
1129 // operator for this Parameter, and add an edge for the use of this Parameter by each
1130 // subsequent operator;
1131 // Step 3.1: Calculate memory usage:
1132 // note the memory usage calculation is different in training phase and inference phase.
1133 // Step 4: Run the strategy searching algorithm:
1134 // If 'sharding_propagation' is configured to be true, then the configured-sharding-strategies will propagate
1135 // to the non-configured operators, with the goal of minimizing redistribution cost.
1136 // Otherwise, DP algorithm is used to search strategy of the costgraph. Note that there may be several connected
1137 // components in the costgraph, and the DP algorithm runs on each of them.
1138 //
1139 // OUTPUT: the determined strategy for each operator.
1140
1141 InitCostGraph();
1142 bool use_sp = (ParallelContext::GetInstance()->strategy_search_mode() == kShardingPropagation) ||
1143 (ParallelContext::GetInstance()->sharding_propagation());
1144 // Step 1
1145 if (CostModelContext::GetInstance()->is_multi_subgraphs() || use_sp) {
1146 if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
1147 MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
1148 << entire_costgraph->GetOperators().size() << " operators.";
1149 } else {
1150 MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
1151 }
1152 } else {
1153 if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
1154 MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
1155 << entire_costgraph->GetOperators().size() << " operators.";
1156 } else {
1157 MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
1158 }
1159 }
1160 // Step 1.1
1161 ReshapeCostCompute(all_nodes);
1162 // Step 2
1163 ConstructCostGraphEdges(all_nodes);
1164 MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
1165 << " operators, and " << entire_costgraph->GetNumEdges() << " edges.";
1166
1167 // Step 3: Augment the costgraph.
1168 AugmentCostGraph(all_nodes);
1169 auto num_ops = entire_costgraph->GetOperators().size();
1170 SetOpsNumToExecutor(num_ops);
1171 auto num_edges = entire_costgraph->GetNumEdges();
1172 MS_LOG(INFO) << "After the augmenting procedure, there are " << num_ops << " operators, and " << num_edges
1173 << " edges.";
1174
1175 // Step 3.1: Calculate the memory usage
1176 if (!use_sp && entire_costgraph->CalculateMemoryCost() != SUCCESS) {
1177 MS_LOG(EXCEPTION) << "Calculating memory cost failed.";
1178 }
1179
1180 // Step 4: run the strategy searching algorithm
1181 if (use_sp) {
1182 entire_costgraph->StrategyPropagate(configured_stra_ops_);
1183 } else if (GetStrategy(entire_costgraph) != SUCCESS) {
1184 MS_LOG(ERROR) << "Strategy search for cost-graph fails";
1185 return FAILED;
1186 }
1187 MS_LOG(INFO) << "Searching strategy succeeded.";
1188
1189 if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
1190 MS_LOG(INFO) << "Init selected strategy succeeded.";
1191 } else {
1192 MS_LOG(EXCEPTION) << "Init selected strategy failed.";
1193 }
1194
1195 // print the selected strategy
1196 for (auto &op : entire_costgraph->GetOperators()) {
1197 StrategyPtr s_strategy = op->selected_strategy();
1198 if (s_strategy != nullptr) {
1199 MS_LOG(INFO) << op->name() << ": The strategy is: " << s_strategy->ToString();
1200 }
1201 }
1202 // Remove some operatorInfo from the CNODEs
1203 (void)IgnoreOperatorsInCostGraph();
1204
1205 ops_in_a_loop_.clear();
1206 configured_stra_ops_.clear();
1207 ignore_candidate_.clear();
1208
1209 return SUCCESS;
1210 }
1211
RecInputTensorNames(const std::map<std::string,std::string>::iterator & it,std::vector<std::vector<std::string>> input_tensor_names)1212 std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::string, std::string>::iterator &it,
1213 std::vector<std::vector<std::string>> input_tensor_names) {
1214 for (size_t j = 0; j < input_tensor_names.size(); j++) {
1215 for (size_t k = 0; k < input_tensor_names[j].size(); k++) {
1216 if (it->first == input_tensor_names[j][k]) {
1217 input_tensor_names[j][k] = it->second;
1218 break;
1219 }
1220 }
1221 }
1222 return input_tensor_names;
1223 }
1224
GetInternalOperatorInfo(const CNodePtr & cnode,const ValueNodePtr & prim_anf_node)1225 CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node) {
1226 auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
1227 if (prim->name() == prim::kPrimTupleGetItem->name() || prim->name() == DEPEND) {
1228 auto prev_cnode = cnode->input(1)->cast<CNodePtr>();
1229 if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
1230 return nullptr;
1231 }
1232 if (IsValueNode<FuncGraph>(prev_cnode->input(0))) {
1233 size_t out_index = 0;
1234 out_index = LongToSize(GetValue<int64_t>(GetValueNode(prev_cnode->input(INDEX_TWO))));
1235 auto graph = GetValueNode<FuncGraphPtr>(prev_cnode->input(0));
1236 auto output = graph->output();
1237 MS_EXCEPTION_IF_NULL(output);
1238 while (IsPrimitiveCNode(output, prim::kPrimDepend)) {
1239 auto output_cnode = output->cast<CNodePtr>();
1240 MS_EXCEPTION_IF_NULL(output_cnode);
1241 output = output_cnode->input(1);
1242 }
1243 while (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
1244 auto make_tuple_cnode = output->cast<CNodePtr>();
1245 output = make_tuple_cnode->input(out_index + 1);
1246 }
1247 prev_cnode = output->cast<CNodePtr>();
1248 }
1249
1250 auto prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
1251 while (prev_prim->name() == prim::kPrimTupleGetItem->name() || prev_prim->name() == DEPEND) {
1252 prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
1253 if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
1254 return nullptr;
1255 }
1256 prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
1257 }
1258 return prev_cnode;
1259 }
1260 return nullptr;
1261 }
1262
ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string & name,const std::string & uniqueid)1263 void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, const std::string &uniqueid) {
1264 size_t iter_ops = 0;
1265 for (const auto &op : entire_costgraph->GetOperators()) {
1266 if (op->name() == name) {
1267 break;
1268 }
1269 iter_ops = iter_ops + 1;
1270 }
1271
1272 std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
1273 for (size_t i = 0; i < input_tensor_names.size(); i++) {
1274 for (size_t j = 0; j < input_tensor_names[i].size(); j++) {
1275 if (input_tensor_names[i][j] == uniqueid) {
1276 input_tensor_names[i][j] = input_tensor_names[iter_ops][0];
1277 }
1278 }
1279 }
1280
1281 entire_costgraph->set_inputs_tensor_name_list(input_tensor_names);
1282 }
1283
FindOperatorIndexById(const std::string & unique_id,const std::vector<std::vector<std::string>> & input_tensor_names)1284 size_t FindOperatorIndexById(const std::string &unique_id,
1285 const std::vector<std::vector<std::string>> &input_tensor_names) {
1286 for (size_t i = 0; i < input_tensor_names.size(); i++) {
1287 if (input_tensor_names[i][0] == unique_id) {
1288 return i;
1289 }
1290 }
1291 return SIZE_MAX;
1292 }
1293
GetIndexOfOpsSharingInputTensor(const std::vector<std::vector<std::string>> & param_users_uniqueid_list,const std::vector<std::vector<std::string>> & input_tensor_names)1294 std::vector<std::vector<size_t>> GetIndexOfOpsSharingInputTensor(
1295 const std::vector<std::vector<std::string>> ¶m_users_uniqueid_list,
1296 const std::vector<std::vector<std::string>> &input_tensor_names) {
1297 std::vector<std::vector<size_t>> param_users_ops_index;
1298 for (const auto &users_uniqueid : param_users_uniqueid_list) {
1299 std::vector<size_t> users_index;
1300 for (const auto &user_uniqueid : users_uniqueid) {
1301 size_t user_index = FindOperatorIndexById(user_uniqueid, input_tensor_names);
1302 if (user_index != SIZE_MAX) {
1303 users_index.push_back(user_index);
1304 }
1305 }
1306 param_users_ops_index.push_back(users_index);
1307 }
1308 return param_users_ops_index;
1309 }
1310
CalculateMicroBatchSize(const std::shared_ptr<Graph> & graph,const FuncGraphPtr & root)1311 void CalculateMicroBatchSize(const std::shared_ptr<Graph> &graph, const FuncGraphPtr &root) {
1312 // The first dimension of an operator is its batch dimension.
1313 // However, the shape of the first dimension is not the batch_size assigned by users.
1314 // This function helps to calculate the micro batch size in the pipeline scenario.
1315
1316 auto manager = root->manager();
1317 auto ops = entire_costgraph->GetOperators();
1318 AnfNodePtr virtual_dataset_;
1319 for (auto &fg : manager->func_graphs()) {
1320 for (auto &node : fg->nodes()) {
1321 if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
1322 virtual_dataset_ = node;
1323 break;
1324 }
1325 }
1326 }
1327 if (!virtual_dataset_) {
1328 // Normally for auto parallel, virtual dataset is required in order to control the input's parallel strategy.
1329 // However, in some test cases or NN, there is no input data.
1330 // This if condition aims to deal with these cases, and return 1.
1331 graph->micro_batch_size = 1;
1332 return;
1333 }
1334 auto node_user_map = manager->node_users();
1335 auto node_users = node_user_map[virtual_dataset_];
1336 int64_t data_user_size = 0;
1337 int64_t total_batch_size = 0;
1338 for (auto &node_user : node_users) {
1339 if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
1340 auto data_users = manager->node_users()[node_user.first];
1341 auto node_first = data_users.front().first;
1342 if (!IsPrimitiveCNode(node_first, prim::kPrimStridedSlice)) {
1343 data_users.clear();
1344 data_users = node_user_map[node_first];
1345 }
1346 data_user_size = int64_t(data_users.size());
1347 }
1348 }
1349
1350 for (auto op : ops) {
1351 if (op->type() == GET_NEXT) {
1352 for (auto shape : op->outputs_shape()) {
1353 if (!shape.empty()) {
1354 total_batch_size = shape[0];
1355 break;
1356 }
1357 }
1358 break;
1359 }
1360 }
1361 if (data_user_size != 0) {
1362 graph->micro_batch_size = total_batch_size / data_user_size;
1363 MS_LOG(INFO) << "In the pipeline scenario, the micro_batch_size of each stage is " << graph->micro_batch_size;
1364 } else {
1365 MS_LOG(EXCEPTION) << "Data user size equals to 0, which could not be divided by the total batch size";
1366 }
1367 }
1368
CreateNodesForCostGraph(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root)1369 void CreateNodesForCostGraph(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
1370 if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
1371 if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
1372 MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
1373 << entire_costgraph->GetOperators().size() << " operators.";
1374 } else {
1375 MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
1376 }
1377 } else {
1378 if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
1379 MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
1380 << entire_costgraph->GetOperators().size() << " operators.";
1381 } else {
1382 MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
1383 }
1384 }
1385 }
1386
ReInitCostGraph(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root,bool dyn_shape_tmp_fix)1387 void ReInitCostGraph(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root, bool dyn_shape_tmp_fix) {
1388 InitCostGraph();
1389 CreateNodesForCostGraph(all_nodes, root);
1390 if (!dyn_shape_tmp_fix) {
1391 ReshapeCostCompute(all_nodes);
1392 }
1393 }
1394
WriteStrategiesBackToAnfGraph(const std::vector<std::shared_ptr<OperatorInfo>> & ops)1395 void WriteStrategiesBackToAnfGraph(const std::vector<std::shared_ptr<OperatorInfo>> &ops) {
1396 for (auto &op : ops) {
1397 auto op_type = op->type();
1398 if (op_type == CAST || op_type == RESHAPE) {
1399 continue;
1400 }
1401 auto op_strategy = op->selected_strategy()->GetInputDim();
1402 if (!op_strategy.empty()) {
1403 std::vector<ValuePtr> strategies;
1404 (void)std::transform(op_strategy.begin(), op_strategy.end(), std::back_inserter(strategies),
1405 [](const Dimensions &dim) { return MakeValue(dim); });
1406 ValueTuplePtr var = std::make_shared<ValueTuple>(strategies);
1407 op->cnode()->AddPrimalAttr(parallel::IN_STRATEGY, var);
1408 }
1409 }
1410 }
1411
TMpInferBatchMatMul(const std::shared_ptr<Graph> & graph,Graph::NodeType * node)1412 void TMpInferBatchMatMul(const std::shared_ptr<Graph> &graph, Graph::NodeType *node) {
1413 if (node->apply.arguments[0].tensor_shape.shape_c != -1 && node->apply.arguments[1].tensor_shape.shape_c == -1) {
1414 auto infer_shape = node->apply.arguments[0].tensor_shape.shape_c;
1415 node->apply.arguments[1].tensor_shape.shape_c = infer_shape;
1416
1417 if (node->node_out.size() == 0) {
1418 MS_LOG(EXCEPTION) << "The current BatchMatMul (" << node->name << ") does not have an outgoing node.";
1419 }
1420 auto &outgoing_node = graph->nodes[node->node_out[0]];
1421 if (outgoing_node.apply.arguments[0].tensor_shape.shape_c == node->tensor_parm.tensor_shape.shape_c) {
1422 outgoing_node.apply.arguments[0].tensor_shape.shape_c = infer_shape;
1423 }
1424
1425 node->tensor_parm.tensor_shape.shape_c = infer_shape;
1426 }
1427 }
1428
TmpInferForDynamicShapeInSAPP(const std::shared_ptr<Graph> & graph)1429 void TmpInferForDynamicShapeInSAPP(const std::shared_ptr<Graph> &graph) {
1430 for (size_t index = graph->nodes.size(); index > 0; index--) {
1431 auto node = graph->nodes[index - 1];
1432 if (node.apply.op_type == OperatorType::kRecBatchMatMul) {
1433 TMpInferBatchMatMul(graph, &node);
1434 }
1435 }
1436 }
1437
HasUserConfiguredStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops)1438 bool HasUserConfiguredStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops) {
1439 for (auto op : ops) {
1440 auto prim_anf_node = GetValueNode<PrimitivePtr>(op->cnode()->input(0));
1441 bool has_user_configured_strategy = prim_anf_node->HasAttr(parallel::IN_STRATEGY);
1442 if (has_user_configured_strategy) {
1443 return true;
1444 }
1445 }
1446 return false;
1447 }
1448
ParallelStrategyRecSearch(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root,size_t rank_id,const size_t device_num)1449 Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root, size_t rank_id,
1450 const size_t device_num) {
1451 bool dyn_shape_tmp_fix = false;
1452 if (device_num > 0) {
1453 dyn_shape_tmp_fix = true;
1454 }
1455
1456 ReInitCostGraph(all_nodes, root, dyn_shape_tmp_fix);
1457 auto ops = entire_costgraph->GetOperators();
1458 if (dyn_shape_tmp_fix && HasUserConfiguredStrategy(ops)) {
1459 MS_LOG(WARNING) << "Now the split strategy will be automatically generated through SAPP, which will overwrite "
1460 "the strategy that has been manually configured by the user.";
1461 }
1462
1463 std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
1464 // Needed by rec_parser 2
1465 auto param_users_uniqueid_list = entire_costgraph->get_param_users_uniqueid_list();
1466 auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list();
1467 for (auto it = tuple_getitem_list.begin(); it != tuple_getitem_list.end();) {
1468 input_tensor_names = RecInputTensorNames(it++, input_tensor_names);
1469 }
1470 std::shared_ptr<Graph> graph = ParseGraph(ops, input_tensor_names);
1471
1472 std::vector<std::vector<size_t>> param_users_ops_index =
1473 GetIndexOfOpsSharingInputTensor(param_users_uniqueid_list, input_tensor_names);
1474 std::shared_ptr<std::vector<std::vector<size_t>>> eli_list = std::make_shared<std::vector<std::vector<size_t>>>();
1475 std::shared_ptr<std::vector<size_t>> index_list = std::make_shared<std::vector<size_t>>();
1476 graph = EliminateGraph(graph, eli_list, index_list, dyn_shape_tmp_fix);
1477 graph->dyn_shape_tmp_fix = dyn_shape_tmp_fix;
1478
1479 if (graph->dyn_shape_tmp_fix) {
1480 TmpInferForDynamicShapeInSAPP(graph);
1481 }
1482
1483 if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) {
1484 CalculateMicroBatchSize(graph, root);
1485 }
1486
1487 size_t num_device = g_device_manager->DeviceNum();
1488 const auto device_memory = CostModelContext::GetInstance()->device_memory_capacity();
1489 // To specify the process is training or inference. For training, if optimizer parallel is activated, it requires at
1490 // least one cut on DP dimension.
1491 bool isTraining = IsTraining(root->manager());
1492 if (PartitionForAllDevices(num_device, device_memory, graph, isTraining, root) == SUCCESS) {
1493 MS_LOG(INFO) << "Partition Success With " << num_device << " devices.";
1494 } else {
1495 MS_LOG(ERROR) << "PartitionForAllDevices failed.";
1496 return FAILED;
1497 }
1498
1499 // Needed when changing stage number
1500 if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) {
1501 if (!graph->dyn_shape_tmp_fix) {
1502 if (ParallelInit() != SUCCESS) {
1503 MS_LOG(EXCEPTION) << "Parallel init failed after Rec search";
1504 }
1505 } else {
1506 if (ParallelInit(rank_id, device_num) != SUCCESS) {
1507 MS_LOG(EXCEPTION) << "Parallel init failed";
1508 }
1509 }
1510 if (parallel::ParallelContext::GetInstance()->auto_pipeline()) {
1511 ReInitCostGraph(all_nodes, root, graph->dyn_shape_tmp_fix);
1512 ops = entire_costgraph->GetOperators();
1513 }
1514 }
1515
1516 GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list, isTraining, param_users_ops_index, root);
1517
1518 // print the selected strategy
1519 for (auto &op : entire_costgraph->GetOperators()) {
1520 StrategyPtr s_strategy = op->selected_strategy();
1521 if (s_strategy != nullptr) {
1522 MS_LOG(INFO) << op->name() << ": The strategy is: " << s_strategy->ToString();
1523 }
1524 }
1525
1526 if (graph->dyn_shape_tmp_fix) {
1527 (void)WriteStrategiesBackToAnfGraph(ops);
1528 (void)IgnoreOperatorsInCostGraph();
1529 ops_in_a_loop_.clear();
1530 configured_stra_ops_.clear();
1531 ignore_candidate_.clear();
1532 return SUCCESS;
1533 }
1534
1535 if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
1536 MS_LOG(INFO) << "Init selected strategy succeeded.";
1537 } else {
1538 MS_LOG(ERROR) << "Init selected strategy failed.";
1539 return FAILED;
1540 }
1541
1542 (void)IgnoreOperatorsInCostGraph();
1543 ops_in_a_loop_.clear();
1544 configured_stra_ops_.clear();
1545 ignore_candidate_.clear();
1546
1547 return SUCCESS;
1548 }
1549
LoadStrategyFromFile(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)1550 Status LoadStrategyFromFile(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
1551 InitCostGraph();
1552 bool use_sp = (ParallelContext::GetInstance()->strategy_search_mode() == kShardingPropagation) ||
1553 (ParallelContext::GetInstance()->sharding_propagation());
1554 if (CostModelContext::GetInstance()->is_multi_subgraphs() || use_sp) {
1555 if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
1556 MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
1557 << entire_costgraph->GetOperators().size() << " operators.";
1558 } else {
1559 MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
1560 }
1561 } else {
1562 if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
1563 MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
1564 << entire_costgraph->GetOperators().size() << " operators.";
1565 } else {
1566 MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
1567 }
1568 }
1569
1570 ReshapeCostCompute(all_nodes);
1571 // load strategy map from json
1572 StrategyMap stra_map;
1573 StrategyPtr strategy = nullptr;
1574 if ((StrategyCheckpoint::GetInstance().LoadAutoOpStrategy(&stra_map) != SUCCESS)) {
1575 return FAILED;
1576 }
1577 for (auto &op : entire_costgraph->GetOperators()) {
1578 std::string strategy_key_name = op->cnodes()[0]->fullname_with_scope();
1579 bool load_strategy_from_json = stra_map.find(strategy_key_name) != stra_map.end();
1580 if (!load_strategy_from_json) {
1581 MS_LOG(INFO) << "not found strategy for " << strategy_key_name;
1582 return FAILED;
1583 }
1584 strategy = stra_map[strategy_key_name];
1585 op->SetSelectedStrategy(strategy, 0);
1586 }
1587 if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
1588 MS_LOG(INFO) << "Init selected strategy succeeded.";
1589 } else {
1590 MS_LOG(INFO) << "Init selected strategy failed.";
1591 return FAILED;
1592 }
1593
1594 // print the selected strategy
1595 for (auto &op : entire_costgraph->GetOperators()) {
1596 StrategyPtr s_strategy = op->selected_strategy();
1597 if (s_strategy != nullptr) {
1598 MS_LOG(INFO) << op->name() << ": The strategy is: " << s_strategy->ToString();
1599 }
1600 }
1601 (void)IgnoreOperatorsInCostGraph();
1602 ops_in_a_loop_.clear();
1603 configured_stra_ops_.clear();
1604 ignore_candidate_.clear();
1605
1606 MS_LOG(INFO) << "End load strategies from file";
1607 return SUCCESS;
1608 }
1609
SaveStrategyToFile()1610 void SaveStrategyToFile() {
1611 StrategyMap stra_map;
1612 TensorInfoMap tensor_info_map;
1613 ManualShapeMap manual_shape_map;
1614
1615 for (auto &op : entire_costgraph->GetOperators()) {
1616 StrategyPtr s_strategy = op->selected_strategy();
1617 std::string strategy_key_name = op->cnodes()[0]->fullname_with_scope();
1618 stra_map[strategy_key_name] = s_strategy;
1619 }
1620 if (StrategyCheckpoint::GetInstance().SaveAutoOpStrategy(stra_map, tensor_info_map, manual_shape_map) != SUCCESS) {
1621 MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
1622 }
1623 MS_LOG(INFO) << "Success save strategies to file.";
1624 }
1625 } // namespace parallel
1626 } // namespace mindspore
1627