• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/step_auto_parallel.h"
18 
19 #include <cinttypes>
20 #include <ctime>
21 #include <algorithm>
22 #include <map>
23 #include <memory>
24 #include <set>
25 #include <string>
26 #include <unordered_map>
27 #include <utility>
28 #include <vector>
29 #include <unordered_set>
30 
31 #include "base/core_ops.h"
32 #include "frontend/optimizer/opt.h"
33 #include "frontend/optimizer/optimizer.h"
34 #include "frontend/parallel/auto_parallel/dp_algo_costmodel.h"
35 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
36 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
37 #include "frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h"
38 #include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h"
39 #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h"
40 #include "frontend/parallel/context.h"
41 #include "frontend/parallel/graph_util/node_info.h"
42 #include "frontend/parallel/graph_util/graph_info.h"
43 #include "frontend/parallel/ops_info/reshape_info.h"
44 #include "frontend/parallel/ops_info/tmp_identity_info.h"
45 #include "frontend/parallel/step_parallel.h"
46 #include "frontend/parallel/parameter_manager.h"
47 #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
48 #include "ir/anf.h"
49 #include "ir/param_info.h"
50 #include "ir/tensor.h"
51 #if ((defined ENABLE_CPU) && (!defined _WIN32))
52 #include "ps/util.h"
53 #endif
54 
55 namespace mindspore {
56 namespace parallel {
StepAutoParallel(const FuncGraphPtr & root,const opt::OptimizerPtr &)57 bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
58 #if ((defined ENABLE_CPU) && (!defined _WIN32))
59   if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) {
60     return false;
61   }
62 #endif
63   MS_EXCEPTION_IF_NULL(root);
64   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
65   std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
66   // assume no change to graph
67   bool changes = false;
68   // control whether use model_parallel mode
69   if (!root->has_flag(AUTO_PARALLEL) || (parallel_mode != AUTO_PARALLEL) ||
70       root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) {
71     return changes;
72   }
73 
74   // check whether strategy_search_mode is valid
75   std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode();
76   if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) {
77     // Setting searching mode: dynamic programming as default.
78     strategy_search_mode = DYNAMIC_PROGRAMMING;
79     MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default";
80   }
81 
82   struct timeval start_time {
83     0
84   }, end_time{0};
85   (void)gettimeofday(&start_time, nullptr);
86 #ifdef ENABLE_DUMP_IR
87   if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
88     draw::Draw(STEP_AUTO_PARALLEL_BEGIN, root);
89   }
90 #endif
91   MS_LOG(INFO) << "Now entering step auto parallel";
92   TOTAL_OPS = 0;
93   AnfNodePtr ret = root->get_return();
94   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
95   if (ParallelInit() != SUCCESS) {
96     MS_LOG(EXCEPTION) << "Parallel init failed";
97   }
98   // mark the forward cnodes, parallel only care these nodes
99   MarkForwardCNode(root);
100   if (IsInsertVirtualOutput(root)) {
101     InsertVirtualOutput(root, all_nodes);
102     AnfNodePtr ret_after = root->get_return();
103     MS_EXCEPTION_IF_NULL(ret_after);
104     all_nodes = DeepScopedGraphSearch(ret_after);
105   }
106   if (FindCommunicationOp(all_nodes)) {
107     MS_LOG(EXCEPTION) << "The graph contain communication op";
108   }
109 
110   // search parallelization strategy
111   if (strategy_search_mode == DYNAMIC_PROGRAMMING) {
112     if (ParallelStrategySearch(all_nodes, root) != SUCCESS) {
113       MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using DP searching mode";
114     }
115   } else if (strategy_search_mode == RECURSIVE_PROGRAMMING) {
116     if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) {
117       MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode";
118     }
119   } else {
120     MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected";
121   }
122 
123   (void)gettimeofday(&end_time, nullptr);
124   uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
125   time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
126   MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us";
127 
128   root->set_flag(AUTO_PARALLEL_RUN_ONCE_ONLY, true);
129   return changes;
130 }
131 
IsElementWiseOperator(const std::string & op_name)132 bool IsElementWiseOperator(const std::string &op_name) {
133   // clang-format off
134   static const std::set<std::string> elementwise_op = {ACTIVATION, GELU,         TANH,
135                                                        SOFTMAX,    LOG_SOFTMAX,  RELU,
136                                                        SQRT,       CAST,         POW,
137                                                        EXP,        LOG,          COS,
138                                                        ACOS,       LOGICALNOT,   NEG,
139                                                        SQUARE,     SIGMOID,      ABS,
140                                                        ACOSH,      ASIN,         ASINH,
141                                                        ATAN,       ATANH,        CEIL,
142                                                        COSH,       EXPM1,        LOG1P,
143                                                        SIN,        SINH,         TAN,
144                                                        RSQRT,      RECIPROCAL,   INV,
145                                                        ROUND,      FLOOR,        SIGN,
146                                                        ERF,        ERFC,         ZEROSLIKE,
147                                                        ONESLIKE,   BESSELI0E,    MOD,
148                                                        ASSIGN,     ASSIGN_ADD,   ATAN2,
149                                                        DIVNONAN,   LOGICALAND,   ELU,
150                                                        LOGICALOR,  RELU6,        SOFTPLUS,
151                                                        SOFTSIGN,   LESS,         LESSEQUAL,
152                                                        BESSELI1E,  GREATEREQUAL, APPROXIMATEEQUAL,
153                                                        REPEAT_ELEMENTS};
154   // clang-format on
155   auto iter = elementwise_op.find(op_name);
156   return (iter != elementwise_op.end());
157 }
158 
IsSplittableOperator(const std::string & op_name)159 bool IsSplittableOperator(const std::string &op_name) {
160   // clang-format off
161   static const std::set<std::string> splittable_op =
162     {MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
163      FLOORDIV, L2_NORMALIZE, ADD, MAXPOOL, AVGPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
164      REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
165      MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, STACK,
166      LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
167      STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT,
168      SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
169      EMBEDDING_LOOKUP, FUSE_BATCH_NORM_EX, SPLIT, BROADCAST_TO, ABS, ACOSH, ASIN, ASINH, ATAN, ATANH, CEIL, COSH,
170      EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE,
171      BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
172      SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
173      UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD,
174      UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE,
175      MATMUL_DDS, DSD_MATMUL, RESIZE_BILINEAR, RESIZE_NEAREST_NEIGHBOR, UNIFORMREAL};
176   // clang-format on
177 
178   auto iter = splittable_op.find(op_name);
179   return (iter != splittable_op.end());
180 }
181 
IsAutoParallelCareNode(const CNodePtr & cnode)182 bool IsAutoParallelCareNode(const CNodePtr &cnode) {
183   MS_EXCEPTION_IF_NULL(cnode);
184   ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
185   if (prim_node == nullptr) {
186     return false;
187   }
188   PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_node);
189   if (prim == nullptr) {
190     return false;
191   }
192   bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name());
193   if (bool_result && (prim->name() != MAKE_TUPLE) && (prim->name() != MAKE_LIST)) {
194     MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name();
195   } else if (prim->name() == CAST) {
196     if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) {
197       // Do not care CASTs from optimizer
198       return false;
199     }
200     return true;
201   }
202   return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name());
203 }
204 
205 // Recording the operators appearing in a for-loop.
206 // Currently, we assume that the operators in different for-loops are identical, and their traversal
207 // orderings are also identical.
208 // Therefore, we create OperatorInfo objects for the operators in a loop (say, loop-3), and reuse them in
209 // the rest of loops (loop-2, loop-1 and loop-0)
210 std::set<std::string> ops_in_a_loop_;
211 // Whether two operators are in different loops; if it is true, then return true.
212 // If at least one of the two operators is not in the loop, then return false.
213 // If two operators are in the same loop, the return false.
IsOperatorsInTwoSeparateLoops(const CNodePtr & a_cnode,const CNodePtr & b_cnode)214 bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cnode) {
215   auto a_op_info = a_cnode->user_data<OperatorInfo>();
216   MS_EXCEPTION_IF_NULL(a_op_info);
217   auto b_op_info = b_cnode->user_data<OperatorInfo>();
218   MS_EXCEPTION_IF_NULL(b_op_info);
219   if ((ops_in_a_loop_.find(a_op_info->name()) == ops_in_a_loop_.end()) ||
220       (ops_in_a_loop_.find(b_op_info->name()) == ops_in_a_loop_.end())) {
221     return false;
222   }
223   size_t a_loop_index = 0, b_loop_index = 0;
224   const auto &a_fullname = a_cnode->fullname_with_scope();
225   if (!GetLoopIndexFromCNode(a_cnode, &a_loop_index)) {
226     MS_LOG(EXCEPTION) << "The operator with fullname_with_scope: " << a_fullname << " was not included in the set.";
227   }
228   const auto &b_fullname = b_cnode->fullname_with_scope();
229   if (!GetLoopIndexFromCNode(b_cnode, &b_loop_index)) {
230     MS_LOG(EXCEPTION) << "The operator with fullname_with_scope: " << b_fullname << " was not included in the set.";
231   }
232   if (a_loop_index == b_loop_index) {
233     return false;
234   }
235   return true;
236 }
237 
238 // 'configured_stra_ops_' includes all operators that are configured sharding strategies.
239 std::map<OperatorInfoPtr, StrategyPtr> configured_stra_ops_;
InitCostGraph()240 void InitCostGraph() {
241   if (entire_costgraph == nullptr) {
242     entire_costgraph = std::make_shared<CostGraph>();
243   }
244   MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
245   CostModelContext::GetInstance()->PrintCostModel();
246   entire_costgraph->Init();
247   configured_stra_ops_.clear();
248 }
249 
SetStrategyToOperator(const OperatorInfoPtr & operator_info,const PrimitivePtr & prim,std::unordered_map<std::string,ValuePtr> attrs,bool,StrategyMap * stra_map,const std::string & strategy_key_name)250 void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const PrimitivePtr &prim,
251                            std::unordered_map<std::string, ValuePtr> attrs, bool, StrategyMap *stra_map,
252                            const std::string &strategy_key_name) {
253   // In this case, the configured strategy should be extracted to help setting cost
254   StrategyPtr strategyPtr;
255   if (StrategyFound(attrs)) {
256     strategyPtr = parallel::ExtractStrategy(attrs[STRATEGY]);
257   } else {
258     strategyPtr = (*stra_map)[strategy_key_name];
259   }
260   if (strategyPtr != nullptr) {
261     if (prim->name() == RESHAPE) {
262       MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
263     }
264     const auto fully_use_devices = CostModelContext::GetInstance()->fully_use_device();
265     // Set cost for this configured strategy
266     if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) {
267       MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed";
268     } else if (fully_use_devices) {
269       // If configured to fully use devices, then checking for the user-specified strategy
270       int64_t used_devices = operator_info->used_devices();
271       MS_EXCEPTION_IF_NULL(g_device_manager);
272       auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size();
273       // 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel
274       if (used_devices == 1) {
275         (void)configured_stra_ops_.emplace(operator_info, strategyPtr);
276         return;
277       }
278       // 'used_devices == -1' means that 'used_devices_' is not set
279       if ((used_devices == -1) || LongToSize(used_devices) != total_device_num) {
280         MS_LOG(EXCEPTION) << "In current configuration 'fully_use_devices' = True, "
281                           << "but the specified strategy uses device: " << used_devices
282                           << ", total devices: " << total_device_num
283                           << ", try to set 'set_algo_parameters(fully_use_devices=False)' "
284                              "in package 'mindspore.parallel'.";
285       }
286     }
287     (void)configured_stra_ops_.emplace(operator_info, strategyPtr);
288   }
289 }
290 
CreateTheOperatorInfo(const PrimitivePtr & prim,const CNodePtr & cnode,bool is_last_nodes,StrategyMap * stra_map)291 OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes,
292                                       StrategyMap *stra_map) {
293   MS_EXCEPTION_IF_NULL(prim);
294   MS_EXCEPTION_IF_NULL(cnode);
295   auto attrs = prim->attrs();
296   std::vector<Shapes> shape_list = ExtractShape(cnode);
297   if (shape_list.empty()) {
298     MS_LOG(EXCEPTION) << "Failure: node " << cnode->UniqueId() << " failed to extract shape";
299   }
300   // Create an OperatorInfo instance
301   OperatorInfoPtr operator_info = NewOperatorInstance(prim, attrs, shape_list);
302   MS_EXCEPTION_IF_NULL(operator_info);
303   // Set the parameter information for this OperatorInfo (whether the inputs are parameters or not)
304   std::vector<bool> parameter_info = ExtractInputParameterByNode(cnode);
305   if (operator_info->set_is_parameter(parameter_info) != SUCCESS) {
306     MS_LOG(ERROR) << "Initializing parameter information failed for operator: " << operator_info->name();
307     return nullptr;
308   }
309   // Set the data type for inputs and outputs of this OperatorInfo
310   auto inputs_type_length = ExtractInputTypeLengthByNode(cnode);
311   auto outputs_type = ExtractOutputTypeByNode(cnode);
312   std::vector<size_t> outputs_type_length;
313   outputs_type_length.reserve(outputs_type.size());
314   std::transform(outputs_type.begin(), outputs_type.end(), std::back_inserter(outputs_type_length),
315                  GetLengthOfDataType);
316   if (operator_info->SetInputAndOutputTypeLength(inputs_type_length, outputs_type_length) != SUCCESS) {
317     MS_LOG(ERROR) << "Setting the lengths of inputs and outputs failed for operator: " << operator_info->name();
318     return nullptr;
319   }
320   if (operator_info->set_outputs_type(outputs_type) != SUCCESS) {
321     MS_LOG(ERROR) << "Setting the types of outputs failed for operator: " << operator_info->name();
322     return nullptr;
323   }
324   // When the 'inputs' contains numerical values for some operators, these values should be extracted from
325   // ANF graph
326   auto &inputs = cnode->inputs();
327   std::vector<ValuePtr> input_value;
328   for (size_t index = 1; index < inputs.size(); ++index) {
329     if (inputs[index]->isa<ValueNode>()) {
330       input_value.push_back(GetValueNode(inputs[index]));
331     } else {
332       input_value.emplace_back(nullptr);
333     }
334   }
335   operator_info->set_input_value(input_value);
336   operator_info->set_outputs_dtype(cnode->Type());
337   operator_info->set_cnode(cnode);
338   // key of strategy map
339   std::string strategy_key_name = "";
340   auto param_names = NodeParameterName(cnode, -1, 0);
341   if (!param_names.empty()) {
342     strategy_key_name = prim->name() + "_" + param_names[0].first;
343   }
344   bool load_strategy_from_ckpt =
345     StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
346   // If no strategy has been configured for this operator, then candidate strategies are generated for
347   // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
348   // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
349   if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) {
350     // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
351     // BatchParallelInfo operator
352     operator_info->ComputeBatchSplitFlagList();
353     if (operator_info->GenerateStrategies(0) != SUCCESS) {
354       MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed.";
355       return nullptr;
356     }
357     if (ParallelContext::GetInstance()->sharding_propagation() &&
358         (operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos)) {
359       const auto &swc_vec = operator_info->GetStrategyCost();
360       if (swc_vec.empty()) {
361         MS_LOG(EXCEPTION) << "No available strategy for: " << operator_info->name();
362       }
363       MS_EXCEPTION_IF_NULL(swc_vec[0]->strategy_ptr);
364       (void)configured_stra_ops_.emplace(operator_info, swc_vec[0]->strategy_ptr);
365     }
366     // If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated
367     auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
368     if (approximation) {
369       operator_info->ApproximateStrategies();
370       MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
371     }
372   } else {
373     SetStrategyToOperator(operator_info, prim, attrs, is_last_nodes, stra_map, strategy_key_name);
374   }
375   return operator_info;
376 }
377 
IsFindWrong(const OperatorInfoPtr current_op_ptr,const std::string & prim_name)378 bool IsFindWrong(const OperatorInfoPtr current_op_ptr, const std::string &prim_name) {
379   bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
380                        (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) &&
381                        (current_op_ptr->name().find(prim_name + "Info") == std::string::npos);
382   if (prim_name == GATHERV2) {
383     is_find_wrong = is_find_wrong && (current_op_ptr->name().find(prim_name + "PInfo") == std::string::npos);
384   }
385   return is_find_wrong;
386 }
387 
388 // Using CNode's UniqueIds to construct nodes
ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr &)389 Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
390   MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
391   // The map from CNode's UniqueId to its operatorInfo
392   std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
393   // The operator_infos in a loop
394   std::vector<OperatorInfoPtr> operators_in_forloop;
395   // Key: i-th loop; Value: index of 'operators_in_forloop'
396   std::map<size_t, size_t> loop_to_ops;
397   // extract strategy from checkpoint for multi-train
398   StrategyMap stra_map;
399   if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
400     if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
401       MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
402     }
403   }
404 
405   for (auto &node : all_nodes) {
406     // NOTE: we only care about splittable Primitive operators
407     auto cnode = node->cast<CNodePtr>();
408     bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)));
409     if (bool_result) {
410       continue;
411     }
412     ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
413     if (!IsAutoParallelCareNode(cnode)) {
414       // Needed by rec_parser
415       if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) {
416         auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node);
417         if (prev_cnode != nullptr) {
418           entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId()));
419         }
420       }
421       continue;
422     }
423     PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
424     MS_EXCEPTION_IF_NULL(prim);
425 
426     auto search_cnode = from_cnode_to_info.find(cnode->UniqueId());
427     if (search_cnode == from_cnode_to_info.end()) {
428       size_t loop_index = 0;
429       bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index);
430       const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop();
431       if (single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) {
432         const auto &current_op_ptr = operators_in_forloop[loop_to_ops[loop_index]];
433         if (IsFindWrong(current_op_ptr, prim->name())) {
434           MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
435                             << " does not match the Prim: " << prim->name()
436                             << ". The fullname_with_scope: " << cnode->fullname_with_scope();
437         }
438         loop_to_ops[loop_index]++;
439         cnode->set_user_data<OperatorInfo>(current_op_ptr);
440         MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
441                      << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
442                      << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
443                      << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
444         (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), current_op_ptr));
445         continue;
446       }
447       bool is_last_nodes = IsPrimitiveCNode(cnode, prim::kPrimVirtualOutput);
448       auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
449       if (operator_info == nullptr) {
450         return FAILED;
451       }
452       // Needed by rec_parser
453       operator_info->set_type(prim->name());
454       operator_info->set_last_node_flag(is_last_nodes);
455       std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
456 
457       entire_costgraph->AddOperator(operator_info);
458       cnode->set_user_data<OperatorInfo>(operator_info);
459       MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
460                    << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
461                    << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
462                    << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
463       (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), operator_info));
464       if (single_loop && is_in_loop) {
465         operators_in_forloop.push_back(operator_info);
466         ops_in_a_loop_.insert(operator_info->name());
467         loop_to_ops[loop_index]++;
468       }
469       // Needed by rec_parser
470       entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
471     } else {
472       // Two CNODEs' UniqueIds should not be equal
473       MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId()
474                         << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
475                         << " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name();
476     }
477   }
478 
479   MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
480   return SUCCESS;
481 }
482 
SetOperatorToCNode(const OperatorInfoPtr & current_op_ptr,const PrimitivePtr & prim,const CNodePtr & cnode)483 void SetOperatorToCNode(const OperatorInfoPtr &current_op_ptr, const PrimitivePtr &prim, const CNodePtr &cnode) {
484   if (current_op_ptr == nullptr) {
485     MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed.";
486   } else {
487     if (IsFindWrong(current_op_ptr, prim->name())) {
488       MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
489                         << " does not match the Prim: " << prim->name();
490     }
491 
492     // Needed by rec_parser
493     ModifyInputsTensorNameListIfOperatorInfoCreated(current_op_ptr->name(), cnode->UniqueId());
494 
495     cnode->set_user_data<OperatorInfo>(current_op_ptr);
496     MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
497                  << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
498                  << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
499                  << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
500   }
501 }
502 
503 // Using CNode's UniqueIdThroughCopys to construct nodes
ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr &)504 Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
505   MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
506   // The map from CNode's UniqueIdThroughCopy to its operatorInfo
507   std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
508   // The operator_infos in a loop
509   std::vector<OperatorInfoPtr> operators_in_forloop;
510   // Key: i-th loop; Value: index of 'operators_in_forloop'
511   std::map<size_t, size_t> loop_to_ops;
512   // extract strategy from checkpoint for multi-train
513   StrategyMap stra_map;
514   if (StrategyCheckpoint::GetInstance().LoadCheckPointOn() &&
515       StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
516     MS_LOG(WARNING) << "Load strategy checkpoint failed";
517     return FAILED;
518   }
519   for (auto &node : all_nodes) {
520     // NOTE: we only care about splittable Primitive operators
521     auto cnode = node->cast<CNodePtr>();
522     if ((cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)))) {
523       continue;
524     }
525     ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
526     if (!IsAutoParallelCareNode(cnode)) {
527       // Needed by rec_parser
528       if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) {
529         auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node);
530         if (prev_cnode != nullptr) {
531           entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId()));
532         }
533       }
534       continue;
535     }
536     PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
537 
538     // Find the operatorInfo if it exists
539     auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy());
540     if (search_cnode == from_cnode_to_info.end()) {
541       size_t loop_index = 0;
542       bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index);
543       const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop();
544       bool is_op_created = single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size());
545       if (is_op_created) {
546         const auto &current_op_ptr = operators_in_forloop[loop_to_ops[loop_index]];
547         if (IsFindWrong(current_op_ptr, prim->name())) {
548           MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
549                             << " does not match the Prim: " << prim->name()
550                             << ". The fullname_with_scope: " << cnode->fullname_with_scope();
551         }
552         loop_to_ops[loop_index]++;
553         cnode->set_user_data<OperatorInfo>(current_op_ptr);
554         MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
555                      << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
556                      << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
557                      << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
558         (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), current_op_ptr));
559         continue;
560       }
561       // In this case, the corresponding OperatorInfo is not created, create the new one.
562       bool is_last_nodes = IsPrimitiveCNode(cnode, prim::kPrimVirtualOutput);
563       auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
564       MS_EXCEPTION_IF_NULL(operator_info);
565 
566       // Needed by rec_parser
567       operator_info->set_type(prim->name());
568       operator_info->set_last_node_flag(is_last_nodes);
569       std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
570 
571       entire_costgraph->AddOperator(operator_info);
572       cnode->set_user_data<OperatorInfo>(operator_info);
573       MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
574                    << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
575                    << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
576                    << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
577       (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info));
578       if (single_loop && is_in_loop) {
579         operators_in_forloop.push_back(operator_info);
580         ops_in_a_loop_.insert(operator_info->name());
581         loop_to_ops[loop_index]++;
582       }
583       // Needed by rec_parser
584       entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
585     } else {
586       SetOperatorToCNode(search_cnode->second, prim, cnode);
587     }
588   }
589 
590   MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
591   return SUCCESS;
592 }
593 
CreateEdgeBetweenTwoOps(const OperatorInfoPtr & prev_op_info,const OperatorInfoPtr & node_op_info,const CNodePtr & cnode,const CNodePtr & prev_cnode,const PrimitivePtr & prim,const PrimitivePtr & prev_prim,size_t output_index,size_t input_index,size_t * edge_count)594 void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const OperatorInfoPtr &node_op_info,
595                              const CNodePtr &cnode, const CNodePtr &prev_cnode, const PrimitivePtr &prim,
596                              const PrimitivePtr &prev_prim, size_t output_index, size_t input_index,
597                              size_t *edge_count) {
598   std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name();
599   // If the edge between these two operators already has been added, then the edge will not be added again.
600   if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, input_index - 1)) {
601     return;
602   }
603   EdgePtr edge_ptr;
604   MS_LOG(INFO) << "Creating edge: " << edge_name;
605   if (IsOperatorsInTwoSeparateLoops(prev_cnode, cnode)) {
606     MS_LOG(INFO) << "prev_cnode_fullname: " << prev_cnode->fullname_with_scope()
607                  << ", cnode_fullname: " << cnode->fullname_with_scope();
608     MS_LOG(INFO) << "The two operators in two separate for-loops, thus skip the edge.";
609     return;
610   }
611   const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow();
612   bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) ||
613                          (stra_follow && IsElementWiseOperator(prev_prim->name()));
614   if (follow_strategy) {
615     // Redistribution in not allowed on the edge.
616     // Elementwise operators have the same strategy as their previous operators.
617     edge_ptr =
618       std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, input_index - 1, false, true);
619   } else {
620     edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, input_index - 1, false);
621   }
622 
623   // Init costs for this edge
624   if (edge_ptr->InitEdgeCost() != SUCCESS) {
625     MS_LOG(EXCEPTION) << "Edge cost initialization failed";
626   }
627   node_op_info->AddPrevEdge(edge_ptr);
628   prev_op_info->AddSuccEdge(edge_ptr);
629   entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr);
630   if (ParallelContext::GetInstance()->sharding_propagation() && (prev_prim->name() == CAST) &&
631       (configured_stra_ops_.find(node_op_info) != configured_stra_ops_.end())) {
632     const auto next_op_stra = configured_stra_ops_[node_op_info];
633     const auto cast_stra = edge_ptr->GetPrevOpStrategyByNextOpStrategyWithZeroComm(next_op_stra);
634     if (cast_stra == nullptr) {
635       MS_LOG(EXCEPTION) << "No available strategy for: " << prev_op_info->name();
636     }
637     prev_op_info->ClearStrategyCost();
638     if (prev_op_info->SetCostUnderStrategy(cast_stra) != SUCCESS) {
639       MS_LOG(EXCEPTION) << "Failure: operator " << prev_op_info->name() << " SetCostUnderStrategy failed";
640     }
641     if (edge_ptr->InitEdgeCost() != SUCCESS) {
642       MS_LOG(EXCEPTION) << "Edge cost re-initialization failed.";
643     }
644     MS_LOG(INFO) << "Set strategy for: " << prev_op_info->name() << " under the strategy of: " << node_op_info->name();
645     (void)configured_stra_ops_.emplace(prev_op_info, cast_stra);
646   }
647   MS_LOG(INFO) << "Successfully adding the edge between " << prev_op_info->name() << " and " << node_op_info->name();
648   (*edge_count)++;
649 }
650 
ConstructCostGraphEdges(const std::vector<AnfNodePtr> & all_nodes)651 void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
652   // Step 2
653   MS_LOG(INFO) << "Constructing edges for cost graph begins.";
654   for (auto &node : all_nodes) {
655     auto cnode = node->cast<CNodePtr>();
656     if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
657       continue;
658     }
659     auto &inputs = cnode->inputs();
660     ValueNodePtr prim_anf_node = inputs[0]->cast<ValueNodePtr>();
661     if (!IsAutoParallelCareNode(cnode)) {
662       continue;
663     }
664     PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
665     size_t edge_count = 0;
666     auto node_op_info = cnode->user_data<OperatorInfo>();
667 
668     for (size_t i = 1; i < inputs.size(); ++i) {
669       auto prev_cnode = inputs[i]->cast<CNodePtr>();
670       bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
671       if (bool_result_prev_cnode) {
672         continue;
673       }
674       ValueNodePtr prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
675       PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
676       size_t output_index = 0;
677 
678       while ((IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) ||
679              (prev_prim->name() == DEPEND)) {
680         if (IsAutoParallelCareNode(prev_cnode)) {
681           auto prev_op_info = prev_cnode->user_data<OperatorInfo>();
682           CreateEdgeBetweenTwoOps(prev_op_info, node_op_info, cnode, prev_cnode, prim, prev_prim, output_index, i,
683                                   &edge_count);
684           break;
685         } else if (prev_prim->name() == prim::kTupleGetItem) {
686           // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before
687           // this 'tuple_getitem'
688           MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator.";
689           output_index = LongToSize(GetValue<int64_t>(GetValueNode(prev_cnode->input(2))));
690           prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
691           bool bool_result_tuple = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
692           if (bool_result_tuple) {
693             break;
694           }
695           prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
696           prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
697           if (!IsAutoParallelCareNode(prev_cnode)) {
698             MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name();
699           }
700           MS_LOG(INFO) << "Jumped the 'tuple_getitem' operator, "
701                        << "and creating an edge between the Operator before "
702                        << "'tuple_getitem' and the Operator after 'tuple_getitem'.";
703         } else if (prev_prim->name() == DEPEND) {
704           // In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before
705           // this 'depend'
706           MS_LOG(INFO) << "Jumping the 'depend' operator.";
707           prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
708           bool bool_result_depend = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
709           if (bool_result_depend) {
710             break;
711           }
712           prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
713           prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
714           MS_LOG(INFO) << "Jumped the 'depend' operator, "
715                        << "and creating an edge between the Operator before "
716                        << "'depend' and the Operator after 'depend'.";
717         }
718       }
719     }
720     MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name();
721   }
722   // If 'approximation' is enabled, the edges need to be checked have effective costs.
723   auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
724   if (approximation) {
725     entire_costgraph->CheckApproximateCostGraphEdges();
726   }
727 
728   MS_LOG(INFO) << "Constructing edges for cost graph ends.";
729 }
730 
AugmentCostGraph(const std::vector<AnfNodePtr> & all_nodes)731 void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
732   // Step 3
733   for (auto &node : all_nodes) {
734     ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsAutoParallelCareNode);
735     auto parameter_name = parameter_users_info.first;
736     auto target_parameter = parameter_users_info.second.first;
737     auto target_set = parameter_users_info.second.second;
738     if (target_set.size() <= 1) {
739       continue;
740     }
741 
742     // Rule out the case when a Parameter being used by a Operator, but the Operator appears in multiple CNODEs
743     std::set<std::string> target_without_duplicate;
744     for (auto &target : target_set) {
745       auto target_cnode = target.first->cast<CNodePtr>();
746       auto input_index = target.second;
747       (void)target_without_duplicate.insert(std::to_string(input_index) +
748                                             target_cnode->user_data<OperatorInfo>()->name());
749     }
750     if (target_without_duplicate.size() <= 1) {
751       continue;
752     }
753 
754     // Here, it is sure that this Parameter (RefKey) is being used by multiple Operators.
755     OperatorInfoPtr tmp_identity_ptr;
756     bool new_identity = false;
757     std::string tmp_identity_name;
758     auto returned_identity = entire_costgraph->FindTmpIdentityByParameterName(parameter_name);
759     if (returned_identity != nullptr) {
760       // In this case, the TmpIdentityInfo instance has already been created
761       new_identity = false;
762       tmp_identity_ptr = returned_identity;
763       tmp_identity_name = tmp_identity_ptr->name();
764     } else {
765       // In the case, the TmpIdentityInfo instance has NOT been created. Thus, a new one is created.
766       new_identity = true;
767       // 1) extract input shape from this Parameter
768       MS_EXCEPTION_IF_NULL(target_parameter);
769       AbstractBasePtr abstract = target_parameter->abstract();
770       if (abstract == nullptr) {
771         MS_LOG(EXCEPTION) << "Failure: abstract is nullptr";
772       }
773       auto input_shape = dyn_cast<abstract::Shape>(abstract->GetShapeTrack());
774       if (input_shape == nullptr) {
775         MS_LOG(EXCEPTION) << "Failure: input_shape is nullptr";
776       }
777       Shape shape = input_shape->shape();
778       Shapes inputs_shape = {shape};
779       Shapes outputs_shape = {shape};
780       // 2) init the attr
781       std::unordered_map<std::string, ValuePtr> attr = {};
782 
783       // Create the TmpIdentity instance
784       tmp_identity_ptr = std::make_shared<TmpIdentityInfo>(inputs_shape, outputs_shape, attr);
785       tmp_identity_ptr->set_name(tmp_identity_ptr->name() + std::to_string(TOTAL_OPS));
786       TOTAL_OPS++;
787       tmp_identity_ptr->set_refkey_parameter_name(parameter_name);
788       // Set the parameter and type lengths for inputs and outputs
789       std::vector<bool> is_parameter;
790       auto casted_target_parameter = target_parameter->cast<ParameterPtr>();
791       MS_EXCEPTION_IF_NULL(casted_target_parameter);
792       is_parameter.push_back(ParameterRequireGrad(casted_target_parameter));
793       if (tmp_identity_ptr->set_is_parameter(is_parameter) != SUCCESS) {
794         MS_LOG(EXCEPTION) << "Setting parameter for TmpIdentityInfo failed";
795       }
796       auto node_type = target_parameter->Type();
797       if (node_type->isa<mindspore::TensorType>()) {
798         auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
799         std::vector<size_t> type_length = {GetLengthOfDataType(input_element_type)};
800         if (tmp_identity_ptr->SetInputAndOutputTypeLength(type_length, type_length) != SUCCESS) {
801           MS_LOG(EXCEPTION) << "Setting input and output type length for TmpIdentityInfo failed";
802         }
803       } else {
804         MS_LOG(EXCEPTION) << "Unknown type: " << node_type->type_name();
805       }
806 
807       // Generate strategies for this TmpIdentityInfo instance;
808       if (tmp_identity_ptr->GenerateStrategies(0) != SUCCESS) {
809         MS_LOG(EXCEPTION) << "Strategy search for Operator failed : " << tmp_identity_ptr->name();
810       }
811     }
812     // A flag recording whether new edges have been created or not
813     bool add_identity_edge = false;
814 
815     // Create edges between this TmpIdentityInfo instance and subsequent Operator instances
816     for (auto &target : target_set) {
817       auto target_cnode = target.first->cast<CNodePtr>();
818       auto input_index = target.second;
819       auto target_op_info = target_cnode->user_data<OperatorInfo>();
820 
821       std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name();
822       // If the edge between these two operators already has been added, then the edge will not be added again.
823       if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, LongToSize(input_index - 1))) {
824         continue;
825       }
826       std::shared_ptr<Edge> edge_ptr =
827         std::make_shared<Edge>(edge_name, tmp_identity_ptr, target_op_info, 0, input_index - 1, false, true);
828       // If 'approximation' is enabled, the edges need to be checked have effective costs.
829       auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
830       if (approximation) {
831         target_op_info->ExactStrategiesAndRelatedEdges();
832       }
833 
834       if (edge_ptr->InitEdgeCost() != SUCCESS) {
835         MS_LOG(EXCEPTION) << "Edge cost initialization failed";
836       }
837       target_op_info->AddPrevEdge(edge_ptr);
838       tmp_identity_ptr->AddSuccEdge(edge_ptr);
839       entire_costgraph->AddEdge(tmp_identity_ptr, target_op_info, edge_ptr);
840       MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and "
841                    << target_op_info->name();
842       add_identity_edge = true;
843     }
844     if (new_identity && add_identity_edge) {
845       // Add the TmpIdentityInfo to CostGraph if BOTH two conditions are satisfied
846       entire_costgraph->AddOperator(tmp_identity_ptr);
847     }
848   }
849 }
850 
ReshapeCostCompute(const std::vector<AnfNodePtr> & all_nodes)851 void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
852   std::unordered_set<std::string> op_cache;
853   for (auto node : all_nodes) {
854     auto cnode = node->cast<CNodePtr>();
855     if (!FindReshape(cnode, &op_cache)) {
856       continue;
857     }
858     MS_ASSERT(cnode->inputs().size() == 3);
859     // get previous node's strategy_cost_
860     auto pre_node = cnode->input(1);
861     if (IsPrimitiveCNode(pre_node, prim::kPrimLoad)) {
862       pre_node = pre_node->cast<CNodePtr>()->input(1);
863     }
864     int64_t out_index = 0;
865     OperatorInfoPtr pre_operator_info;
866     std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs;
867     auto operator_info = cnode->user_data<OperatorInfo>();
868     if (pre_node->isa<Parameter>()) {
869       auto reshape_info1 = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
870       reshape_info1->SetCostForReshapeWithParameter();
871       pre_operator_info = reshape_info1;
872       pre_stra_costs = reshape_info1->strategy_cost();
873     } else {
874       if (!FindReshapePreNodeStraCosts(pre_node, &pre_operator_info, &out_index, 0)) {
875         MS_LOG(EXCEPTION) << "FindReshapePreNodeStraCosts for reshape failed";
876       }
877       pre_stra_costs = pre_operator_info->strategy_cost();
878     }
879     // get next node's strategy_cost_
880     int64_t in_index = 0;
881     OperatorInfoPtr next_operator_info;
882     bool is_next_reshape = false;
883     std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs;
884     bool find_next_node = FindReshapeNextNodeStraCosts(cnode, &next_operator_info, &in_index, &is_next_reshape, 0);
885     if (!find_next_node) {
886       MS_LOG(INFO) << "FindReshapeNextNodeStraCosts for reshape failed";
887     }
888     // set input_layout and output_layout for reshape.
889     // init reshape and set cost for each input_layout and output_layout.
890     auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
891     reshape_info->set_pre_operator_name(pre_operator_info->name());
892     reshape_info->set_pre_operator_index(out_index);
893     if (find_next_node) {
894       next_stra_costs = next_operator_info->strategy_cost();
895       reshape_info->set_next_operator_name(next_operator_info->name());
896       reshape_info->set_next_operator_index(in_index);
897     }
898     bool is_prev_param = pre_node->isa<Parameter>();
899     if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param,
900                                             is_next_reshape) != SUCCESS) {
901       MS_LOG(EXCEPTION) << "reshape generate strategy_costs failed!";
902     }
903   }
904 }
905 
ParallelStrategySearch(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root)906 Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
907   // There are 4 meta-steps to determine the parallelization strategy for the ANF graph.
908   // Step 1: Traverse the ANF graph, and create NODEs for costgraph:
909   //      create the OperatorInfo object for each primitive, and enumerate the parallelization strategies
910   //      for each OperatorInfo;
911   // Step 1.1: Deal with 'Reshape':
912   //      For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's
913   //      layout as its output layout.
914   // Step 2: Traverse the ANF graph, and create EDGES for costgraph:
915   //      create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies
916   //      for each edge, based on the strategies of two OperatorInfos;
917   // Step 3: Augment the costgraph:
918   //      taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity
919   //      operator for this Parameter, and add an edge for the use of this Parameter by each
920   //      subsequent operator;
921   // Step 3.1: Calculate memory usage:
922   //      note the memory usage calculation is different in training phase and inference phase.
923   // Step 4: Run the strategy searching algorithm:
924   //      If 'sharding_propagation' is configured to be true, then the configured-sharding-strategies will propagate
925   //      to the non-configured operators, with the goal of minimizing redistribution cost.
926   //      Otherwise, DP algorithm is used to search strategy of the costgraph. Note that there may be several connected
927   //      components in the costgraph, and the DP algorithm runs on each of them.
928   //
929   // OUTPUT: the determined strategy for each operator.
930 
931   InitCostGraph();
932   // Step 1
933   if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
934     if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
935       MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
936                    << entire_costgraph->GetOperators().size() << " operators.";
937     } else {
938       MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
939     }
940   } else {
941     if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
942       MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
943                    << entire_costgraph->GetOperators().size() << " operators.";
944     } else {
945       MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
946     }
947   }
948   // Step 1.1
949   ReshapeCostCompute(all_nodes);
950   // Step 2
951   ConstructCostGraphEdges(all_nodes);
952   MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
953                << " operators, and " << entire_costgraph->GetNumEdges() << " edges.";
954 
955   // Step 3: Augment the costgraph.
956   AugmentCostGraph(all_nodes);
957   auto num_ops = entire_costgraph->GetOperators().size();
958   SetOpsNumToExecutor(num_ops);
959   auto num_edges = entire_costgraph->GetNumEdges();
960   MS_LOG(INFO) << "After the augmenting procedure, there are " << num_ops << " operators, and " << num_edges
961                << " edges.";
962 
963   // Step 3.1: Calculate the memory usage
964   if (entire_costgraph->CalculateMemoryCost() != SUCCESS) {
965     MS_LOG(EXCEPTION) << "Calculating memory cost failed.";
966   }
967 
968   // Step 4: run the strategy searching algorithm
969   if (ParallelContext::GetInstance()->sharding_propagation()) {
970     entire_costgraph->StrategyPropagate(configured_stra_ops_);
971     configured_stra_ops_.clear();
972   } else if (GetStrategy(entire_costgraph) != SUCCESS) {
973     MS_LOG(ERROR) << "Strategy search for cost-graph fails";
974     return FAILED;
975   }
976   MS_LOG(INFO) << "Searching strategy succeeded.";
977 
978   if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
979     MS_LOG(INFO) << "Init selected strategy succeeded.";
980   } else {
981     MS_LOG(EXCEPTION) << "Init selected strategy failed.";
982   }
983 
984   // print the selected strategy
985   for (auto &op : entire_costgraph->GetOperators()) {
986     StrategyPtr s_strategy = op->selected_strategy();
987     MS_LOG(INFO) << op->name() << " : The strategy is:";
988     PrintStrategy(s_strategy);
989   }
990   ops_in_a_loop_.clear();
991 
992   return SUCCESS;
993 }
994 
RecInputTensorNames(const std::map<std::string,std::string>::iterator & it,std::vector<std::vector<std::string>> input_tensor_names)995 std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::string, std::string>::iterator &it,
996                                                           std::vector<std::vector<std::string>> input_tensor_names) {
997   for (size_t j = 0; j < input_tensor_names.size(); j++) {
998     for (size_t k = 0; k < input_tensor_names[j].size(); k++) {
999       if (it->first == input_tensor_names[j][k]) {
1000         input_tensor_names[j][k] = it->second;
1001         break;
1002       }
1003     }
1004   }
1005   return input_tensor_names;
1006 }
1007 
GetInternalOperatorInfo(const CNodePtr & cnode,const ValueNodePtr & prim_anf_node)1008 CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node) {
1009   PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
1010   if (prim->name() == prim::kTupleGetItem || prim->name() == DEPEND) {
1011     auto prev_cnode = cnode->input(1)->cast<CNodePtr>();
1012     if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
1013       return nullptr;
1014     }
1015     auto prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
1016     while (prev_prim->name() == prim::kTupleGetItem || prev_prim->name() == DEPEND) {
1017       prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
1018       if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
1019         return nullptr;
1020       }
1021       prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
1022     }
1023     return prev_cnode;
1024   }
1025   return nullptr;
1026 }
1027 
ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string & name,const std::string & uniqueid)1028 void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, const std::string &uniqueid) {
1029   size_t iter_ops = 0;
1030   for (auto op : entire_costgraph->GetOperators()) {
1031     if (op->name() == name) {
1032       break;
1033     }
1034     iter_ops = iter_ops + 1;
1035   }
1036 
1037   std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
1038   for (size_t i = 0; i < input_tensor_names.size(); i++) {
1039     for (size_t j = 0; j < input_tensor_names[i].size(); j++) {
1040       if (input_tensor_names[i][j] == uniqueid) {
1041         input_tensor_names[i][j] = input_tensor_names[iter_ops][0];
1042       }
1043     }
1044   }
1045 
1046   entire_costgraph->set_inputs_tensor_name_list(input_tensor_names);
1047 }
1048 
ParallelStrategyRecSearch(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root)1049 Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
1050   InitCostGraph();
1051   if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
1052     if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
1053       MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
1054                    << entire_costgraph->GetOperators().size() << " operators.";
1055     } else {
1056       MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
1057     }
1058   } else {
1059     if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
1060       MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
1061                    << entire_costgraph->GetOperators().size() << " operators.";
1062     } else {
1063       MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
1064     }
1065   }
1066   ReshapeCostCompute(all_nodes);
1067 
1068   auto ops = entire_costgraph->GetOperators();
1069   std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
1070   auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list();
1071   for (auto it = tuple_getitem_list.begin(); it != tuple_getitem_list.end();) {
1072     input_tensor_names = RecInputTensorNames(it++, input_tensor_names);
1073   }
1074   std::shared_ptr<Graph> graph = ParseGraph(ops, input_tensor_names);
1075 
1076   std::shared_ptr<std::vector<std::vector<size_t>>> eli_list(new std::vector<std::vector<size_t>>);
1077   std::shared_ptr<std::vector<size_t>> index_list(new std::vector<size_t>);
1078   graph = EliminateGraph(graph, eli_list, index_list);
1079 
1080   size_t num_device = g_device_manager->DeviceNum();
1081   const auto device_memory = CostModelContext::GetInstance()->device_memory_capacity();
1082   if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) {
1083     MS_LOG(INFO) << "Partition Success With " << num_device << " devices.";
1084   } else {
1085     MS_LOG(ERROR) << "PartitionForAllDevices failed.";
1086     return FAILED;
1087   }
1088 
1089   bool is_training = true;
1090   if (!root->has_flag(TRAINING)) {
1091     is_training = false;
1092   }
1093   GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list, is_training);
1094 
1095   if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
1096     MS_LOG(INFO) << "Init selected strategy succeeded.";
1097   } else {
1098     MS_LOG(ERROR) << "Init selected strategy failed.";
1099     return FAILED;
1100   }
1101 
1102   // print the selected strategy
1103   for (auto &op : entire_costgraph->GetOperators()) {
1104     StrategyPtr s_strategy = op->selected_strategy();
1105     MS_LOG(INFO) << op->name() << " : The strategy is:";
1106     PrintStrategy(s_strategy);
1107   }
1108 
1109   return SUCCESS;
1110 }
1111 }  // namespace parallel
1112 }  // namespace mindspore
1113