1 /**
2 * Copyright 2019-2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/step_auto_parallel.h"
18
19 #include <cinttypes>
20 #include <ctime>
21 #include <algorithm>
22 #include <map>
23 #include <memory>
24 #include <set>
25 #include <string>
26 #include <unordered_map>
27 #include <utility>
28 #include <vector>
29 #include <unordered_set>
30
31 #include "base/core_ops.h"
32 #include "frontend/optimizer/opt.h"
33 #include "frontend/optimizer/optimizer.h"
34 #include "frontend/parallel/auto_parallel/dp_algo_costmodel.h"
35 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
36 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
37 #include "frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h"
38 #include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h"
39 #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h"
40 #include "frontend/parallel/context.h"
41 #include "frontend/parallel/graph_util/node_info.h"
42 #include "frontend/parallel/graph_util/graph_info.h"
43 #include "frontend/parallel/ops_info/reshape_info.h"
44 #include "frontend/parallel/ops_info/tmp_identity_info.h"
45 #include "frontend/parallel/step_parallel.h"
46 #include "frontend/parallel/parameter_manager.h"
47 #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
48 #include "ir/anf.h"
49 #include "ir/param_info.h"
50 #include "ir/tensor.h"
51 #if ((defined ENABLE_CPU) && (!defined _WIN32))
52 #include "ps/util.h"
53 #endif
54
55 namespace mindspore {
56 namespace parallel {
StepAutoParallel(const FuncGraphPtr & root,const opt::OptimizerPtr &)57 bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
58 #if ((defined ENABLE_CPU) && (!defined _WIN32))
59 if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) {
60 return false;
61 }
62 #endif
63 MS_EXCEPTION_IF_NULL(root);
64 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
65 std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
66 // assume no change to graph
67 bool changes = false;
68 // control whether use model_parallel mode
69 if (!root->has_flag(AUTO_PARALLEL) || (parallel_mode != AUTO_PARALLEL) ||
70 root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) {
71 return changes;
72 }
73
74 // check whether strategy_search_mode is valid
75 std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode();
76 if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) {
77 // Setting searching mode: dynamic programming as default.
78 strategy_search_mode = DYNAMIC_PROGRAMMING;
79 MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default";
80 }
81
82 struct timeval start_time {
83 0
84 }, end_time{0};
85 (void)gettimeofday(&start_time, nullptr);
86 #ifdef ENABLE_DUMP_IR
87 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
88 draw::Draw(STEP_AUTO_PARALLEL_BEGIN, root);
89 }
90 #endif
91 MS_LOG(INFO) << "Now entering step auto parallel";
92 TOTAL_OPS = 0;
93 AnfNodePtr ret = root->get_return();
94 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
95 if (ParallelInit() != SUCCESS) {
96 MS_LOG(EXCEPTION) << "Parallel init failed";
97 }
98 // mark the forward cnodes, parallel only care these nodes
99 MarkForwardCNode(root);
100 if (IsInsertVirtualOutput(root)) {
101 InsertVirtualOutput(root, all_nodes);
102 AnfNodePtr ret_after = root->get_return();
103 MS_EXCEPTION_IF_NULL(ret_after);
104 all_nodes = DeepScopedGraphSearch(ret_after);
105 }
106 if (FindCommunicationOp(all_nodes)) {
107 MS_LOG(EXCEPTION) << "The graph contain communication op";
108 }
109
110 // search parallelization strategy
111 if (strategy_search_mode == DYNAMIC_PROGRAMMING) {
112 if (ParallelStrategySearch(all_nodes, root) != SUCCESS) {
113 MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using DP searching mode";
114 }
115 } else if (strategy_search_mode == RECURSIVE_PROGRAMMING) {
116 if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) {
117 MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode";
118 }
119 } else {
120 MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected";
121 }
122
123 (void)gettimeofday(&end_time, nullptr);
124 uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
125 time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
126 MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us";
127
128 root->set_flag(AUTO_PARALLEL_RUN_ONCE_ONLY, true);
129 return changes;
130 }
131
IsElementWiseOperator(const std::string & op_name)132 bool IsElementWiseOperator(const std::string &op_name) {
133 // clang-format off
134 static const std::set<std::string> elementwise_op = {ACTIVATION, GELU, TANH,
135 SOFTMAX, LOG_SOFTMAX, RELU,
136 SQRT, CAST, POW,
137 EXP, LOG, COS,
138 ACOS, LOGICALNOT, NEG,
139 SQUARE, SIGMOID, ABS,
140 ACOSH, ASIN, ASINH,
141 ATAN, ATANH, CEIL,
142 COSH, EXPM1, LOG1P,
143 SIN, SINH, TAN,
144 RSQRT, RECIPROCAL, INV,
145 ROUND, FLOOR, SIGN,
146 ERF, ERFC, ZEROSLIKE,
147 ONESLIKE, BESSELI0E, MOD,
148 ASSIGN, ASSIGN_ADD, ATAN2,
149 DIVNONAN, LOGICALAND, ELU,
150 LOGICALOR, RELU6, SOFTPLUS,
151 SOFTSIGN, LESS, LESSEQUAL,
152 BESSELI1E, GREATEREQUAL, APPROXIMATEEQUAL,
153 REPEAT_ELEMENTS};
154 // clang-format on
155 auto iter = elementwise_op.find(op_name);
156 return (iter != elementwise_op.end());
157 }
158
IsSplittableOperator(const std::string & op_name)159 bool IsSplittableOperator(const std::string &op_name) {
160 // clang-format off
161 static const std::set<std::string> splittable_op =
162 {MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
163 FLOORDIV, L2_NORMALIZE, ADD, MAXPOOL, AVGPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
164 REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
165 MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, STACK,
166 LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
167 STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT,
168 SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
169 EMBEDDING_LOOKUP, FUSE_BATCH_NORM_EX, SPLIT, BROADCAST_TO, ABS, ACOSH, ASIN, ASINH, ATAN, ATANH, CEIL, COSH,
170 EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE,
171 BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
172 SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
173 UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD,
174 UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE,
175 MATMUL_DDS, DSD_MATMUL, RESIZE_BILINEAR, RESIZE_NEAREST_NEIGHBOR, UNIFORMREAL};
176 // clang-format on
177
178 auto iter = splittable_op.find(op_name);
179 return (iter != splittable_op.end());
180 }
181
IsAutoParallelCareNode(const CNodePtr & cnode)182 bool IsAutoParallelCareNode(const CNodePtr &cnode) {
183 MS_EXCEPTION_IF_NULL(cnode);
184 ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
185 if (prim_node == nullptr) {
186 return false;
187 }
188 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_node);
189 if (prim == nullptr) {
190 return false;
191 }
192 bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name());
193 if (bool_result && (prim->name() != MAKE_TUPLE) && (prim->name() != MAKE_LIST)) {
194 MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name();
195 } else if (prim->name() == CAST) {
196 if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) {
197 // Do not care CASTs from optimizer
198 return false;
199 }
200 return true;
201 }
202 return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name());
203 }
204
205 // Recording the operators appearing in a for-loop.
206 // Currently, we assume that the operators in different for-loops are identical, and their traversal
207 // orderings are also identical.
208 // Therefore, we create OperatorInfo objects for the operators in a loop (say, loop-3), and reuse them in
209 // the rest of loops (loop-2, loop-1 and loop-0)
210 std::set<std::string> ops_in_a_loop_;
211 // Whether two operators are in different loops; if it is true, then return true.
212 // If at least one of the two operators is not in the loop, then return false.
213 // If two operators are in the same loop, the return false.
IsOperatorsInTwoSeparateLoops(const CNodePtr & a_cnode,const CNodePtr & b_cnode)214 bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cnode) {
215 auto a_op_info = a_cnode->user_data<OperatorInfo>();
216 MS_EXCEPTION_IF_NULL(a_op_info);
217 auto b_op_info = b_cnode->user_data<OperatorInfo>();
218 MS_EXCEPTION_IF_NULL(b_op_info);
219 if ((ops_in_a_loop_.find(a_op_info->name()) == ops_in_a_loop_.end()) ||
220 (ops_in_a_loop_.find(b_op_info->name()) == ops_in_a_loop_.end())) {
221 return false;
222 }
223 size_t a_loop_index = 0, b_loop_index = 0;
224 const auto &a_fullname = a_cnode->fullname_with_scope();
225 if (!GetLoopIndexFromCNode(a_cnode, &a_loop_index)) {
226 MS_LOG(EXCEPTION) << "The operator with fullname_with_scope: " << a_fullname << " was not included in the set.";
227 }
228 const auto &b_fullname = b_cnode->fullname_with_scope();
229 if (!GetLoopIndexFromCNode(b_cnode, &b_loop_index)) {
230 MS_LOG(EXCEPTION) << "The operator with fullname_with_scope: " << b_fullname << " was not included in the set.";
231 }
232 if (a_loop_index == b_loop_index) {
233 return false;
234 }
235 return true;
236 }
237
238 // 'configured_stra_ops_' includes all operators that are configured sharding strategies.
239 std::map<OperatorInfoPtr, StrategyPtr> configured_stra_ops_;
InitCostGraph()240 void InitCostGraph() {
241 if (entire_costgraph == nullptr) {
242 entire_costgraph = std::make_shared<CostGraph>();
243 }
244 MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
245 CostModelContext::GetInstance()->PrintCostModel();
246 entire_costgraph->Init();
247 configured_stra_ops_.clear();
248 }
249
SetStrategyToOperator(const OperatorInfoPtr & operator_info,const PrimitivePtr & prim,std::unordered_map<std::string,ValuePtr> attrs,bool,StrategyMap * stra_map,const std::string & strategy_key_name)250 void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const PrimitivePtr &prim,
251 std::unordered_map<std::string, ValuePtr> attrs, bool, StrategyMap *stra_map,
252 const std::string &strategy_key_name) {
253 // In this case, the configured strategy should be extracted to help setting cost
254 StrategyPtr strategyPtr;
255 if (StrategyFound(attrs)) {
256 strategyPtr = parallel::ExtractStrategy(attrs[STRATEGY]);
257 } else {
258 strategyPtr = (*stra_map)[strategy_key_name];
259 }
260 if (strategyPtr != nullptr) {
261 if (prim->name() == RESHAPE) {
262 MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
263 }
264 const auto fully_use_devices = CostModelContext::GetInstance()->fully_use_device();
265 // Set cost for this configured strategy
266 if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) {
267 MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed";
268 } else if (fully_use_devices) {
269 // If configured to fully use devices, then checking for the user-specified strategy
270 int64_t used_devices = operator_info->used_devices();
271 MS_EXCEPTION_IF_NULL(g_device_manager);
272 auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size();
273 // 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel
274 if (used_devices == 1) {
275 (void)configured_stra_ops_.emplace(operator_info, strategyPtr);
276 return;
277 }
278 // 'used_devices == -1' means that 'used_devices_' is not set
279 if ((used_devices == -1) || LongToSize(used_devices) != total_device_num) {
280 MS_LOG(EXCEPTION) << "In current configuration 'fully_use_devices' = True, "
281 << "but the specified strategy uses device: " << used_devices
282 << ", total devices: " << total_device_num
283 << ", try to set 'set_algo_parameters(fully_use_devices=False)' "
284 "in package 'mindspore.parallel'.";
285 }
286 }
287 (void)configured_stra_ops_.emplace(operator_info, strategyPtr);
288 }
289 }
290
CreateTheOperatorInfo(const PrimitivePtr & prim,const CNodePtr & cnode,bool is_last_nodes,StrategyMap * stra_map)291 OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes,
292 StrategyMap *stra_map) {
293 MS_EXCEPTION_IF_NULL(prim);
294 MS_EXCEPTION_IF_NULL(cnode);
295 auto attrs = prim->attrs();
296 std::vector<Shapes> shape_list = ExtractShape(cnode);
297 if (shape_list.empty()) {
298 MS_LOG(EXCEPTION) << "Failure: node " << cnode->UniqueId() << " failed to extract shape";
299 }
300 // Create an OperatorInfo instance
301 OperatorInfoPtr operator_info = NewOperatorInstance(prim, attrs, shape_list);
302 MS_EXCEPTION_IF_NULL(operator_info);
303 // Set the parameter information for this OperatorInfo (whether the inputs are parameters or not)
304 std::vector<bool> parameter_info = ExtractInputParameterByNode(cnode);
305 if (operator_info->set_is_parameter(parameter_info) != SUCCESS) {
306 MS_LOG(ERROR) << "Initializing parameter information failed for operator: " << operator_info->name();
307 return nullptr;
308 }
309 // Set the data type for inputs and outputs of this OperatorInfo
310 auto inputs_type_length = ExtractInputTypeLengthByNode(cnode);
311 auto outputs_type = ExtractOutputTypeByNode(cnode);
312 std::vector<size_t> outputs_type_length;
313 outputs_type_length.reserve(outputs_type.size());
314 std::transform(outputs_type.begin(), outputs_type.end(), std::back_inserter(outputs_type_length),
315 GetLengthOfDataType);
316 if (operator_info->SetInputAndOutputTypeLength(inputs_type_length, outputs_type_length) != SUCCESS) {
317 MS_LOG(ERROR) << "Setting the lengths of inputs and outputs failed for operator: " << operator_info->name();
318 return nullptr;
319 }
320 if (operator_info->set_outputs_type(outputs_type) != SUCCESS) {
321 MS_LOG(ERROR) << "Setting the types of outputs failed for operator: " << operator_info->name();
322 return nullptr;
323 }
324 // When the 'inputs' contains numerical values for some operators, these values should be extracted from
325 // ANF graph
326 auto &inputs = cnode->inputs();
327 std::vector<ValuePtr> input_value;
328 for (size_t index = 1; index < inputs.size(); ++index) {
329 if (inputs[index]->isa<ValueNode>()) {
330 input_value.push_back(GetValueNode(inputs[index]));
331 } else {
332 input_value.emplace_back(nullptr);
333 }
334 }
335 operator_info->set_input_value(input_value);
336 operator_info->set_outputs_dtype(cnode->Type());
337 operator_info->set_cnode(cnode);
338 // key of strategy map
339 std::string strategy_key_name = "";
340 auto param_names = NodeParameterName(cnode, -1, 0);
341 if (!param_names.empty()) {
342 strategy_key_name = prim->name() + "_" + param_names[0].first;
343 }
344 bool load_strategy_from_ckpt =
345 StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
346 // If no strategy has been configured for this operator, then candidate strategies are generated for
347 // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy.
348 // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint .
349 if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) {
350 // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for
351 // BatchParallelInfo operator
352 operator_info->ComputeBatchSplitFlagList();
353 if (operator_info->GenerateStrategies(0) != SUCCESS) {
354 MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed.";
355 return nullptr;
356 }
357 if (ParallelContext::GetInstance()->sharding_propagation() &&
358 (operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos)) {
359 const auto &swc_vec = operator_info->GetStrategyCost();
360 if (swc_vec.empty()) {
361 MS_LOG(EXCEPTION) << "No available strategy for: " << operator_info->name();
362 }
363 MS_EXCEPTION_IF_NULL(swc_vec[0]->strategy_ptr);
364 (void)configured_stra_ops_.emplace(operator_info, swc_vec[0]->strategy_ptr);
365 }
366 // If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated
367 auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
368 if (approximation) {
369 operator_info->ApproximateStrategies();
370 MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
371 }
372 } else {
373 SetStrategyToOperator(operator_info, prim, attrs, is_last_nodes, stra_map, strategy_key_name);
374 }
375 return operator_info;
376 }
377
IsFindWrong(const OperatorInfoPtr current_op_ptr,const std::string & prim_name)378 bool IsFindWrong(const OperatorInfoPtr current_op_ptr, const std::string &prim_name) {
379 bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) &&
380 (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) &&
381 (current_op_ptr->name().find(prim_name + "Info") == std::string::npos);
382 if (prim_name == GATHERV2) {
383 is_find_wrong = is_find_wrong && (current_op_ptr->name().find(prim_name + "PInfo") == std::string::npos);
384 }
385 return is_find_wrong;
386 }
387
388 // Using CNode's UniqueIds to construct nodes
ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr &)389 Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
390 MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
391 // The map from CNode's UniqueId to its operatorInfo
392 std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
393 // The operator_infos in a loop
394 std::vector<OperatorInfoPtr> operators_in_forloop;
395 // Key: i-th loop; Value: index of 'operators_in_forloop'
396 std::map<size_t, size_t> loop_to_ops;
397 // extract strategy from checkpoint for multi-train
398 StrategyMap stra_map;
399 if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
400 if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
401 MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
402 }
403 }
404
405 for (auto &node : all_nodes) {
406 // NOTE: we only care about splittable Primitive operators
407 auto cnode = node->cast<CNodePtr>();
408 bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)));
409 if (bool_result) {
410 continue;
411 }
412 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
413 if (!IsAutoParallelCareNode(cnode)) {
414 // Needed by rec_parser
415 if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) {
416 auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node);
417 if (prev_cnode != nullptr) {
418 entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId()));
419 }
420 }
421 continue;
422 }
423 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
424 MS_EXCEPTION_IF_NULL(prim);
425
426 auto search_cnode = from_cnode_to_info.find(cnode->UniqueId());
427 if (search_cnode == from_cnode_to_info.end()) {
428 size_t loop_index = 0;
429 bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index);
430 const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop();
431 if (single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) {
432 const auto ¤t_op_ptr = operators_in_forloop[loop_to_ops[loop_index]];
433 if (IsFindWrong(current_op_ptr, prim->name())) {
434 MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
435 << " does not match the Prim: " << prim->name()
436 << ". The fullname_with_scope: " << cnode->fullname_with_scope();
437 }
438 loop_to_ops[loop_index]++;
439 cnode->set_user_data<OperatorInfo>(current_op_ptr);
440 MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
441 << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
442 << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
443 << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
444 (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), current_op_ptr));
445 continue;
446 }
447 bool is_last_nodes = IsPrimitiveCNode(cnode, prim::kPrimVirtualOutput);
448 auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
449 if (operator_info == nullptr) {
450 return FAILED;
451 }
452 // Needed by rec_parser
453 operator_info->set_type(prim->name());
454 operator_info->set_last_node_flag(is_last_nodes);
455 std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
456
457 entire_costgraph->AddOperator(operator_info);
458 cnode->set_user_data<OperatorInfo>(operator_info);
459 MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
460 << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
461 << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
462 << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
463 (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), operator_info));
464 if (single_loop && is_in_loop) {
465 operators_in_forloop.push_back(operator_info);
466 ops_in_a_loop_.insert(operator_info->name());
467 loop_to_ops[loop_index]++;
468 }
469 // Needed by rec_parser
470 entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
471 } else {
472 // Two CNODEs' UniqueIds should not be equal
473 MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId()
474 << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
475 << " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name();
476 }
477 }
478
479 MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
480 return SUCCESS;
481 }
482
SetOperatorToCNode(const OperatorInfoPtr & current_op_ptr,const PrimitivePtr & prim,const CNodePtr & cnode)483 void SetOperatorToCNode(const OperatorInfoPtr ¤t_op_ptr, const PrimitivePtr &prim, const CNodePtr &cnode) {
484 if (current_op_ptr == nullptr) {
485 MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed.";
486 } else {
487 if (IsFindWrong(current_op_ptr, prim->name())) {
488 MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
489 << " does not match the Prim: " << prim->name();
490 }
491
492 // Needed by rec_parser
493 ModifyInputsTensorNameListIfOperatorInfoCreated(current_op_ptr->name(), cnode->UniqueId());
494
495 cnode->set_user_data<OperatorInfo>(current_op_ptr);
496 MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
497 << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
498 << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
499 << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
500 }
501 }
502
503 // Using CNode's UniqueIdThroughCopys to construct nodes
ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr &)504 Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
505 MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
506 // The map from CNode's UniqueIdThroughCopy to its operatorInfo
507 std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
508 // The operator_infos in a loop
509 std::vector<OperatorInfoPtr> operators_in_forloop;
510 // Key: i-th loop; Value: index of 'operators_in_forloop'
511 std::map<size_t, size_t> loop_to_ops;
512 // extract strategy from checkpoint for multi-train
513 StrategyMap stra_map;
514 if (StrategyCheckpoint::GetInstance().LoadCheckPointOn() &&
515 StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
516 MS_LOG(WARNING) << "Load strategy checkpoint failed";
517 return FAILED;
518 }
519 for (auto &node : all_nodes) {
520 // NOTE: we only care about splittable Primitive operators
521 auto cnode = node->cast<CNodePtr>();
522 if ((cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)))) {
523 continue;
524 }
525 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
526 if (!IsAutoParallelCareNode(cnode)) {
527 // Needed by rec_parser
528 if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) {
529 auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node);
530 if (prev_cnode != nullptr) {
531 entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId()));
532 }
533 }
534 continue;
535 }
536 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
537
538 // Find the operatorInfo if it exists
539 auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy());
540 if (search_cnode == from_cnode_to_info.end()) {
541 size_t loop_index = 0;
542 bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index);
543 const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop();
544 bool is_op_created = single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size());
545 if (is_op_created) {
546 const auto ¤t_op_ptr = operators_in_forloop[loop_to_ops[loop_index]];
547 if (IsFindWrong(current_op_ptr, prim->name())) {
548 MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
549 << " does not match the Prim: " << prim->name()
550 << ". The fullname_with_scope: " << cnode->fullname_with_scope();
551 }
552 loop_to_ops[loop_index]++;
553 cnode->set_user_data<OperatorInfo>(current_op_ptr);
554 MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
555 << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
556 << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
557 << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
558 (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), current_op_ptr));
559 continue;
560 }
561 // In this case, the corresponding OperatorInfo is not created, create the new one.
562 bool is_last_nodes = IsPrimitiveCNode(cnode, prim::kPrimVirtualOutput);
563 auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map);
564 MS_EXCEPTION_IF_NULL(operator_info);
565
566 // Needed by rec_parser
567 operator_info->set_type(prim->name());
568 operator_info->set_last_node_flag(is_last_nodes);
569 std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
570
571 entire_costgraph->AddOperator(operator_info);
572 cnode->set_user_data<OperatorInfo>(operator_info);
573 MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
574 << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
575 << ", CNode fullname_with_scope: " << cnode->fullname_with_scope()
576 << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
577 (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info));
578 if (single_loop && is_in_loop) {
579 operators_in_forloop.push_back(operator_info);
580 ops_in_a_loop_.insert(operator_info->name());
581 loop_to_ops[loop_index]++;
582 }
583 // Needed by rec_parser
584 entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
585 } else {
586 SetOperatorToCNode(search_cnode->second, prim, cnode);
587 }
588 }
589
590 MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
591 return SUCCESS;
592 }
593
CreateEdgeBetweenTwoOps(const OperatorInfoPtr & prev_op_info,const OperatorInfoPtr & node_op_info,const CNodePtr & cnode,const CNodePtr & prev_cnode,const PrimitivePtr & prim,const PrimitivePtr & prev_prim,size_t output_index,size_t input_index,size_t * edge_count)594 void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const OperatorInfoPtr &node_op_info,
595 const CNodePtr &cnode, const CNodePtr &prev_cnode, const PrimitivePtr &prim,
596 const PrimitivePtr &prev_prim, size_t output_index, size_t input_index,
597 size_t *edge_count) {
598 std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name();
599 // If the edge between these two operators already has been added, then the edge will not be added again.
600 if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, input_index - 1)) {
601 return;
602 }
603 EdgePtr edge_ptr;
604 MS_LOG(INFO) << "Creating edge: " << edge_name;
605 if (IsOperatorsInTwoSeparateLoops(prev_cnode, cnode)) {
606 MS_LOG(INFO) << "prev_cnode_fullname: " << prev_cnode->fullname_with_scope()
607 << ", cnode_fullname: " << cnode->fullname_with_scope();
608 MS_LOG(INFO) << "The two operators in two separate for-loops, thus skip the edge.";
609 return;
610 }
611 const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow();
612 bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) ||
613 (stra_follow && IsElementWiseOperator(prev_prim->name()));
614 if (follow_strategy) {
615 // Redistribution in not allowed on the edge.
616 // Elementwise operators have the same strategy as their previous operators.
617 edge_ptr =
618 std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, input_index - 1, false, true);
619 } else {
620 edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, input_index - 1, false);
621 }
622
623 // Init costs for this edge
624 if (edge_ptr->InitEdgeCost() != SUCCESS) {
625 MS_LOG(EXCEPTION) << "Edge cost initialization failed";
626 }
627 node_op_info->AddPrevEdge(edge_ptr);
628 prev_op_info->AddSuccEdge(edge_ptr);
629 entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr);
630 if (ParallelContext::GetInstance()->sharding_propagation() && (prev_prim->name() == CAST) &&
631 (configured_stra_ops_.find(node_op_info) != configured_stra_ops_.end())) {
632 const auto next_op_stra = configured_stra_ops_[node_op_info];
633 const auto cast_stra = edge_ptr->GetPrevOpStrategyByNextOpStrategyWithZeroComm(next_op_stra);
634 if (cast_stra == nullptr) {
635 MS_LOG(EXCEPTION) << "No available strategy for: " << prev_op_info->name();
636 }
637 prev_op_info->ClearStrategyCost();
638 if (prev_op_info->SetCostUnderStrategy(cast_stra) != SUCCESS) {
639 MS_LOG(EXCEPTION) << "Failure: operator " << prev_op_info->name() << " SetCostUnderStrategy failed";
640 }
641 if (edge_ptr->InitEdgeCost() != SUCCESS) {
642 MS_LOG(EXCEPTION) << "Edge cost re-initialization failed.";
643 }
644 MS_LOG(INFO) << "Set strategy for: " << prev_op_info->name() << " under the strategy of: " << node_op_info->name();
645 (void)configured_stra_ops_.emplace(prev_op_info, cast_stra);
646 }
647 MS_LOG(INFO) << "Successfully adding the edge between " << prev_op_info->name() << " and " << node_op_info->name();
648 (*edge_count)++;
649 }
650
ConstructCostGraphEdges(const std::vector<AnfNodePtr> & all_nodes)651 void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
652 // Step 2
653 MS_LOG(INFO) << "Constructing edges for cost graph begins.";
654 for (auto &node : all_nodes) {
655 auto cnode = node->cast<CNodePtr>();
656 if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
657 continue;
658 }
659 auto &inputs = cnode->inputs();
660 ValueNodePtr prim_anf_node = inputs[0]->cast<ValueNodePtr>();
661 if (!IsAutoParallelCareNode(cnode)) {
662 continue;
663 }
664 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
665 size_t edge_count = 0;
666 auto node_op_info = cnode->user_data<OperatorInfo>();
667
668 for (size_t i = 1; i < inputs.size(); ++i) {
669 auto prev_cnode = inputs[i]->cast<CNodePtr>();
670 bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
671 if (bool_result_prev_cnode) {
672 continue;
673 }
674 ValueNodePtr prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
675 PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
676 size_t output_index = 0;
677
678 while ((IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) ||
679 (prev_prim->name() == DEPEND)) {
680 if (IsAutoParallelCareNode(prev_cnode)) {
681 auto prev_op_info = prev_cnode->user_data<OperatorInfo>();
682 CreateEdgeBetweenTwoOps(prev_op_info, node_op_info, cnode, prev_cnode, prim, prev_prim, output_index, i,
683 &edge_count);
684 break;
685 } else if (prev_prim->name() == prim::kTupleGetItem) {
686 // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before
687 // this 'tuple_getitem'
688 MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator.";
689 output_index = LongToSize(GetValue<int64_t>(GetValueNode(prev_cnode->input(2))));
690 prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
691 bool bool_result_tuple = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
692 if (bool_result_tuple) {
693 break;
694 }
695 prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
696 prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
697 if (!IsAutoParallelCareNode(prev_cnode)) {
698 MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name();
699 }
700 MS_LOG(INFO) << "Jumped the 'tuple_getitem' operator, "
701 << "and creating an edge between the Operator before "
702 << "'tuple_getitem' and the Operator after 'tuple_getitem'.";
703 } else if (prev_prim->name() == DEPEND) {
704 // In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before
705 // this 'depend'
706 MS_LOG(INFO) << "Jumping the 'depend' operator.";
707 prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
708 bool bool_result_depend = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
709 if (bool_result_depend) {
710 break;
711 }
712 prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
713 prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
714 MS_LOG(INFO) << "Jumped the 'depend' operator, "
715 << "and creating an edge between the Operator before "
716 << "'depend' and the Operator after 'depend'.";
717 }
718 }
719 }
720 MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name();
721 }
722 // If 'approximation' is enabled, the edges need to be checked have effective costs.
723 auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
724 if (approximation) {
725 entire_costgraph->CheckApproximateCostGraphEdges();
726 }
727
728 MS_LOG(INFO) << "Constructing edges for cost graph ends.";
729 }
730
AugmentCostGraph(const std::vector<AnfNodePtr> & all_nodes)731 void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
732 // Step 3
733 for (auto &node : all_nodes) {
734 ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsAutoParallelCareNode);
735 auto parameter_name = parameter_users_info.first;
736 auto target_parameter = parameter_users_info.second.first;
737 auto target_set = parameter_users_info.second.second;
738 if (target_set.size() <= 1) {
739 continue;
740 }
741
742 // Rule out the case when a Parameter being used by a Operator, but the Operator appears in multiple CNODEs
743 std::set<std::string> target_without_duplicate;
744 for (auto &target : target_set) {
745 auto target_cnode = target.first->cast<CNodePtr>();
746 auto input_index = target.second;
747 (void)target_without_duplicate.insert(std::to_string(input_index) +
748 target_cnode->user_data<OperatorInfo>()->name());
749 }
750 if (target_without_duplicate.size() <= 1) {
751 continue;
752 }
753
754 // Here, it is sure that this Parameter (RefKey) is being used by multiple Operators.
755 OperatorInfoPtr tmp_identity_ptr;
756 bool new_identity = false;
757 std::string tmp_identity_name;
758 auto returned_identity = entire_costgraph->FindTmpIdentityByParameterName(parameter_name);
759 if (returned_identity != nullptr) {
760 // In this case, the TmpIdentityInfo instance has already been created
761 new_identity = false;
762 tmp_identity_ptr = returned_identity;
763 tmp_identity_name = tmp_identity_ptr->name();
764 } else {
765 // In the case, the TmpIdentityInfo instance has NOT been created. Thus, a new one is created.
766 new_identity = true;
767 // 1) extract input shape from this Parameter
768 MS_EXCEPTION_IF_NULL(target_parameter);
769 AbstractBasePtr abstract = target_parameter->abstract();
770 if (abstract == nullptr) {
771 MS_LOG(EXCEPTION) << "Failure: abstract is nullptr";
772 }
773 auto input_shape = dyn_cast<abstract::Shape>(abstract->GetShapeTrack());
774 if (input_shape == nullptr) {
775 MS_LOG(EXCEPTION) << "Failure: input_shape is nullptr";
776 }
777 Shape shape = input_shape->shape();
778 Shapes inputs_shape = {shape};
779 Shapes outputs_shape = {shape};
780 // 2) init the attr
781 std::unordered_map<std::string, ValuePtr> attr = {};
782
783 // Create the TmpIdentity instance
784 tmp_identity_ptr = std::make_shared<TmpIdentityInfo>(inputs_shape, outputs_shape, attr);
785 tmp_identity_ptr->set_name(tmp_identity_ptr->name() + std::to_string(TOTAL_OPS));
786 TOTAL_OPS++;
787 tmp_identity_ptr->set_refkey_parameter_name(parameter_name);
788 // Set the parameter and type lengths for inputs and outputs
789 std::vector<bool> is_parameter;
790 auto casted_target_parameter = target_parameter->cast<ParameterPtr>();
791 MS_EXCEPTION_IF_NULL(casted_target_parameter);
792 is_parameter.push_back(ParameterRequireGrad(casted_target_parameter));
793 if (tmp_identity_ptr->set_is_parameter(is_parameter) != SUCCESS) {
794 MS_LOG(EXCEPTION) << "Setting parameter for TmpIdentityInfo failed";
795 }
796 auto node_type = target_parameter->Type();
797 if (node_type->isa<mindspore::TensorType>()) {
798 auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
799 std::vector<size_t> type_length = {GetLengthOfDataType(input_element_type)};
800 if (tmp_identity_ptr->SetInputAndOutputTypeLength(type_length, type_length) != SUCCESS) {
801 MS_LOG(EXCEPTION) << "Setting input and output type length for TmpIdentityInfo failed";
802 }
803 } else {
804 MS_LOG(EXCEPTION) << "Unknown type: " << node_type->type_name();
805 }
806
807 // Generate strategies for this TmpIdentityInfo instance;
808 if (tmp_identity_ptr->GenerateStrategies(0) != SUCCESS) {
809 MS_LOG(EXCEPTION) << "Strategy search for Operator failed : " << tmp_identity_ptr->name();
810 }
811 }
812 // A flag recording whether new edges have been created or not
813 bool add_identity_edge = false;
814
815 // Create edges between this TmpIdentityInfo instance and subsequent Operator instances
816 for (auto &target : target_set) {
817 auto target_cnode = target.first->cast<CNodePtr>();
818 auto input_index = target.second;
819 auto target_op_info = target_cnode->user_data<OperatorInfo>();
820
821 std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name();
822 // If the edge between these two operators already has been added, then the edge will not be added again.
823 if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, LongToSize(input_index - 1))) {
824 continue;
825 }
826 std::shared_ptr<Edge> edge_ptr =
827 std::make_shared<Edge>(edge_name, tmp_identity_ptr, target_op_info, 0, input_index - 1, false, true);
828 // If 'approximation' is enabled, the edges need to be checked have effective costs.
829 auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
830 if (approximation) {
831 target_op_info->ExactStrategiesAndRelatedEdges();
832 }
833
834 if (edge_ptr->InitEdgeCost() != SUCCESS) {
835 MS_LOG(EXCEPTION) << "Edge cost initialization failed";
836 }
837 target_op_info->AddPrevEdge(edge_ptr);
838 tmp_identity_ptr->AddSuccEdge(edge_ptr);
839 entire_costgraph->AddEdge(tmp_identity_ptr, target_op_info, edge_ptr);
840 MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and "
841 << target_op_info->name();
842 add_identity_edge = true;
843 }
844 if (new_identity && add_identity_edge) {
845 // Add the TmpIdentityInfo to CostGraph if BOTH two conditions are satisfied
846 entire_costgraph->AddOperator(tmp_identity_ptr);
847 }
848 }
849 }
850
ReshapeCostCompute(const std::vector<AnfNodePtr> & all_nodes)851 void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
852 std::unordered_set<std::string> op_cache;
853 for (auto node : all_nodes) {
854 auto cnode = node->cast<CNodePtr>();
855 if (!FindReshape(cnode, &op_cache)) {
856 continue;
857 }
858 MS_ASSERT(cnode->inputs().size() == 3);
859 // get previous node's strategy_cost_
860 auto pre_node = cnode->input(1);
861 if (IsPrimitiveCNode(pre_node, prim::kPrimLoad)) {
862 pre_node = pre_node->cast<CNodePtr>()->input(1);
863 }
864 int64_t out_index = 0;
865 OperatorInfoPtr pre_operator_info;
866 std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs;
867 auto operator_info = cnode->user_data<OperatorInfo>();
868 if (pre_node->isa<Parameter>()) {
869 auto reshape_info1 = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
870 reshape_info1->SetCostForReshapeWithParameter();
871 pre_operator_info = reshape_info1;
872 pre_stra_costs = reshape_info1->strategy_cost();
873 } else {
874 if (!FindReshapePreNodeStraCosts(pre_node, &pre_operator_info, &out_index, 0)) {
875 MS_LOG(EXCEPTION) << "FindReshapePreNodeStraCosts for reshape failed";
876 }
877 pre_stra_costs = pre_operator_info->strategy_cost();
878 }
879 // get next node's strategy_cost_
880 int64_t in_index = 0;
881 OperatorInfoPtr next_operator_info;
882 bool is_next_reshape = false;
883 std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs;
884 bool find_next_node = FindReshapeNextNodeStraCosts(cnode, &next_operator_info, &in_index, &is_next_reshape, 0);
885 if (!find_next_node) {
886 MS_LOG(INFO) << "FindReshapeNextNodeStraCosts for reshape failed";
887 }
888 // set input_layout and output_layout for reshape.
889 // init reshape and set cost for each input_layout and output_layout.
890 auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
891 reshape_info->set_pre_operator_name(pre_operator_info->name());
892 reshape_info->set_pre_operator_index(out_index);
893 if (find_next_node) {
894 next_stra_costs = next_operator_info->strategy_cost();
895 reshape_info->set_next_operator_name(next_operator_info->name());
896 reshape_info->set_next_operator_index(in_index);
897 }
898 bool is_prev_param = pre_node->isa<Parameter>();
899 if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param,
900 is_next_reshape) != SUCCESS) {
901 MS_LOG(EXCEPTION) << "reshape generate strategy_costs failed!";
902 }
903 }
904 }
905
ParallelStrategySearch(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root)906 Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
907 // There are 4 meta-steps to determine the parallelization strategy for the ANF graph.
908 // Step 1: Traverse the ANF graph, and create NODEs for costgraph:
909 // create the OperatorInfo object for each primitive, and enumerate the parallelization strategies
910 // for each OperatorInfo;
911 // Step 1.1: Deal with 'Reshape':
912 // For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's
913 // layout as its output layout.
914 // Step 2: Traverse the ANF graph, and create EDGES for costgraph:
915 // create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies
916 // for each edge, based on the strategies of two OperatorInfos;
917 // Step 3: Augment the costgraph:
918 // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity
919 // operator for this Parameter, and add an edge for the use of this Parameter by each
920 // subsequent operator;
921 // Step 3.1: Calculate memory usage:
922 // note the memory usage calculation is different in training phase and inference phase.
923 // Step 4: Run the strategy searching algorithm:
924 // If 'sharding_propagation' is configured to be true, then the configured-sharding-strategies will propagate
925 // to the non-configured operators, with the goal of minimizing redistribution cost.
926 // Otherwise, DP algorithm is used to search strategy of the costgraph. Note that there may be several connected
927 // components in the costgraph, and the DP algorithm runs on each of them.
928 //
929 // OUTPUT: the determined strategy for each operator.
930
931 InitCostGraph();
932 // Step 1
933 if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
934 if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
935 MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
936 << entire_costgraph->GetOperators().size() << " operators.";
937 } else {
938 MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
939 }
940 } else {
941 if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
942 MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
943 << entire_costgraph->GetOperators().size() << " operators.";
944 } else {
945 MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
946 }
947 }
948 // Step 1.1
949 ReshapeCostCompute(all_nodes);
950 // Step 2
951 ConstructCostGraphEdges(all_nodes);
952 MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
953 << " operators, and " << entire_costgraph->GetNumEdges() << " edges.";
954
955 // Step 3: Augment the costgraph.
956 AugmentCostGraph(all_nodes);
957 auto num_ops = entire_costgraph->GetOperators().size();
958 SetOpsNumToExecutor(num_ops);
959 auto num_edges = entire_costgraph->GetNumEdges();
960 MS_LOG(INFO) << "After the augmenting procedure, there are " << num_ops << " operators, and " << num_edges
961 << " edges.";
962
963 // Step 3.1: Calculate the memory usage
964 if (entire_costgraph->CalculateMemoryCost() != SUCCESS) {
965 MS_LOG(EXCEPTION) << "Calculating memory cost failed.";
966 }
967
968 // Step 4: run the strategy searching algorithm
969 if (ParallelContext::GetInstance()->sharding_propagation()) {
970 entire_costgraph->StrategyPropagate(configured_stra_ops_);
971 configured_stra_ops_.clear();
972 } else if (GetStrategy(entire_costgraph) != SUCCESS) {
973 MS_LOG(ERROR) << "Strategy search for cost-graph fails";
974 return FAILED;
975 }
976 MS_LOG(INFO) << "Searching strategy succeeded.";
977
978 if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
979 MS_LOG(INFO) << "Init selected strategy succeeded.";
980 } else {
981 MS_LOG(EXCEPTION) << "Init selected strategy failed.";
982 }
983
984 // print the selected strategy
985 for (auto &op : entire_costgraph->GetOperators()) {
986 StrategyPtr s_strategy = op->selected_strategy();
987 MS_LOG(INFO) << op->name() << " : The strategy is:";
988 PrintStrategy(s_strategy);
989 }
990 ops_in_a_loop_.clear();
991
992 return SUCCESS;
993 }
994
RecInputTensorNames(const std::map<std::string,std::string>::iterator & it,std::vector<std::vector<std::string>> input_tensor_names)995 std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::string, std::string>::iterator &it,
996 std::vector<std::vector<std::string>> input_tensor_names) {
997 for (size_t j = 0; j < input_tensor_names.size(); j++) {
998 for (size_t k = 0; k < input_tensor_names[j].size(); k++) {
999 if (it->first == input_tensor_names[j][k]) {
1000 input_tensor_names[j][k] = it->second;
1001 break;
1002 }
1003 }
1004 }
1005 return input_tensor_names;
1006 }
1007
GetInternalOperatorInfo(const CNodePtr & cnode,const ValueNodePtr & prim_anf_node)1008 CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node) {
1009 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
1010 if (prim->name() == prim::kTupleGetItem || prim->name() == DEPEND) {
1011 auto prev_cnode = cnode->input(1)->cast<CNodePtr>();
1012 if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
1013 return nullptr;
1014 }
1015 auto prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
1016 while (prev_prim->name() == prim::kTupleGetItem || prev_prim->name() == DEPEND) {
1017 prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
1018 if (prev_cnode == nullptr || !IsValueNode<Primitive>(prev_cnode->input(0))) {
1019 return nullptr;
1020 }
1021 prev_prim = prev_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>();
1022 }
1023 return prev_cnode;
1024 }
1025 return nullptr;
1026 }
1027
ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string & name,const std::string & uniqueid)1028 void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, const std::string &uniqueid) {
1029 size_t iter_ops = 0;
1030 for (auto op : entire_costgraph->GetOperators()) {
1031 if (op->name() == name) {
1032 break;
1033 }
1034 iter_ops = iter_ops + 1;
1035 }
1036
1037 std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
1038 for (size_t i = 0; i < input_tensor_names.size(); i++) {
1039 for (size_t j = 0; j < input_tensor_names[i].size(); j++) {
1040 if (input_tensor_names[i][j] == uniqueid) {
1041 input_tensor_names[i][j] = input_tensor_names[iter_ops][0];
1042 }
1043 }
1044 }
1045
1046 entire_costgraph->set_inputs_tensor_name_list(input_tensor_names);
1047 }
1048
ParallelStrategyRecSearch(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root)1049 Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
1050 InitCostGraph();
1051 if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
1052 if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
1053 MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
1054 << entire_costgraph->GetOperators().size() << " operators.";
1055 } else {
1056 MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
1057 }
1058 } else {
1059 if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
1060 MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
1061 << entire_costgraph->GetOperators().size() << " operators.";
1062 } else {
1063 MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
1064 }
1065 }
1066 ReshapeCostCompute(all_nodes);
1067
1068 auto ops = entire_costgraph->GetOperators();
1069 std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
1070 auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list();
1071 for (auto it = tuple_getitem_list.begin(); it != tuple_getitem_list.end();) {
1072 input_tensor_names = RecInputTensorNames(it++, input_tensor_names);
1073 }
1074 std::shared_ptr<Graph> graph = ParseGraph(ops, input_tensor_names);
1075
1076 std::shared_ptr<std::vector<std::vector<size_t>>> eli_list(new std::vector<std::vector<size_t>>);
1077 std::shared_ptr<std::vector<size_t>> index_list(new std::vector<size_t>);
1078 graph = EliminateGraph(graph, eli_list, index_list);
1079
1080 size_t num_device = g_device_manager->DeviceNum();
1081 const auto device_memory = CostModelContext::GetInstance()->device_memory_capacity();
1082 if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) {
1083 MS_LOG(INFO) << "Partition Success With " << num_device << " devices.";
1084 } else {
1085 MS_LOG(ERROR) << "PartitionForAllDevices failed.";
1086 return FAILED;
1087 }
1088
1089 bool is_training = true;
1090 if (!root->has_flag(TRAINING)) {
1091 is_training = false;
1092 }
1093 GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list, is_training);
1094
1095 if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
1096 MS_LOG(INFO) << "Init selected strategy succeeded.";
1097 } else {
1098 MS_LOG(ERROR) << "Init selected strategy failed.";
1099 return FAILED;
1100 }
1101
1102 // print the selected strategy
1103 for (auto &op : entire_costgraph->GetOperators()) {
1104 StrategyPtr s_strategy = op->selected_strategy();
1105 MS_LOG(INFO) << op->name() << " : The strategy is:";
1106 PrintStrategy(s_strategy);
1107 }
1108
1109 return SUCCESS;
1110 }
1111 } // namespace parallel
1112 } // namespace mindspore
1113