• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/step_auto_parallel.h"
18 
19 #include <algorithm>
20 #include <cinttypes>
21 #include <ctime>
22 #include <map>
23 #include <memory>
24 #include <set>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 #include "frontend/optimizer/opt.h"
30 #include "frontend/optimizer/optimizer.h"
31 #include "frontend/parallel/auto_parallel/dp_algo_costmodel.h"
32 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
33 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
34 #include "frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h"
35 #include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h"
36 #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h"
37 #include "frontend/parallel/graph_util/graph_info.h"
38 #include "frontend/parallel/graph_util/node_info.h"
39 #include "frontend/parallel/ops_info/reshape_info.h"
40 #include "frontend/parallel/ops_info/tmp_identity_info.h"
41 #include "frontend/parallel/parameter_manager.h"
42 #include "frontend/parallel/step_parallel.h"
43 #include "frontend/parallel/step_parallel_utils.h"
44 #include "frontend/parallel/dynamic_shape/dynamic_shape.h"
45 #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
46 #include "include/common/utils/parallel_context.h"
47 #include "ir/anf.h"
48 #include "ir/param_info.h"
49 #include "ir/tensor.h"
50 #include "ops/array_ops.h"
51 #include "ops/framework_ops.h"
52 #include "ops/math_ops.h"
53 #include "ops/other_ops.h"
54 #include "ops/sequence_ops.h"
55 #include "pipeline/jit/ps/pipeline_split.h"
56 #include "utils/hash_map.h"
57 #include "utils/hash_set.h"
58 #include "utils/ms_context.h"
59 #if defined(__linux__) && defined(WITH_BACKEND)
60 #include "include/backend/distributed/ps/util.h"
61 #endif
62 
63 namespace mindspore {
64 namespace parallel {
SearchParallelStrategy(const std::string & strategy_search_mode,const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)65 void SearchParallelStrategy(const std::string &strategy_search_mode, const FuncGraphPtr &root,
66                             const std::vector<AnfNodePtr> &all_nodes) {
67   if (StrategyCheckpoint::GetInstance().LoadAutoOpStrategyOn()) {
68     if (LoadStrategyFromFile(root, all_nodes) == SUCCESS) {
69       MS_LOG(INFO) << "Load strategies success, jump searching strategy.";
70       return;
71     }
72   }
73   if ((strategy_search_mode == kDynamicProgramming) || (strategy_search_mode == kShardingPropagation)) {
74     if (ParallelStrategySearch(all_nodes, root) != SUCCESS) {
75       MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using " << strategy_search_mode
76                         << " searching mode";
77     }
78   } else if (strategy_search_mode == kRecursiveProgramming) {
79     if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) {
80       MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode";
81     }
82   } else {
83     MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected: " << strategy_search_mode;
84   }
85   if (StrategyCheckpoint::GetInstance().SaveAutoOpStrategyOn()) {
86     SaveStrategyToFile();
87   }
88 }
89 
HasCellShard(const FuncGraphPtr & func_graph)90 bool HasCellShard(const FuncGraphPtr &func_graph) {
91   AnfNodePtr ret = func_graph->get_return();
92   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
93   for (auto &node : all_nodes) {
94     if (IsPrimitiveCNode(node, prim::kPrimShard) || IsPrimitiveCNode(node, prim::kPrimReshard)) {
95       return true;
96     }
97   }
98   return false;
99 }
100 
IsSkipAutoParallel(const FuncGraphPtr & root,const std::string & strategy_search_mode,const bool is_pre_action)101 bool IsSkipAutoParallel(const FuncGraphPtr &root, const std::string &strategy_search_mode, const bool is_pre_action) {
102   root->set_flag(kHasShard, HasCellShard(root));
103   std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
104   if (root->has_flag(kSkipAutoParallelCompile) || parallel_mode != kAutoParallel ||
105       root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY) || HasNestedMetaFg(root)) {
106     return true;
107   }
108 
109   // For parallel with shard, skip PreAutoParallel
110   // Shard Prim will be deleted once shard is set, see pass.cc.
111   if (root->has_flag(kHasShard)) {
112     return true;
113   }
114   if (root->has_flag(kSharded)) {
115     return false;
116   }
117   if (parallel::IsPynativeParallel() && !root->has_flag(kHasShard)) {
118     return true;
119   }
120 
121   if ((is_pre_action && strategy_search_mode == kDynamicProgramming) ||
122       (!is_pre_action && strategy_search_mode != kDynamicProgramming)) {
123     return true;
124   }
125   return false;
126 }
127 
StepAutoParallel(const FuncGraphPtr & root,const opt::OptimizerPtr &)128 bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
129   // Mode 'dynamic programming' will run after pipeline_split, others don't.
130   MS_EXCEPTION_IF_NULL(root);
131   bool is_pre_action = !root->has_flag(AUTO_PARALLEL_FINISH_PRE_ACTION);
132   bool changes;
133   if (is_pre_action) {
134     root->set_flag(AUTO_PARALLEL_FINISH_PRE_ACTION, true);
135     auto manager = root->manager();
136     const auto &graphs = manager->func_graphs();
137     bool is_training = std::any_of(graphs.cbegin(), graphs.cend(),
138                                    [](auto cur_graph) -> bool { return cur_graph->has_flag(kTraining); });
139     if (is_training) {
140       root->set_flag(kTraining, true);
141     }
142     changes = true;
143   } else {
144     changes = false;
145   }
146 
147 #if defined(__linux__) && defined(WITH_BACKEND)
148   if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) {
149     return changes;
150   }
151 #endif
152   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
153   // control whether use model_parallel mode
154   std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode();
155   bool is_skip = IsSkipAutoParallel(root, strategy_search_mode, is_pre_action);
156   if (is_skip && !ParallelContext::GetInstance()->direct_split()) {
157     return changes;
158   }
159   MS_LOG(INFO) << "search_mode: " << strategy_search_mode;
160 
161   // tag dynamic shape graph
162   parallel::TagDynamicShapeFuncGraph(root);
163 
164   MSLogTime msTime;
165   msTime.Start();
166 #ifdef ENABLE_DUMP_IR
167   auto context = MsContext::GetInstance();
168   MS_EXCEPTION_IF_NULL(context);
169   if (context->CanDump(kIntroductory)) {
170     DumpGraph(root, std::string(STEP_AUTO_PARALLEL_BEGIN));
171   }
172 #endif
173   MS_LOG(INFO) << "Now entering step auto parallel";
174   TOTAL_OPS = 0;
175   AnfNodePtr ret = root->get_return();
176   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
177 
178   // insert Virtual Dataset if not exist
179   if (ParallelInit() != SUCCESS) {
180     MS_LOG(EXCEPTION) << "Parallel init failed";
181   }
182   if (!mindspore::pipeline::HasVirtualDataset(all_nodes)) {
183     mindspore::pipeline::InsertVirtualDataset(root, all_nodes);
184   }
185   // redo deepscoped search again to connected the Virtual Dataset into the graph
186   all_nodes = DeepScopedGraphSearch(ret);
187 
188   if (strategy_search_mode == kRecursiveProgramming &&
189       ((g_device_manager->DeviceNum() & (g_device_manager->DeviceNum() - 1)) != 0)) {
190     MS_LOG(EXCEPTION)
191       << "The recursive auto parallel strategy searching mode requires the device num be the power of 2.";
192   }
193   // mark the forward cnodes, parallel only care these nodes
194   MarkForwardCNode(root);
195 
196   ExceptionIfHasCommunicationOp(all_nodes);
197 
198   if (IsInsertVirtualOutput(root)) {
199     InsertVirtualOutput(root, all_nodes);
200     AnfNodePtr ret_after = root->get_return();
201     MS_EXCEPTION_IF_NULL(ret_after);
202     all_nodes = DeepScopedGraphSearch(ret_after);
203   }
204 
205   // search parallelization strategy
206   SearchParallelStrategy(strategy_search_mode, root, all_nodes);
207   msTime.End();
208   uint64_t time = msTime.GetRunTimeUS();
209 
210   MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us";
211 
212   root->set_flag(AUTO_PARALLEL_RUN_ONCE_ONLY, true);
213   return changes;
214 }
215 
IsElementWiseOperator(const std::string & op_name)216 bool IsElementWiseOperator(const std::string &op_name) {
217   // clang-format off
218   static const std::set<std::string> elementwise_op = {ACTIVATION, GELU,         TANH,
219                                                        SOFTMAX,    LOG_SOFTMAX,  RELU,
220                                                        SQRT,       CAST,         POW,
221                                                        EXP,        LOG,          COS,
222                                                        ACOS,       LOGICALNOT,   NEG,
223                                                        SQUARE,     SIGMOID,      ABS,
224                                                        ACOSH,      ASIN,         ASINH,
225                                                        ATAN,       ATANH,        CEIL,
226                                                        COSH,       EXPM1,        LOG1P,
227                                                        SIN,        SINH,         TAN,
228                                                        RSQRT,      RECIPROCAL,   INV,
229                                                        ROUND,      FLOOR,        SIGN,
230                                                        ERF,        ERFC,         ZEROSLIKE,
231                                                        ONESLIKE,   BESSELI0E,    BESSELI0,
232                                                        BESSELI1,   BESSELJ0,     BESSELJ0,
233                                                        ASSIGN,     ASSIGN_ADD,   ATAN2,
234                                                        DIVNONAN,   LOGICALAND,   ELU,
235                                                        LOGICALOR,  RELU6,        SOFTPLUS,
236                                                        SOFTSIGN,   LESS,         LESSEQUAL,
237                                                        BESSELI1E,  GREATEREQUAL, APPROXIMATEEQUAL,
238                                                        MOD,        REVERSEV2,    REPEAT_ELEMENTS,
239                                                        TRUNC,      LGAMMA,       CHOLESKY};
240   // clang-format on
241   auto iter = elementwise_op.find(op_name);
242   return (iter != elementwise_op.cend());
243 }
244 
245 // Recording the operators appearing in a for-loop.
246 // Currently, we assume that the operators in different for-loops are identical, and their traversal
247 // orderings are also identical.
248 // Therefore, we create OperatorInfo objects for the operators in a loop (say, loop-3), and reuse them in
249 // the rest of loops (loop-2, loop-1 and loop-0)
250 std::set<std::string> ops_in_a_loop_;
251 // Whether two operators are in different loops; if it is true, then return true.
252 // If at least one of the two operators is not in the loop, then return false.
253 // If two operators are in the same loop, the return false.
IsOperatorsInTwoSeparateLoops(const CNodePtr & a_cnode,const CNodePtr & b_cnode)254 bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cnode) {
255   auto a_op_info = a_cnode->user_data<OperatorInfo>();
256   MS_EXCEPTION_IF_NULL(a_op_info);
257   auto b_op_info = b_cnode->user_data<OperatorInfo>();
258   MS_EXCEPTION_IF_NULL(b_op_info);
259   if ((ops_in_a_loop_.find(a_op_info->name()) == ops_in_a_loop_.end()) ||
260       (ops_in_a_loop_.find(b_op_info->name()) == ops_in_a_loop_.end())) {
261     return false;
262   }
263   size_t a_loop_index = 0, b_loop_index = 0;
264   const auto &a_fullname = a_cnode->fullname_with_scope();
265   if (!GetLoopIndexFromCNode(a_cnode, &a_loop_index)) {
266     MS_LOG(EXCEPTION) << "The operator with fullname_with_scope: " << a_fullname << " was not included in the set.";
267   }
268   const auto &b_fullname = b_cnode->fullname_with_scope();
269   if (!GetLoopIndexFromCNode(b_cnode, &b_loop_index)) {
270     MS_LOG(EXCEPTION) << "The operator with fullname_with_scope: " << b_fullname << " was not included in the set.";
271   }
272   if (a_loop_index == b_loop_index) {
273     return false;
274   }
275   return true;
276 }
277 
278 // 'configured_stra_ops_' includes all operators that are configured sharding strategies.
279 std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> configured_stra_ops_;
280 std::set<OperatorInfoPtr> ignore_candidate_;
InitCostGraph()281 void InitCostGraph() {
282   if (entire_costgraph == nullptr) {
283     entire_costgraph = std::make_shared<CostGraph>();
284   }
285   MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
286   CostModelContext::GetInstance()->PrintCostModel();
287   entire_costgraph->Init();
288   configured_stra_ops_.clear();
289   ignore_candidate_.clear();
290 }
291 
SetStrategyToOperator(const OperatorInfoPtr & operator_info,const PrimitivePtr & prim,mindspore::HashMap<std::string,ValuePtr> attrs,bool,StrategyMap * stra_map,const std::string & strategy_key_name)292 void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const PrimitivePtr &prim,
293                            mindspore::HashMap<std::string, ValuePtr> attrs, bool, StrategyMap *stra_map,
294                            const std::string &strategy_key_name) {
295   // In this case, the configured strategy should be extracted to help setting cost
296   StrategyPtr strategyPtr;
297   if (StrategyFound(attrs)) {
298     strategyPtr = parallel::ExtractStrategy(attrs[IN_STRATEGY]);
299   } else {
300     strategyPtr = (*stra_map)[strategy_key_name];
301   }
302 
303   if (strategyPtr == nullptr) {
304     return;
305   }
306 
307   if (prim->name() == RESHAPE) {
308     MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
309   }
310 
311   // Set cost for this configured strategy
312   if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) {
313     MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed";
314   }
315 
316   const auto fully_use_devices = CostModelContext::GetInstance()->fully_use_device();
317   if (fully_use_devices) {
318     // If configured to fully use devices, then checking for the user-specified strategy
319     int64_t used_devices = operator_info->used_devices();
320     MS_EXCEPTION_IF_NULL(g_device_manager);
321     auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size();
322 
323     // 'used_devices == -1' means that 'used_devices_' is not set
324     // 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel
325     if (used_devices == -1 || (used_devices != 1 && LongToSize(used_devices) != total_device_num)) {
326       MS_LOG(EXCEPTION) << "In current configuration 'fully_use_devices' = True, "
327                         << "but the specified strategy uses device: " << used_devices
328                         << ", total devices: " << total_device_num
329                         << ", try to set 'set_algo_parameters(fully_use_devices=False)' "
330                            "in package 'mindspore.parallel'.";
331     }
332   }
333   (void)configured_stra_ops_.emplace(operator_info, strategyPtr);
334 }
335 
ApplyApproximationForNode(const OperatorInfoPtr & operator_info)336 void ApplyApproximationForNode(const OperatorInfoPtr &operator_info) {
337   auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
338   if (approximation) {
339     operator_info->ApproximateStrategies();
340     MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
341   }
342 }
343 
AddOperatorToIgnoreCandidates(const PrimitivePtr & prim,const OperatorInfoPtr & operator_info)344 void AddOperatorToIgnoreCandidates(const PrimitivePtr &prim, const OperatorInfoPtr &operator_info) {
345   if (prim->name() == CAST) {
346     // add CAST into ignore_candidate
347     (void)ignore_candidate_.insert(operator_info);
348   }
349 }
350 
GenerateStrategiesByOperatorInfoPtr(const OperatorInfoPtr & operator_info)351 bool GenerateStrategiesByOperatorInfoPtr(const OperatorInfoPtr &operator_info) {
352   Status retGenStra;
353   auto attrs = operator_info->attrs();
354   if (AttrFound(attrs, STRATEGY_GEN_MODE) && GetValue<std::string>(attrs[STRATEGY_GEN_MODE]) == kDataParallel) {
355     MS_LOG(INFO) << "generating batch parallel strategy...";
356     auto prim = std::make_shared<Primitive>(operator_info->name());
357     StrategyPtr strategyPtr = parallel::GenerateBatchParallelStrategy(operator_info, prim);
358     retGenStra = operator_info->SetCostUnderStrategy(strategyPtr);
359     attrs = prim->attrs();
360     operator_info->addAttr(IN_STRATEGY, attrs[GEN_STRATEGY]);  // for d-rec
361   } else {
362     MS_LOG(INFO) << "auto-searching strategy...";
363     retGenStra = operator_info->GenerateStrategies(0);
364   }
365   if (retGenStra != SUCCESS) {
366     MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed.";
367     return false;
368   }
369   return true;
370 }
371 
CreateTheOperatorInfo(const PrimitivePtr & prim,const CNodePtr & cnode,bool is_last_nodes,StrategyMap * stra_map)372 OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes,
373                                       StrategyMap *stra_map) {
374   MS_EXCEPTION_IF_NULL(prim);
375   MS_EXCEPTION_IF_NULL(cnode);
376   // Create an OperatorInfo instance
377   OperatorInfoPtr operator_info = CreateOperatorInfo(cnode);
378   MS_EXCEPTION_IF_NULL(operator_info);
379   // Set the parameter information for this OperatorInfo (whether the inputs are parameters or not)
380   std::vector<bool> parameter_info = ExtractInputParameterByNode(cnode);
381   if (operator_info->set_is_parameter(parameter_info) != SUCCESS) {
382     MS_LOG(ERROR) << "Initializing parameter information failed for operator: " << operator_info->name();
383     return nullptr;
384   }
385   // Set the data type for inputs and outputs of this OperatorInfo
386   auto inputs_type_length = ExtractInputTypeLengthByNode(cnode);
387   auto outputs_type = ExtractOutputTypeByNode(cnode);
388   if (ParallelContext::GetInstance()->strategy_search_mode() == kRecursiveProgramming) {
389     std::string param_name = ExtractInputParameterNameByNode(cnode);
390     if (!param_name.empty()) {
391       operator_info->set_involved_param_name(param_name);
392     }
393   }
394   std::vector<size_t> outputs_type_length;
395   outputs_type_length.reserve(outputs_type.size());
396   (void)std::transform(outputs_type.begin(), outputs_type.end(), std::back_inserter(outputs_type_length),
397                        GetLengthOfDataType);
398   if (operator_info->SetInputAndOutputTypeLength(inputs_type_length, outputs_type_length) != SUCCESS) {
399     MS_LOG(ERROR) << "Setting the lengths of inputs and outputs failed for operator: " << operator_info->name();
400     return nullptr;
401   }
402   if (operator_info->set_outputs_type(outputs_type) != SUCCESS) {
403     MS_LOG(ERROR) << "Setting the types of outputs failed for operator: " << operator_info->name();
404     return nullptr;
405   }
406 
407   operator_info->set_auto_parallel(true);
408 
409   AddOperatorToIgnoreCandidates(prim, operator_info);
410   // key of strategy map
411   std::string strategy_key_name;
412   auto param_names = NodeParameterName(cnode, -1, 0);
413   if (!param_names.empty()) {
414     strategy_key_name = prim->name() + "_" + param_names[0].first;
415   }
416   bool load_strategy_from_ckpt =
417     StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
418   // If no strategy has been configured for this operator, then candidate strategies are generated for
419   // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
420   // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
421   auto attrs = prim->attrs();
422   if (ParallelContext::GetInstance()->strategy_search_mode() != kRecursiveProgramming) {
423     if ((StrategyFound(attrs) && prim->name() != CAST) || load_strategy_from_ckpt) {
424       SetStrategyToOperator(operator_info, prim, attrs, is_last_nodes, stra_map, strategy_key_name);
425       return operator_info;
426     }
427   }
428   // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
429   // BatchParallelInfo operator
430   operator_info->ComputeBatchSplitFlagList();
431   if ((ParallelContext::GetInstance()->strategy_search_mode() != kRecursiveProgramming)) {
432     (void)GenerateStrategiesByOperatorInfoPtr(operator_info);
433   }
434 
435   bool use_sp_and_dataset = ((ParallelContext::GetInstance()->strategy_search_mode() == kShardingPropagation) ||
436                              (ParallelContext::GetInstance()->sharding_propagation())) &&
437                             (operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos);
438   if (use_sp_and_dataset) {
439     const auto &swc_vec = operator_info->GetStrategyCost();
440     if (swc_vec.empty()) {
441       MS_LOG(EXCEPTION) << "No available strategy for: " << operator_info->name();
442     }
443     MS_EXCEPTION_IF_NULL(swc_vec[0]->strategy_ptr);
444     (void)configured_stra_ops_.emplace(operator_info, swc_vec[0]->strategy_ptr);
445   }
446   // If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated
447   ApplyApproximationForNode(operator_info);
448   return operator_info;
449 }
450 
IsFindWrong(const OperatorInfoPtr & current_op_ptr,const std::string & prim_name)451 bool IsFindWrong(const OperatorInfoPtr &current_op_ptr, const std::string &prim_name) {
452   bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
453                        (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) &&
454                        (current_op_ptr->name().find(prim_name + "Info") == std::string::npos);
455   if (prim_name == GATHERV2) {
456     is_find_wrong = is_find_wrong && (current_op_ptr->name().find(prim_name + "PInfo") == std::string::npos);
457   }
458   return is_find_wrong;
459 }
460 
AddUsersUniqueIdWhenSharingParameter(const std::pair<std::string,std::pair<AnfNodePtr,AnfNodeIndexSet>> & parameter_users_info)461 void AddUsersUniqueIdWhenSharingParameter(
462   const std::pair<std::string, std::pair<AnfNodePtr, AnfNodeIndexSet>> &parameter_users_info) {
463   auto users_set = parameter_users_info.second.second;
464   if (users_set.size() > 1) {
465     MS_LOG(INFO) << "Parameter " << parameter_users_info.first << " has " << users_set.size() << " users.";
466     std::vector<std::string> param_users_uniqueid;
467     for (const auto &user : users_set) {
468       MS_LOG(INFO) << "with ID: " << user.first->UniqueId() << " and name: " << user.first->UniqueName();
469       param_users_uniqueid.push_back(user.first->UniqueId());
470     }
471     entire_costgraph->add_param_users_uniqueid(param_users_uniqueid);
472   }
473 }
474 
AddParamUsersForRec(const std::vector<AnfNodePtr> & all_nodes)475 void AddParamUsersForRec(const std::vector<AnfNodePtr> &all_nodes) {
476   for (auto &node : all_nodes) {
477     if (node->isa<Parameter>()) {
478       ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsParallelCareNode, all_nodes);
479       AddUsersUniqueIdWhenSharingParameter(parameter_users_info);
480     }
481   }
482 }
483 
484 // Using CNode's UniqueIds to construct nodes
ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr &)485 Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
486   MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
487   // The map from CNode's UniqueId to its operatorInfo
488   std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
489   // The operator_infos in a loop
490   std::vector<OperatorInfoPtr> operators_in_forloop;
491   // Key: i-th loop; Value: index of 'operators_in_forloop'
492   std::map<size_t, size_t> loop_to_ops;
493   // extract strategy from checkpoint for multi-train
494   StrategyMap stra_map;
495   if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
496     if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
497       MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
498     }
499   }
500 
501   for (auto &node : all_nodes) {
502     // NOTE: we only care about splittable Primitive operators
503     auto cnode = node->cast<CNodePtr>();
504     bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)));
505     if (bool_result) {
506       continue;
507     }
508     auto prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
509     if (!IsAutoParallelCareNode(cnode)) {
510       // Needed by rec_parser
511       if (ParallelContext::GetInstance()->strategy_search_mode() == kRecursiveProgramming) {
512         auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node);
513         if (prev_cnode != nullptr) {
514           entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId()));
515         }
516       }
517       continue;
518     }
519     auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
520     MS_EXCEPTION_IF_NULL(prim);
521 
522     auto search_cnode = from_cnode_to_info.find(cnode->UniqueId() + prim->name());
523     if (search_cnode == from_cnode_to_info.cend()) {
524       size_t loop_index = 0;
525       bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index);
526       const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop();
527       if (single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) {
528         const auto &current_op_ptr = operators_in_forloop[loop_to_ops[loop_index]];
529         if (IsFindWrong(current_op_ptr, prim->name())) {
530           MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
531                             << " does not match the Prim: " << prim->name()
532                             << ". The fullname_with_scope: " << cnode->fullname_with_scope();
533         }
534         loop_to_ops[loop_index]++;
535         cnode->set_user_data<OperatorInfo>(current_op_ptr);
536         MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
537                      << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
538                      << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
539                      << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
540         (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId() + prim->name(), current_op_ptr));
541         continue;
542       }
543       bool is_last_nodes = IsPrimitiveCNode(cnode, prim::kPrimVirtualOutput);
544       auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
545       if (operator_info == nullptr) {
546         return FAILED;
547       }
548       if (ParallelContext::GetInstance()->strategy_search_mode() == kRecursiveProgramming) {
549         operator_info->set_type(prim->name());
550         operator_info->set_last_node_flag(is_last_nodes);
551         std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode, all_nodes);
552         entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
553       }
554 
555       entire_costgraph->AddOperator(operator_info);
556       cnode->set_user_data<OperatorInfo>(operator_info);
557       MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
558                    << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
559                    << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
560                    << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
561       (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId() + prim->name(), operator_info));
562       if (single_loop && is_in_loop) {
563         operators_in_forloop.push_back(operator_info);
564         (void)ops_in_a_loop_.insert(operator_info->name());
565         loop_to_ops[loop_index]++;
566       }
567     } else {
568       // Two CNODEs' UniqueIds should not be equal
569       MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId()
570                         << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
571                         << " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name();
572     }
573   }
574 
575   MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
576   // Needed by rec_parser 2
577   AddParamUsersForRec(all_nodes);
578 
579   return SUCCESS;
580 }
581 
SetOperatorToCNode(const OperatorInfoPtr & current_op_ptr,const PrimitivePtr & prim,const CNodePtr & cnode)582 void SetOperatorToCNode(const OperatorInfoPtr &current_op_ptr, const PrimitivePtr &prim, const CNodePtr &cnode) {
583   if (current_op_ptr == nullptr) {
584     MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed.";
585   } else {
586     if (IsFindWrong(current_op_ptr, prim->name())) {
587       MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
588                         << " does not match the Prim: " << prim->name();
589     }
590 
591     // Needed by rec_parser
592     ModifyInputsTensorNameListIfOperatorInfoCreated(current_op_ptr->name(), cnode->UniqueId());
593 
594     cnode->set_user_data<OperatorInfo>(current_op_ptr);
595     current_op_ptr->set_cnode(cnode);
596     MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
597                  << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
598                  << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
599                  << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
600   }
601 }
602 
603 // Using CNode's UniqueIdThroughCopys to construct nodes
ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr &)604 Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
605   MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
606   // The map from CNode's UniqueIdThroughCopy to its operatorInfo
607   std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
608   // The operator_infos in a loop
609   std::vector<OperatorInfoPtr> operators_in_forloop;
610   // Key: i-th loop; Value: index of 'operators_in_forloop'
611   std::map<size_t, size_t> loop_to_ops;
612   // extract strategy from checkpoint for multi-train
613   StrategyMap stra_map;
614   if (StrategyCheckpoint::GetInstance().LoadCheckPointOn() &&
615       StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
616     MS_LOG(WARNING) << "Load strategy checkpoint failed";
617     return FAILED;
618   }
619   for (auto &node : all_nodes) {
620     // NOTE: we only care about splittable Primitive operators
621     auto cnode = node->cast<CNodePtr>();
622     if ((cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)))) {
623       continue;
624     }
625     auto prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
626     if (!IsAutoParallelCareNode(cnode)) {
627       // Needed by rec_parser
628       if (ParallelContext::GetInstance()->strategy_search_mode() == kRecursiveProgramming) {
629         auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node);
630         if (prev_cnode != nullptr) {
631           entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId()));
632         }
633       }
634       continue;
635     }
636     auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
637 
638     // Find the operatorInfo if it exists
639     auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy() + prim->name());
640     if (search_cnode == from_cnode_to_info.cend()) {
641       size_t loop_index = 0;
642       bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index);
643       const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop();
644       bool is_op_created = single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size());
645       if (is_op_created) {
646         const auto &current_op_ptr = operators_in_forloop[loop_to_ops[loop_index]];
647         if (IsFindWrong(current_op_ptr, prim->name())) {
648           MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
649                             << " does not match the Prim: " << prim->name()
650                             << ". The fullname_with_scope: " << cnode->fullname_with_scope();
651         }
652         loop_to_ops[loop_index]++;
653         cnode->set_user_data<OperatorInfo>(current_op_ptr);
654         MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
655                      << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
656                      << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
657                      << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
658         (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy() + prim->name(), current_op_ptr));
659         continue;
660       }
661       // In this case, the corresponding OperatorInfo is not created, create the new one.
662       bool is_last_nodes = IsPrimitiveCNode(cnode, prim::kPrimVirtualOutput);
663       auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
664       MS_EXCEPTION_IF_NULL(operator_info);
665 
666       if (ParallelContext::GetInstance()->strategy_search_mode() == kRecursiveProgramming) {
667         operator_info->set_type(prim->name());
668         operator_info->set_last_node_flag(is_last_nodes);
669         std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode, all_nodes);
670         entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
671       }
672 
673       entire_costgraph->AddOperator(operator_info);
674       cnode->set_user_data<OperatorInfo>(operator_info);
675       MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
676                    << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
677                    << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
678                    << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
679       (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy() + prim->name(), operator_info));
680       if (single_loop && is_in_loop) {
681         operators_in_forloop.push_back(operator_info);
682         (void)ops_in_a_loop_.insert(operator_info->name());
683         loop_to_ops[loop_index]++;
684       }
685     } else {
686       SetOperatorToCNode(search_cnode->second, prim, cnode);
687     }
688   }
689 
690   MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
691   // Needed by rec_parser 2
692   AddParamUsersForRec(all_nodes);
693 
694   return SUCCESS;
695 }
696 
PreProcessPreCastForSP(const OperatorInfoPtr & prev_op_info,const OperatorInfoPtr & node_op_info,const CNodePtr & cnode,const EdgePtr edge_ptr,size_t input_index)697 void PreProcessPreCastForSP(const OperatorInfoPtr &prev_op_info, const OperatorInfoPtr &node_op_info,
698                             const CNodePtr &cnode, const EdgePtr edge_ptr, size_t input_index) {
699   if (IsPrimitiveCNode(cnode, prim::kPrimMatMul) && input_index == INDEX_TWO) {
700     prev_op_info->set_repeated_num_in_dev_matrix_right(false);
701     prev_op_info->ClearStrategyCost();
702     (void)prev_op_info->GenerateStrategies(0);
703   }
704   if ((configured_stra_ops_.find(node_op_info) != configured_stra_ops_.end())) {
705     const auto next_op_stra = configured_stra_ops_[node_op_info];
706     if (edge_ptr->InitEdgeCost() != SUCCESS) {
707       MS_LOG(EXCEPTION) << "Edge cost initialization failed";
708     }
709     const auto cast_stra = edge_ptr->GetPrevOpStrategyByNextOpStrategyWithMiniComm(next_op_stra);
710     if (cast_stra == nullptr) {
711       MS_LOG(EXCEPTION) << "No available strategy for: " << prev_op_info->name();
712     }
713     prev_op_info->ClearStrategyCost();
714     if (prev_op_info->SetCostUnderStrategy(cast_stra) != SUCCESS) {
715       MS_LOG(EXCEPTION) << "Failure: operator " << prev_op_info->name() << " SetCostUnderStrategy failed";
716     }
717     if (edge_ptr->InitEdgeCost() != SUCCESS) {
718       MS_LOG(EXCEPTION) << "Edge cost re-initialization failed.";
719     }
720     MS_LOG(INFO) << "Set strategy for: " << prev_op_info->name() << " under the strategy of: " << node_op_info->name();
721     (void)configured_stra_ops_.emplace(prev_op_info, cast_stra);
722   }
723 }
724 
CreateEdgeBetweenTwoOps(const OperatorInfoPtr & prev_op_info,const OperatorInfoPtr & node_op_info,const CNodePtr & cnode,const CNodePtr & prev_cnode,const PrimitivePtr & prim,const PrimitivePtr & prev_prim,size_t output_index,size_t input_index,size_t * edge_count)725 void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const OperatorInfoPtr &node_op_info,
726                              const CNodePtr &cnode, const CNodePtr &prev_cnode, const PrimitivePtr &prim,
727                              const PrimitivePtr &prev_prim, size_t output_index, size_t input_index,
728                              size_t *edge_count) {
729   std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name();
730   // If the edge between these two operators already has been added, then the edge will not be added again.
731   if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, input_index - 1)) {
732     return;
733   }
734   EdgePtr edge_ptr;
735   MS_LOG(INFO) << "Creating edge: " << edge_name;
736   if (IsOperatorsInTwoSeparateLoops(prev_cnode, cnode)) {
737     MS_LOG(INFO) << "prev_cnode_fullname: " << prev_cnode->fullname_with_scope()
738                  << ", cnode_fullname: " << cnode->fullname_with_scope();
739     MS_LOG(INFO) << "The two operators in two separate for-loops, thus skip the edge.";
740     return;
741   }
742   const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow();
743   bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) ||
744                          (stra_follow && IsElementWiseOperator(prev_prim->name()));
745   if (follow_strategy) {
746     // Redistribution in not allowed on the edge.
747     // Elementwise operators have the same strategy as their previous operators.
748     edge_ptr =
749       std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, input_index - 1, false, true);
750   } else {
751     edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, input_index - 1, false);
752   }
753   bool use_sp = (ParallelContext::GetInstance()->strategy_search_mode() == kShardingPropagation) ||
754                 (ParallelContext::GetInstance()->sharding_propagation());
755   // Init costs for this edge
756   if (ParallelContext::GetInstance()->strategy_search_mode() != kRecursiveProgramming) {
757     if (!use_sp && edge_ptr->InitEdgeCost() != SUCCESS) {
758       MS_LOG(EXCEPTION) << "Edge cost initialization failed";
759     }
760   }
761   node_op_info->AddPrevEdge(edge_ptr);
762   prev_op_info->AddSuccEdge(edge_ptr);
763   entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr);
764   if (use_sp && prev_prim->name() == CAST) {
765     PreProcessPreCastForSP(prev_op_info, node_op_info, cnode, edge_ptr, input_index);
766   }
767   MS_LOG(INFO) << "Successfully adding the edge between " << prev_op_info->name() << " and " << node_op_info->name();
768   (*edge_count)++;
769 }
770 
ApplyApproximationForGraphs()771 void ApplyApproximationForGraphs() {
772   // If 'approximation' is enabled, the edges need to be checked have effective costs.
773   auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
774   if (approximation) {
775     entire_costgraph->CheckApproximateCostGraphEdges();
776   }
777 }
778 
CreateEdgeAccrossMakeList(const CNodePtr & cnode,const PrimitivePtr & prim,const OperatorInfoPtr & node_op_info,CNodePtr * prev_cnode,ValueNodePtr * prev_prim_anf_node,PrimitivePtr * prev_prim,size_t * edge_count)779 static void CreateEdgeAccrossMakeList(const CNodePtr &cnode, const PrimitivePtr &prim,
780                                       const OperatorInfoPtr &node_op_info, CNodePtr *prev_cnode,
781                                       ValueNodePtr *prev_prim_anf_node, PrimitivePtr *prev_prim, size_t *edge_count) {
782   MS_LOG(INFO) << "Creating edges across the 'make_list' operator.";
783   const auto &sub_inputs = (*prev_cnode)->inputs();
784   for (size_t j = 1; j < sub_inputs.size(); ++j) {
785     *prev_cnode = sub_inputs[j]->cast<CNodePtr>();
786     bool bool_result_list = (*prev_cnode == nullptr) || !IsValueNode<Primitive>((*prev_cnode)->input(0)) ||
787                             !IsAutoParallelCareNode(*prev_cnode);
788     if (bool_result_list) {
789       continue;
790     }
791     *prev_prim_anf_node = (*prev_cnode)->input(0)->cast<ValueNodePtr>();
792     *prev_prim = (*prev_prim_anf_node)->value()->cast<PrimitivePtr>();
793     auto prev_op_info = (*prev_cnode)->user_data<OperatorInfo>();
794     CreateEdgeBetweenTwoOps(prev_op_info, node_op_info, cnode, *prev_cnode, prim, *prev_prim, 0, j, edge_count);
795   }
796 }
797 
ConstructCNodeCostGraphEdges(const mindspore::CNodePtr & cnode,const std::vector<AnfNodePtr> & all_nodes)798 static void ConstructCNodeCostGraphEdges(const mindspore::CNodePtr &cnode, const std::vector<AnfNodePtr> &all_nodes) {
799   auto &inputs = cnode->inputs();
800   ValueNodePtr prim_anf_node = inputs[0]->cast<ValueNodePtr>();
801   PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
802   size_t edge_count = 0;
803   auto node_op_info = cnode->user_data<OperatorInfo>();
804 
805   for (size_t i = 1; i < inputs.size(); ++i) {
806     AnfNodePtr prev_node = inputs[i];
807     if (inputs[i]->isa<Parameter>()) {
808       prev_node = FindRealInputByFormalParameter(cnode, inputs[i], all_nodes);
809       if (prev_node->UniqueId() == inputs[i]->UniqueId()) {
810         continue;
811       }
812     }
813     auto prev_cnode = prev_node->cast<CNodePtr>();
814     PrimitivePtr prev_prim;
815     ValueNodePtr prev_prim_anf_node;
816     bool is_cross = CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
817     if (is_cross) {
818       continue;
819     }
820     size_t output_index = 0;
821     bool is_before_tuple_get_item = false;
822 
823     while (IsCarePrevCNode(prev_cnode, prev_prim)) {
824       if (IsValueNode<FuncGraph>(prev_cnode->input(0))) {
825         auto graph = GetValueNode<FuncGraphPtr>(prev_cnode->input(0));
826         auto output = graph->output();
827         MS_EXCEPTION_IF_NULL(output);
828         prev_cnode = output->cast<CNodePtr>();
829         (void)CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
830       } else if (IsAutoParallelCareNode(prev_cnode)) {
831         auto prev_op_info = prev_cnode->user_data<OperatorInfo>();
832         CreateEdgeBetweenTwoOps(prev_op_info, node_op_info, cnode, prev_cnode, prim, prev_prim, output_index, i,
833                                 &edge_count);
834         break;
835       } else if (prev_prim->name() == prim::kPrimTupleGetItem->name()) {
836         // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before
837         // this 'tuple_getitem'
838         output_index = LongToSize(GetValue<int64_t>(GetValueNode(prev_cnode->input(2))));
839         prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
840         is_cross = CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
841         if (is_cross) {
842           break;
843         }
844         if (!IsAutoParallelCareNode(prev_cnode) && !IsValueNode<FuncGraph>(prev_cnode->input(0))) {
845           MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name();
846         }
847         is_before_tuple_get_item = true;
848       } else if (prev_prim->name() == kMakeTupleOpName) {
849         if (!is_before_tuple_get_item) {
850           CreateEdgeAccrossMakeList(cnode, prim, node_op_info, &prev_cnode, &prev_prim_anf_node, &prev_prim,
851                                     &edge_count);
852           break;
853         }
854         prev_cnode = prev_cnode->input(output_index + 1)->cast<CNodePtr>();
855         output_index = 0;
856         is_cross = CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
857         if (is_cross) {
858           break;
859         }
860         is_before_tuple_get_item = false;
861       } else if (prev_prim->name() == kMakeListOpName) {
862         CreateEdgeAccrossMakeList(cnode, prim, node_op_info, &prev_cnode, &prev_prim_anf_node, &prev_prim, &edge_count);
863         break;
864       } else if (prev_prim->name() == kDependOpName || prev_prim->name() == kLoadOpName) {
865         // In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before
866         // this 'depend'
867         prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
868         is_cross = CrossInterNode(&prev_cnode, &prev_prim_anf_node, &prev_prim);
869         if (is_cross) {
870           break;
871         }
872         is_before_tuple_get_item = true;
873       }
874     }
875   }
876   MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name();
877 }
878 
ConstructCostGraphEdges(const std::vector<AnfNodePtr> & all_nodes)879 void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
880   // Step 2
881   MS_LOG(INFO) << "Constructing edges for cost graph begins.";
882   for (auto &node : all_nodes) {
883     auto cnode = node->cast<CNodePtr>();
884     if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
885       continue;
886     }
887     if (!IsAutoParallelCareNode(cnode)) {
888       continue;
889     }
890     ConstructCNodeCostGraphEdges(cnode, all_nodes);
891   }
892   ApplyApproximationForGraphs();
893 
894   MS_LOG(INFO) << "Constructing edges for cost graph ends.";
895 }
896 
ApplyApproximationForParaNode(const OperatorInfoPtr & target_op_info)897 void ApplyApproximationForParaNode(const OperatorInfoPtr &target_op_info) {
898   // If 'approximation' is enabled, the edges need to be checked have effective costs.
899   auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
900   if (approximation) {
901     target_op_info->ExactStrategiesAndRelatedEdges();
902   }
903 }
904 
CreateIdentityOp(const std::string & parameter_name,const AnfNodePtr & target_parameter)905 std::pair<OperatorInfoPtr, bool> CreateIdentityOp(const std::string &parameter_name,
906                                                   const AnfNodePtr &target_parameter) {
907   // Here, it is sure that this Parameter (RefKey) is being used by multiple Operators.
908   OperatorInfoPtr tmp_identity_ptr;
909   bool new_identity = false;
910   auto returned_identity = entire_costgraph->FindTmpIdentityByParameterName(parameter_name);
911   if (returned_identity != nullptr) {
912     // In this case, the TmpIdentityInfo instance has already been created
913     new_identity = false;
914     tmp_identity_ptr = returned_identity;
915   } else {
916     // In the case, the TmpIdentityInfo instance has NOT been created. Thus, a new one is created.
917     new_identity = true;
918     // 1) extract input shape from this Parameter
919     MS_EXCEPTION_IF_NULL(target_parameter);
920     AbstractBasePtr abstract = target_parameter->abstract();
921     if (abstract == nullptr) {
922       MS_LOG(EXCEPTION) << "Failure: abstract is nullptr";
923     }
924     auto input_shape = dyn_cast<abstract::Shape>(abstract->GetShapeTrack());
925     if (input_shape == nullptr) {
926       MS_LOG(EXCEPTION) << "Failure: input_shape is nullptr";
927     }
928     Shape shape = input_shape->shape();
929     Shapes inputs_shape = {shape};
930     Shapes outputs_shape = {shape};
931     // 2) init the attr
932     mindspore::HashMap<std::string, ValuePtr> attr = {};
933 
934     // Create the TmpIdentity instance
935     tmp_identity_ptr = std::make_shared<TmpIdentityInfo>(inputs_shape, outputs_shape, attr);
936     tmp_identity_ptr->set_name(tmp_identity_ptr->name() + std::to_string(TOTAL_OPS));
937     TOTAL_OPS++;
938     tmp_identity_ptr->set_refkey_parameter_name(parameter_name);
939     // Set the parameter and type lengths for inputs and outputs
940     std::vector<bool> is_parameter;
941     auto casted_target_parameter = target_parameter->cast<ParameterPtr>();
942     MS_EXCEPTION_IF_NULL(casted_target_parameter);
943     is_parameter.push_back(ParameterRequireGrad(casted_target_parameter));
944     if (tmp_identity_ptr->set_is_parameter(is_parameter) != SUCCESS) {
945       MS_LOG(EXCEPTION) << "Setting parameter for TmpIdentityInfo failed";
946     }
947     auto node_type = target_parameter->Type();
948     if (node_type->isa<mindspore::TensorType>()) {
949       auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
950       std::vector<size_t> type_length = {GetLengthOfDataType(input_element_type)};
951       if (tmp_identity_ptr->SetInputAndOutputTypeLength(type_length, type_length) != SUCCESS) {
952         MS_LOG(EXCEPTION) << "Setting input and output type length for TmpIdentityInfo failed";
953       }
954     } else {
955       MS_LOG(EXCEPTION) << "Unknown type: " << node_type->type_name();
956     }
957 
958     // Generate strategies for this TmpIdentityInfo instance;
959     if (tmp_identity_ptr->GenerateStrategies(0) != SUCCESS) {
960       MS_LOG(EXCEPTION) << "Strategy search for Operator failed : " << tmp_identity_ptr->name();
961     }
962   }
963   return std::make_pair(tmp_identity_ptr, new_identity);
964 }
965 
AugmentCostGraph(const std::vector<AnfNodePtr> & all_nodes)966 void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
967   // Step 3
968   for (auto &node : all_nodes) {
969     ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsAutoParallelCareNode, all_nodes);
970     auto parameter_name = parameter_users_info.first;
971     auto target_parameter = parameter_users_info.second.first;
972     auto target_set = parameter_users_info.second.second;
973     if (target_set.size() <= 1) {
974       continue;
975     }
976 
977     // Rule out the case when a Parameter being used by a Operator, but the Operator appears in multiple CNODEs
978     std::set<std::string> target_without_duplicate;
979     for (auto &target : target_set) {
980       auto target_cnode = target.first->cast<CNodePtr>();
981       // Eliminate the ops without cost.
982       if (IsSomePrimitive(target_cnode, SEND)) {
983         continue;
984       }
985       auto input_index = target.second;
986       (void)target_without_duplicate.insert(std::to_string(input_index) +
987                                             target_cnode->user_data<OperatorInfo>()->name());
988     }
989     if (target_without_duplicate.size() <= 1 || parameter_name.empty()) {
990       continue;
991     }
992 
993     auto pair = CreateIdentityOp(parameter_name, target_parameter);
994     OperatorInfoPtr tmp_identity_ptr = pair.first;
995     bool new_identity = pair.second;
996     // A flag recording whether new edges have been created or not
997     bool add_identity_edge = false;
998 
999     // Create edges between this TmpIdentityInfo instance and subsequent Operator instances
1000     for (auto &target : target_set) {
1001       auto target_cnode = target.first->cast<CNodePtr>();
1002       auto input_index = target.second;
1003       auto target_op_info = target_cnode->user_data<OperatorInfo>();
1004       if (!target_op_info->repeated_num_in_dev_matrix_right() && tmp_identity_ptr->repeated_num_in_dev_matrix_right()) {
1005         tmp_identity_ptr->set_repeated_num_in_dev_matrix_right(false);
1006         tmp_identity_ptr->ClearStrategyCost();
1007         (void)tmp_identity_ptr->GenerateStrategies(0);
1008       }
1009 
1010       std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name();
1011       // If the edge between these two operators already has been added, then the edge will not be added again.
1012       if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, LongToSize(input_index - 1))) {
1013         continue;
1014       }
1015       std::shared_ptr<Edge> edge_ptr =
1016         std::make_shared<Edge>(edge_name, tmp_identity_ptr, target_op_info, 0, input_index - 1, false, true);
1017       ApplyApproximationForParaNode(target_op_info);
1018 
1019       bool use_sp = (ParallelContext::GetInstance()->strategy_search_mode() == kShardingPropagation) ||
1020                     (ParallelContext::GetInstance()->sharding_propagation());
1021       if (!use_sp && edge_ptr->InitEdgeCost() != SUCCESS) {
1022         MS_LOG(EXCEPTION) << "Edge cost initialization failed";
1023       }
1024       target_op_info->AddPrevEdge(edge_ptr);
1025       tmp_identity_ptr->AddSuccEdge(edge_ptr);
1026       entire_costgraph->AddEdge(tmp_identity_ptr, target_op_info, edge_ptr);
1027       MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and "
1028                    << target_op_info->name();
1029       add_identity_edge = true;
1030     }
1031     if (new_identity && add_identity_edge) {
1032       // Add the TmpIdentityInfo to CostGraph if BOTH two conditions are satisfied
1033       entire_costgraph->AddOperator(tmp_identity_ptr);
1034     }
1035   }
1036 }
1037 
ReshapeCostCompute(const std::vector<AnfNodePtr> & all_nodes)1038 void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
1039   mindspore::HashSet<std::string> op_cache;
1040   for (const auto &node : all_nodes) {
1041     auto cnode = node->cast<CNodePtr>();
1042     if (!FindReshape(cnode, &op_cache)) {
1043       continue;
1044     }
1045     MS_ASSERT(cnode->size() == 3);
1046     // get previous node's strategy_cost_
1047     auto pre_node = cnode->input(1);
1048     if (IsPrimitiveCNode(pre_node, prim::kPrimLoad)) {
1049       pre_node = pre_node->cast<CNodePtr>()->input(1);
1050     }
1051     int64_t out_index = 0;
1052     OperatorInfoPtr pre_operator_info;
1053     std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs;
1054     auto operator_info = cnode->user_data<OperatorInfo>();
1055     bool is_prev_param = false;
1056     if (!FindReshapePreNodeStraCosts(pre_node, &pre_operator_info, &is_prev_param, &out_index, 0)) {
1057       MS_LOG(EXCEPTION) << "FindReshapePreNodeStraCosts for reshape failed";
1058     }
1059     // 如果是双递归的话枚举reshape和前向算子的策略
1060     if (ParallelContext::GetInstance()->strategy_search_mode() == kRecursiveProgramming) {
1061       ConstructCNodeCostGraphEdges(cnode, all_nodes);
1062     }
1063     if (is_prev_param) {
1064       auto reshape_info1 = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
1065       reshape_info1->SetCostForReshapeWithParameter();
1066       pre_operator_info = reshape_info1;
1067       pre_stra_costs = reshape_info1->strategy_cost();
1068     } else {
1069       pre_stra_costs = pre_operator_info->strategy_cost();
1070     }
1071     // get next node's strategy_cost_
1072     std::vector<std::pair<OperatorInfoPtr, int64_t>> next_ops_index;
1073     bool is_next_reshape = false;
1074     std::vector<std::pair<std::vector<std::shared_ptr<StrategyWithCost>>, int64_t>> next_costs_index;
1075     (void)FindReshapeNextNodeStraCosts(cnode, &next_ops_index, &is_next_reshape, 0);
1076     if (next_ops_index.empty()) {
1077       MS_LOG(INFO) << "FindReshapeNextNodeStraCosts for reshape failed";
1078     }
1079     // set input_layout and output_layout for reshape.
1080     // init reshape and set cost for each input_layout and output_layout.
1081     auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
1082     reshape_info->set_pre_operator_name(pre_operator_info->name());
1083     reshape_info->set_pre_operator_index(out_index);
1084     if (!next_ops_index.empty()) {
1085       for (auto &op_index : next_ops_index) {
1086         // 如果是双递归的话枚举reshape的后向算子的策略
1087         if (ParallelContext::GetInstance()->strategy_search_mode() == kRecursiveProgramming) {
1088           ConstructCNodeCostGraphEdges(op_index.first->cnode(), all_nodes);
1089         }
1090         auto op_cost = op_index.first->strategy_cost();
1091         (void)next_costs_index.emplace_back(std::make_pair(op_cost, op_index.second));
1092       }
1093       reshape_info->set_next_operator_name(next_ops_index[0].first->name());
1094       reshape_info->set_next_operator_index(next_ops_index[0].second);
1095     }
1096     if (ParallelContext::GetInstance()->strategy_search_mode() != kRecursiveProgramming) {
1097       if (reshape_info->GenerateStrategyCosts(pre_stra_costs, next_costs_index, out_index, is_prev_param,
1098                                               is_next_reshape) != SUCCESS) {
1099         MS_LOG(EXCEPTION) << "Reshape generate strategy costs failed";
1100       }
1101     }
1102   }
1103 }
1104 
IgnoreOperatorsInCostGraph()1105 Status IgnoreOperatorsInCostGraph() {
1106   for (auto op = ignore_candidate_.cbegin(); op != ignore_candidate_.cend(); ++op) {
1107     auto cnodes = (*op)->cnodes();
1108     for (auto &cnode : cnodes) {
1109       MS_EXCEPTION_IF_NULL(cnode);
1110       cnode->set_user_data<OperatorInfo>(nullptr);
1111     }
1112   }
1113   return SUCCESS;
1114 }
1115 
ParallelStrategySearch(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root)1116 Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
1117   // There are 4 meta-steps to determine the parallelization strategy for the ANF graph.
1118   // Step 1: Traverse the ANF graph, and create NODEs for costgraph:
1119   //      create the OperatorInfo object for each primitive, and enumerate the parallelization strategies
1120   //      for each OperatorInfo;
1121   // Step 1.1: Deal with 'Reshape':
1122   //      For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's
1123   //      layout as its output layout.
1124   // Step 2: Traverse the ANF graph, and create EDGES for costgraph:
1125   //      create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies
1126   //      for each edge, based on the strategies of two OperatorInfos;
1127   // Step 3: Augment the costgraph:
1128   //      taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity
1129   //      operator for this Parameter, and add an edge for the use of this Parameter by each
1130   //      subsequent operator;
1131   // Step 3.1: Calculate memory usage:
1132   //      note the memory usage calculation is different in training phase and inference phase.
1133   // Step 4: Run the strategy searching algorithm:
1134   //      If 'sharding_propagation' is configured to be true, then the configured-sharding-strategies will propagate
1135   //      to the non-configured operators, with the goal of minimizing redistribution cost.
1136   //      Otherwise, DP algorithm is used to search strategy of the costgraph. Note that there may be several connected
1137   //      components in the costgraph, and the DP algorithm runs on each of them.
1138   //
1139   // OUTPUT: the determined strategy for each operator.
1140 
1141   InitCostGraph();
1142   bool use_sp = (ParallelContext::GetInstance()->strategy_search_mode() == kShardingPropagation) ||
1143                 (ParallelContext::GetInstance()->sharding_propagation());
1144   // Step 1
1145   if (CostModelContext::GetInstance()->is_multi_subgraphs() || use_sp) {
1146     if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
1147       MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
1148                    << entire_costgraph->GetOperators().size() << " operators.";
1149     } else {
1150       MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
1151     }
1152   } else {
1153     if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
1154       MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
1155                    << entire_costgraph->GetOperators().size() << " operators.";
1156     } else {
1157       MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
1158     }
1159   }
1160   // Step 1.1
1161   ReshapeCostCompute(all_nodes);
1162   // Step 2
1163   ConstructCostGraphEdges(all_nodes);
1164   MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
1165                << " operators, and " << entire_costgraph->GetNumEdges() << " edges.";
1166 
1167   // Step 3: Augment the costgraph.
1168   AugmentCostGraph(all_nodes);
1169   auto num_ops = entire_costgraph->GetOperators().size();
1170   SetOpsNumToExecutor(num_ops);
1171   auto num_edges = entire_costgraph->GetNumEdges();
1172   MS_LOG(INFO) << "After the augmenting procedure, there are " << num_ops << " operators, and " << num_edges
1173                << " edges.";
1174 
1175   // Step 3.1: Calculate the memory usage
1176   if (!use_sp && entire_costgraph->CalculateMemoryCost() != SUCCESS) {
1177     MS_LOG(EXCEPTION) << "Calculating memory cost failed.";
1178   }
1179 
1180   // Step 4: run the strategy searching algorithm
1181   if (use_sp) {
1182     entire_costgraph->StrategyPropagate(configured_stra_ops_);
1183   } else if (GetStrategy(entire_costgraph) != SUCCESS) {
1184     MS_LOG(ERROR) << "Strategy search for cost-graph fails";
1185     return FAILED;
1186   }
1187   MS_LOG(INFO) << "Searching strategy succeeded.";
1188 
1189   if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
1190     MS_LOG(INFO) << "Init selected strategy succeeded.";
1191   } else {
1192     MS_LOG(EXCEPTION) << "Init selected strategy failed.";
1193   }
1194 
1195   // print the selected strategy
1196   for (auto &op : entire_costgraph->GetOperators()) {
1197     StrategyPtr s_strategy = op->selected_strategy();
1198     if (s_strategy != nullptr) {
1199       MS_LOG(INFO) << op->name() << ": The strategy is: " << s_strategy->ToString();
1200     }
1201   }
1202   // Remove some operatorInfo from the CNODEs
1203   (void)IgnoreOperatorsInCostGraph();
1204 
1205   ops_in_a_loop_.clear();
1206   configured_stra_ops_.clear();
1207   ignore_candidate_.clear();
1208 
1209   return SUCCESS;
1210 }
1211 
RecInputTensorNames(const std::map<std::string,std::string>::iterator & it,std::vector<std::vector<std::string>> input_tensor_names)1212 std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::string, std::string>::iterator &it,
1213                                                           std::vector<std::vector<std::string>> input_tensor_names) {
1214   for (size_t j = 0; j < input_tensor_names.size(); j++) {
1215     for (size_t k = 0; k < input_tensor_names[j].size(); k++) {
1216       if (it->first == input_tensor_names[j][k]) {
1217         input_tensor_names[j][k] = it->second;
1218         break;
1219       }
1220     }
1221   }
1222   return input_tensor_names;
1223 }
1224 
GetInternalOperatorInfo(const CNodePtr & cnode,const ValueNodePtr & prim_anf_node)1225 CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node) {
1226   auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
1227   if (prim->name() == prim::kPrimTupleGetItem->name() || prim->name() == DEPEND) {
1228     auto prev_cnode = cnode->input(1)->cast<CNodePtr>();
1229     if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
1230       return nullptr;
1231     }
1232     if (IsValueNode<FuncGraph>(prev_cnode->input(0))) {
1233       size_t out_index = 0;
1234       out_index = LongToSize(GetValue<int64_t>(GetValueNode(prev_cnode->input(INDEX_TWO))));
1235       auto graph = GetValueNode<FuncGraphPtr>(prev_cnode->input(0));
1236       auto output = graph->output();
1237       MS_EXCEPTION_IF_NULL(output);
1238       while (IsPrimitiveCNode(output, prim::kPrimDepend)) {
1239         auto output_cnode = output->cast<CNodePtr>();
1240         MS_EXCEPTION_IF_NULL(output_cnode);
1241         output = output_cnode->input(1);
1242       }
1243       while (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
1244         auto make_tuple_cnode = output->cast<CNodePtr>();
1245         output = make_tuple_cnode->input(out_index + 1);
1246       }
1247       prev_cnode = output->cast<CNodePtr>();
1248     }
1249 
1250     auto prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
1251     while (prev_prim->name() == prim::kPrimTupleGetItem->name() || prev_prim->name() == DEPEND) {
1252       prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
1253       if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
1254         return nullptr;
1255       }
1256       prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
1257     }
1258     return prev_cnode;
1259   }
1260   return nullptr;
1261 }
1262 
ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string & name,const std::string & uniqueid)1263 void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, const std::string &uniqueid) {
1264   size_t iter_ops = 0;
1265   for (const auto &op : entire_costgraph->GetOperators()) {
1266     if (op->name() == name) {
1267       break;
1268     }
1269     iter_ops = iter_ops + 1;
1270   }
1271 
1272   std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
1273   for (size_t i = 0; i < input_tensor_names.size(); i++) {
1274     for (size_t j = 0; j < input_tensor_names[i].size(); j++) {
1275       if (input_tensor_names[i][j] == uniqueid) {
1276         input_tensor_names[i][j] = input_tensor_names[iter_ops][0];
1277       }
1278     }
1279   }
1280 
1281   entire_costgraph->set_inputs_tensor_name_list(input_tensor_names);
1282 }
1283 
FindOperatorIndexById(const std::string & unique_id,const std::vector<std::vector<std::string>> & input_tensor_names)1284 size_t FindOperatorIndexById(const std::string &unique_id,
1285                              const std::vector<std::vector<std::string>> &input_tensor_names) {
1286   for (size_t i = 0; i < input_tensor_names.size(); i++) {
1287     if (input_tensor_names[i][0] == unique_id) {
1288       return i;
1289     }
1290   }
1291   return SIZE_MAX;
1292 }
1293 
GetIndexOfOpsSharingInputTensor(const std::vector<std::vector<std::string>> & param_users_uniqueid_list,const std::vector<std::vector<std::string>> & input_tensor_names)1294 std::vector<std::vector<size_t>> GetIndexOfOpsSharingInputTensor(
1295   const std::vector<std::vector<std::string>> &param_users_uniqueid_list,
1296   const std::vector<std::vector<std::string>> &input_tensor_names) {
1297   std::vector<std::vector<size_t>> param_users_ops_index;
1298   for (const auto &users_uniqueid : param_users_uniqueid_list) {
1299     std::vector<size_t> users_index;
1300     for (const auto &user_uniqueid : users_uniqueid) {
1301       size_t user_index = FindOperatorIndexById(user_uniqueid, input_tensor_names);
1302       if (user_index != SIZE_MAX) {
1303         users_index.push_back(user_index);
1304       }
1305     }
1306     param_users_ops_index.push_back(users_index);
1307   }
1308   return param_users_ops_index;
1309 }
1310 
CalculateMicroBatchSize(const std::shared_ptr<Graph> & graph,const FuncGraphPtr & root)1311 void CalculateMicroBatchSize(const std::shared_ptr<Graph> &graph, const FuncGraphPtr &root) {
1312   // The first dimension of an operator is its batch dimension.
1313   // However, the shape of the first dimension is not the batch_size assigned by users.
1314   // This function helps to calculate the micro batch size in the pipeline scenario.
1315 
1316   auto manager = root->manager();
1317   auto ops = entire_costgraph->GetOperators();
1318   AnfNodePtr virtual_dataset_;
1319   for (auto &fg : manager->func_graphs()) {
1320     for (auto &node : fg->nodes()) {
1321       if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
1322         virtual_dataset_ = node;
1323         break;
1324       }
1325     }
1326   }
1327   if (!virtual_dataset_) {
1328     // Normally for auto parallel, virtual dataset is required in order to control the input's parallel strategy.
1329     // However, in some test cases or NN, there is no input data.
1330     // This if condition aims to deal with these cases, and return 1.
1331     graph->micro_batch_size = 1;
1332     return;
1333   }
1334   auto node_user_map = manager->node_users();
1335   auto node_users = node_user_map[virtual_dataset_];
1336   int64_t data_user_size = 0;
1337   int64_t total_batch_size = 0;
1338   for (auto &node_user : node_users) {
1339     if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
1340       auto data_users = manager->node_users()[node_user.first];
1341       auto node_first = data_users.front().first;
1342       if (!IsPrimitiveCNode(node_first, prim::kPrimStridedSlice)) {
1343         data_users.clear();
1344         data_users = node_user_map[node_first];
1345       }
1346       data_user_size = int64_t(data_users.size());
1347     }
1348   }
1349 
1350   for (auto op : ops) {
1351     if (op->type() == GET_NEXT) {
1352       for (auto shape : op->outputs_shape()) {
1353         if (!shape.empty()) {
1354           total_batch_size = shape[0];
1355           break;
1356         }
1357       }
1358       break;
1359     }
1360   }
1361   if (data_user_size != 0) {
1362     graph->micro_batch_size = total_batch_size / data_user_size;
1363     MS_LOG(INFO) << "In the pipeline scenario, the micro_batch_size of each stage is " << graph->micro_batch_size;
1364   } else {
1365     MS_LOG(EXCEPTION) << "Data user size equals to 0, which could not be divided by the total batch size";
1366   }
1367 }
1368 
CreateNodesForCostGraph(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root)1369 void CreateNodesForCostGraph(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
1370   if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
1371     if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
1372       MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
1373                    << entire_costgraph->GetOperators().size() << " operators.";
1374     } else {
1375       MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
1376     }
1377   } else {
1378     if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
1379       MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
1380                    << entire_costgraph->GetOperators().size() << " operators.";
1381     } else {
1382       MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
1383     }
1384   }
1385 }
1386 
ReInitCostGraph(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root,bool dyn_shape_tmp_fix)1387 void ReInitCostGraph(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root, bool dyn_shape_tmp_fix) {
1388   InitCostGraph();
1389   CreateNodesForCostGraph(all_nodes, root);
1390   if (!dyn_shape_tmp_fix) {
1391     ReshapeCostCompute(all_nodes);
1392   }
1393 }
1394 
WriteStrategiesBackToAnfGraph(const std::vector<std::shared_ptr<OperatorInfo>> & ops)1395 void WriteStrategiesBackToAnfGraph(const std::vector<std::shared_ptr<OperatorInfo>> &ops) {
1396   for (auto &op : ops) {
1397     auto op_type = op->type();
1398     if (op_type == CAST || op_type == RESHAPE) {
1399       continue;
1400     }
1401     auto op_strategy = op->selected_strategy()->GetInputDim();
1402     if (!op_strategy.empty()) {
1403       std::vector<ValuePtr> strategies;
1404       (void)std::transform(op_strategy.begin(), op_strategy.end(), std::back_inserter(strategies),
1405                            [](const Dimensions &dim) { return MakeValue(dim); });
1406       ValueTuplePtr var = std::make_shared<ValueTuple>(strategies);
1407       op->cnode()->AddPrimalAttr(parallel::IN_STRATEGY, var);
1408     }
1409   }
1410 }
1411 
TMpInferBatchMatMul(const std::shared_ptr<Graph> & graph,Graph::NodeType * node)1412 void TMpInferBatchMatMul(const std::shared_ptr<Graph> &graph, Graph::NodeType *node) {
1413   if (node->apply.arguments[0].tensor_shape.shape_c != -1 && node->apply.arguments[1].tensor_shape.shape_c == -1) {
1414     auto infer_shape = node->apply.arguments[0].tensor_shape.shape_c;
1415     node->apply.arguments[1].tensor_shape.shape_c = infer_shape;
1416 
1417     if (node->node_out.size() == 0) {
1418       MS_LOG(EXCEPTION) << "The current BatchMatMul (" << node->name << ") does not have an outgoing node.";
1419     }
1420     auto &outgoing_node = graph->nodes[node->node_out[0]];
1421     if (outgoing_node.apply.arguments[0].tensor_shape.shape_c == node->tensor_parm.tensor_shape.shape_c) {
1422       outgoing_node.apply.arguments[0].tensor_shape.shape_c = infer_shape;
1423     }
1424 
1425     node->tensor_parm.tensor_shape.shape_c = infer_shape;
1426   }
1427 }
1428 
TmpInferForDynamicShapeInSAPP(const std::shared_ptr<Graph> & graph)1429 void TmpInferForDynamicShapeInSAPP(const std::shared_ptr<Graph> &graph) {
1430   for (size_t index = graph->nodes.size(); index > 0; index--) {
1431     auto node = graph->nodes[index - 1];
1432     if (node.apply.op_type == OperatorType::kRecBatchMatMul) {
1433       TMpInferBatchMatMul(graph, &node);
1434     }
1435   }
1436 }
1437 
HasUserConfiguredStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops)1438 bool HasUserConfiguredStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops) {
1439   for (auto op : ops) {
1440     auto prim_anf_node = GetValueNode<PrimitivePtr>(op->cnode()->input(0));
1441     bool has_user_configured_strategy = prim_anf_node->HasAttr(parallel::IN_STRATEGY);
1442     if (has_user_configured_strategy) {
1443       return true;
1444     }
1445   }
1446   return false;
1447 }
1448 
ParallelStrategyRecSearch(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root,size_t rank_id,const size_t device_num)1449 Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root, size_t rank_id,
1450                                  const size_t device_num) {
1451   bool dyn_shape_tmp_fix = false;
1452   if (device_num > 0) {
1453     dyn_shape_tmp_fix = true;
1454   }
1455 
1456   ReInitCostGraph(all_nodes, root, dyn_shape_tmp_fix);
1457   auto ops = entire_costgraph->GetOperators();
1458   if (dyn_shape_tmp_fix && HasUserConfiguredStrategy(ops)) {
1459     MS_LOG(WARNING) << "Now the split strategy will be automatically generated through SAPP, which will overwrite "
1460                        "the strategy that has been manually configured by the user.";
1461   }
1462 
1463   std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
1464   // Needed by rec_parser 2
1465   auto param_users_uniqueid_list = entire_costgraph->get_param_users_uniqueid_list();
1466   auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list();
1467   for (auto it = tuple_getitem_list.begin(); it != tuple_getitem_list.end();) {
1468     input_tensor_names = RecInputTensorNames(it++, input_tensor_names);
1469   }
1470   std::shared_ptr<Graph> graph = ParseGraph(ops, input_tensor_names);
1471 
1472   std::vector<std::vector<size_t>> param_users_ops_index =
1473     GetIndexOfOpsSharingInputTensor(param_users_uniqueid_list, input_tensor_names);
1474   std::shared_ptr<std::vector<std::vector<size_t>>> eli_list = std::make_shared<std::vector<std::vector<size_t>>>();
1475   std::shared_ptr<std::vector<size_t>> index_list = std::make_shared<std::vector<size_t>>();
1476   graph = EliminateGraph(graph, eli_list, index_list, dyn_shape_tmp_fix);
1477   graph->dyn_shape_tmp_fix = dyn_shape_tmp_fix;
1478 
1479   if (graph->dyn_shape_tmp_fix) {
1480     TmpInferForDynamicShapeInSAPP(graph);
1481   }
1482 
1483   if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) {
1484     CalculateMicroBatchSize(graph, root);
1485   }
1486 
1487   size_t num_device = g_device_manager->DeviceNum();
1488   const auto device_memory = CostModelContext::GetInstance()->device_memory_capacity();
1489   // To specify the process is training or inference. For training, if optimizer parallel is activated, it requires at
1490   // least one cut on DP dimension.
1491   bool isTraining = IsTraining(root->manager());
1492   if (PartitionForAllDevices(num_device, device_memory, graph, isTraining, root) == SUCCESS) {
1493     MS_LOG(INFO) << "Partition Success With " << num_device << " devices.";
1494   } else {
1495     MS_LOG(ERROR) << "PartitionForAllDevices failed.";
1496     return FAILED;
1497   }
1498 
1499   // Needed when changing stage number
1500   if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) {
1501     if (!graph->dyn_shape_tmp_fix) {
1502       if (ParallelInit() != SUCCESS) {
1503         MS_LOG(EXCEPTION) << "Parallel init failed after Rec search";
1504       }
1505     } else {
1506       if (ParallelInit(rank_id, device_num) != SUCCESS) {
1507         MS_LOG(EXCEPTION) << "Parallel init failed";
1508       }
1509     }
1510     if (parallel::ParallelContext::GetInstance()->auto_pipeline()) {
1511       ReInitCostGraph(all_nodes, root, graph->dyn_shape_tmp_fix);
1512       ops = entire_costgraph->GetOperators();
1513     }
1514   }
1515 
1516   GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list, isTraining, param_users_ops_index, root);
1517 
1518   // print the selected strategy
1519   for (auto &op : entire_costgraph->GetOperators()) {
1520     StrategyPtr s_strategy = op->selected_strategy();
1521     if (s_strategy != nullptr) {
1522       MS_LOG(INFO) << op->name() << ": The strategy is: " << s_strategy->ToString();
1523     }
1524   }
1525 
1526   if (graph->dyn_shape_tmp_fix) {
1527     (void)WriteStrategiesBackToAnfGraph(ops);
1528     (void)IgnoreOperatorsInCostGraph();
1529     ops_in_a_loop_.clear();
1530     configured_stra_ops_.clear();
1531     ignore_candidate_.clear();
1532     return SUCCESS;
1533   }
1534 
1535   if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
1536     MS_LOG(INFO) << "Init selected strategy succeeded.";
1537   } else {
1538     MS_LOG(ERROR) << "Init selected strategy failed.";
1539     return FAILED;
1540   }
1541 
1542   (void)IgnoreOperatorsInCostGraph();
1543   ops_in_a_loop_.clear();
1544   configured_stra_ops_.clear();
1545   ignore_candidate_.clear();
1546 
1547   return SUCCESS;
1548 }
1549 
LoadStrategyFromFile(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)1550 Status LoadStrategyFromFile(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
1551   InitCostGraph();
1552   bool use_sp = (ParallelContext::GetInstance()->strategy_search_mode() == kShardingPropagation) ||
1553                 (ParallelContext::GetInstance()->sharding_propagation());
1554   if (CostModelContext::GetInstance()->is_multi_subgraphs() || use_sp) {
1555     if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
1556       MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
1557                    << entire_costgraph->GetOperators().size() << " operators.";
1558     } else {
1559       MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
1560     }
1561   } else {
1562     if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
1563       MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
1564                    << entire_costgraph->GetOperators().size() << " operators.";
1565     } else {
1566       MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
1567     }
1568   }
1569 
1570   ReshapeCostCompute(all_nodes);
1571   // load strategy map from json
1572   StrategyMap stra_map;
1573   StrategyPtr strategy = nullptr;
1574   if ((StrategyCheckpoint::GetInstance().LoadAutoOpStrategy(&stra_map) != SUCCESS)) {
1575     return FAILED;
1576   }
1577   for (auto &op : entire_costgraph->GetOperators()) {
1578     std::string strategy_key_name = op->cnodes()[0]->fullname_with_scope();
1579     bool load_strategy_from_json = stra_map.find(strategy_key_name) != stra_map.end();
1580     if (!load_strategy_from_json) {
1581       MS_LOG(INFO) << "not found strategy for " << strategy_key_name;
1582       return FAILED;
1583     }
1584     strategy = stra_map[strategy_key_name];
1585     op->SetSelectedStrategy(strategy, 0);
1586   }
1587   if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
1588     MS_LOG(INFO) << "Init selected strategy succeeded.";
1589   } else {
1590     MS_LOG(INFO) << "Init selected strategy failed.";
1591     return FAILED;
1592   }
1593 
1594   // print the selected strategy
1595   for (auto &op : entire_costgraph->GetOperators()) {
1596     StrategyPtr s_strategy = op->selected_strategy();
1597     if (s_strategy != nullptr) {
1598       MS_LOG(INFO) << op->name() << ": The strategy is: " << s_strategy->ToString();
1599     }
1600   }
1601   (void)IgnoreOperatorsInCostGraph();
1602   ops_in_a_loop_.clear();
1603   configured_stra_ops_.clear();
1604   ignore_candidate_.clear();
1605 
1606   MS_LOG(INFO) << "End load strategies from file";
1607   return SUCCESS;
1608 }
1609 
SaveStrategyToFile()1610 void SaveStrategyToFile() {
1611   StrategyMap stra_map;
1612   TensorInfoMap tensor_info_map;
1613   ManualShapeMap manual_shape_map;
1614 
1615   for (auto &op : entire_costgraph->GetOperators()) {
1616     StrategyPtr s_strategy = op->selected_strategy();
1617     std::string strategy_key_name = op->cnodes()[0]->fullname_with_scope();
1618     stra_map[strategy_key_name] = s_strategy;
1619   }
1620   if (StrategyCheckpoint::GetInstance().SaveAutoOpStrategy(stra_map, tensor_info_map, manual_shape_map) != SUCCESS) {
1621     MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
1622   }
1623   MS_LOG(INFO) << "Success save strategies to file.";
1624 }
1625 }  // namespace parallel
1626 }  // namespace mindspore
1627