• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/auto_parallel/stage_compute.h"
18 
19 #include <algorithm>
20 #include <utility>
21 #include <string>
22 #include <vector>
23 #include <map>
24 #include <regex>
25 
26 #include "mindspore/core/ops/array_ops.h"
27 #include "frontend/parallel/step_parallel.h"
28 #include "utils/parallel_node_check.h"
29 #include "ir/func_graph.h"
30 #include "include/common/utils/parallel_context.h"
31 #include "mindspore/core/utils/ms_utils.h"
32 
33 namespace mindspore {
34 namespace parallel {
35 
36 constexpr size_t PARSING_FAILED = SIZE_MAX;
37 constexpr size_t Kilo = 1024;
38 constexpr double MARGIN_FACTOR = 1.1;
39 
40 // Thousand separators for memory numbers
TSepBytes(size_t n)41 string TSepBytes(size_t n) {
42   string res = std::to_string(n / Kilo / Kilo);
43   int thousand_digit_num = 3;
44   for (int i = static_cast<int>(res.size()) - thousand_digit_num; i > 0; i -= thousand_digit_num) res.insert(i, ",");
45   return res + "M";
46 }
47 
GetNodes(const FuncGraphPtr & root)48 std::vector<AnfNodePtr> GetNodes(const FuncGraphPtr &root) {
49   AnfNodePtr ret_forward = root->get_return();
50   return DeepScopedGraphSearch(ret_forward);
51 }
52 
53 // Get Number of Layers ((each model has unique layers name to analyse))
GetNumLayers(const FuncGraphPtr & root)54 size_t GetNumLayers(const FuncGraphPtr &root) {
55   const std::string kHeadLayer = "Head";
56   const std::string kNormLayer = "Norm";
57   const std::string kEmbeddingLayer = "Embedding";
58   const std::string kLinearLayer = "Linear";
59   std::vector<FuncGraphPtr> pipeline_cells;
60   size_t num_layers = 0;
61 
62   auto forward_nodes = GetNodes(root);
63   for (auto node : forward_nodes) {
64     if (!node->isa<CNode>()) continue;
65 
66     auto cnode = node->cast<CNodePtr>();
67     if (!IsValueNode<FuncGraph>(cnode->input(0))) {
68       continue;
69     }
70 
71     auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
72     if (graph->stage() == -1 ||
73         std::find(pipeline_cells.begin(), pipeline_cells.end(), graph) != pipeline_cells.end()) {
74       continue;
75     }
76 
77     pipeline_cells.push_back(graph);
78     std::string name = graph->ToString();
79     // Remove pre/post cells
80     if (!(name.find(kHeadLayer) != std::string::npos || name.find(kNormLayer) != std::string::npos ||
81           name.find(kEmbeddingLayer) != std::string::npos || name.find(kLinearLayer) != std::string::npos)) {
82       MS_LOG(DEBUG) << name << " is counted as a layer";
83       num_layers++;
84     } else {
85       MS_LOG(DEBUG) << name << " is NOT a normal layer";
86     }
87   }
88 
89   return (num_layers > 0) ? num_layers : PARSING_FAILED;
90 }
91 
GetNumDevices()92 size_t GetNumDevices() { return g_device_manager->DeviceNum(); }
93 
94 // Get parallel_optimizer
HasParallelOptimizer(const FuncGraphPtr & root)95 bool HasParallelOptimizer(const FuncGraphPtr &root) {
96   return parallel::ParallelContext::GetInstance()->enable_parallel_optimizer();
97 }
98 
99 // Check if recomputation was chosen. Currently only able to check select_recompute
HasRecompute(const FuncGraphPtr & root)100 bool HasRecompute(const FuncGraphPtr &root) {
101   auto forward_nodes = GetNodes(root);
102   for (auto &forward_node : forward_nodes) {
103     if (!forward_node->isa<CNode>()) {
104       continue;
105     }
106     auto cnode = forward_node->cast<CNodePtr>();
107     if (IsValueNode<FuncGraph>(cnode->input(0))) {
108       auto fg = GetValueNode<FuncGraphPtr>(cnode->input(0));
109       if (fg->has_flag(FUNC_GRAPH_RECOMPUTE_GRAD_GRAPH) || fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH)) {
110         MS_LOG(DEBUG) << "found recompute cell " << fg->ToString();
111         return true;
112       }
113     }
114 
115     if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
116       auto current_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
117       if (current_prim != nullptr) {
118         auto prim_recompute_attr = current_prim->GetAttr(kAttrRecompute);
119         if (prim_recompute_attr != nullptr && prim_recompute_attr->isa<BoolImm>()) {
120           auto recomputed = GetValue<bool>(prim_recompute_attr);
121           if (recomputed) {
122             MS_LOG(DEBUG) << "found recompute node " << current_prim->name();
123             return true;
124           }
125         }
126       }
127     }
128   }
129 
130   return false;
131 }
132 
133 // Get DP and MP dimensions
GetDPAndMP(const std::shared_ptr<Graph> & graph,const size_t stage)134 std::tuple<size_t, size_t> GetDPAndMP(const std::shared_ptr<Graph> &graph, const size_t stage) {
135   std::map<std::string, int> strategy_occurrence;
136 
137   size_t dp = 0;
138   size_t mp = 0;
139   const unsigned int kTargetLength = 2;
140   size_t roll_back = FloatToSize(log2(stage));
141   for (auto &node_ptr : graph->nodes) {
142     if (node_ptr.apply.op_type == kRecMatMul) {
143       size_t n_cut = node_ptr.apply.strs.size() - roll_back - 1;
144       if (n_cut >= node_ptr.apply.strs.size()) {
145         MS_LOG(WARNING) << "Strategy of  " << node_ptr.name << " not available";
146         return {PARSING_FAILED, PARSING_FAILED};
147       }
148       StrategyRec strategy = node_ptr.apply.strs[n_cut];
149       if (sizeof(strategy.inputTensor) / sizeof(TensorStr4D) >= kTargetLength) {
150         MS_LOG(DEBUG) << "inputTensor[0] " << strategy.inputTensor[0].str_w << " " << strategy.inputTensor[0].str_h
151                       << " " << strategy.inputTensor[0].str_c << " " << strategy.inputTensor[0].str_n;
152         MS_LOG(DEBUG) << "inputTensor[1] " << strategy.inputTensor[1].str_w << " " << strategy.inputTensor[1].str_h
153                       << " " << strategy.inputTensor[1].str_c << " " << strategy.inputTensor[1].str_n;
154         int mp_strat = 1;
155         int dp_strat = 1;
156         if (strategy.inputTensor[1].str_h * strategy.inputTensor[1].str_w != 0) {
157           mp_strat = static_cast<int>(1 / (strategy.inputTensor[1].str_h * strategy.inputTensor[1].str_w));
158         }
159         if (strategy.inputTensor[0].str_h != 0) {
160           dp_strat = static_cast<int>(1 / strategy.inputTensor[0].str_h);
161         }
162         MS_LOG(DEBUG) << "dp_strat: " << dp_strat << ", mp_strat: " << mp_strat;
163         std::string strategy_str = std::to_string(dp_strat) + "," + std::to_string(mp_strat);
164         auto it = strategy_occurrence.find(strategy_str);
165         if (it == strategy_occurrence.end()) {
166           strategy_occurrence.insert(std::pair<std::string, int>(strategy_str, 1));
167         } else {
168           it->second++;
169         }
170       } else {
171         MS_LOG(DEBUG) << "MatMul strategy found but null";
172       }
173     }
174   }
175   // Take the (DP,MP) that appears the most
176   int occurrence = 0;
177   for (auto it = strategy_occurrence.begin(); it != strategy_occurrence.end(); it++) {
178     if (it->second > occurrence) {
179       auto stra = it->first;
180       auto pos = stra.find(",");
181       dp = static_cast<size_t>(std::stoi(stra.substr(0, pos)));
182       mp = static_cast<size_t>(std::stoi(stra.substr(pos + 1, stra.size() - pos)));
183       occurrence = it->second;
184     }
185   }
186   if (dp > 0 && mp > 0) {
187     return {dp, mp};
188   }
189   return {PARSING_FAILED, PARSING_FAILED};
190 }
191 
192 // Get Vocab Size and Hidden Size as a tuple
GetVocabAndHiddenSize(const FuncGraphPtr & root)193 std::tuple<size_t, size_t> GetVocabAndHiddenSize(const FuncGraphPtr &root) {
194   size_t hidden_size = 0;
195   size_t vocab_size = 0;
196   std::vector<AnfNodePtr> parameters = root->parameters();
197   for (auto &p : parameters) {
198     auto parameter_ptr = p->cast<ParameterPtr>();
199     Shapes param_shapes = GetNodeShape(p);
200     if (hidden_size == 0 && std::regex_match(parameter_ptr->name().c_str(), std::regex(".*0.attention.*.weight"))) {
201       hidden_size = static_cast<size_t>(param_shapes[0][1]);
202       MS_LOG(DEBUG) << "Parameter for hidden size: " << parameter_ptr->name().c_str() << "; with shape " << param_shapes
203                     << "; h = " << hidden_size;
204     } else if (vocab_size == 0 && (std::regex_match(parameter_ptr->name().c_str(), std::regex(".*word_embedding.*")) ||
205                                    std::regex_match(parameter_ptr->name().c_str(), std::regex(".*tok_embeddings.*")))) {
206       vocab_size = static_cast<size_t>(param_shapes[0][0]);
207       MS_LOG(DEBUG) << "Parameter for vocab size: " << parameter_ptr->name().c_str() << "; with shape " << param_shapes
208                     << "; v = " << vocab_size;
209     } else {
210       MS_LOG(DEBUG) << "Parameter " << parameter_ptr->name().c_str() << "; with shape " << param_shapes;
211     }
212     if (hidden_size > 0 && vocab_size > 0) break;
213   }
214   if (hidden_size > 0 && vocab_size > 0) {
215     return {hidden_size, vocab_size};
216   }
217   return {PARSING_FAILED, PARSING_FAILED};
218 }
219 
220 // Get Attention Heads and Sequence Length
GetSeqLengthAndAttentionHeads(const FuncGraphPtr & root)221 std::tuple<size_t, size_t> GetSeqLengthAndAttentionHeads(const FuncGraphPtr &root) {
222   size_t seq_length = 0;
223   size_t attention_heads = 0;
224   auto forward_nodes = GetNodes(root);
225   const size_t kTargetShape = 4;
226   for (auto &forward_node : forward_nodes) {
227     if (forward_node->isa<CNode>()) {
228       auto cnode = forward_node->cast<CNodePtr>();
229       if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
230         auto current_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
231         if (seq_length == 0 && strcmp(current_prim->name().c_str(), GET_NEXT) == 0) {
232           Shapes param_shapes = GetNodeShape(cnode);
233           MS_LOG(DEBUG) << current_prim->name().c_str() << " with shape " << param_shapes;
234           seq_length = static_cast<size_t>(param_shapes[0][1] - 1);
235         }
236         if (attention_heads == 0 && (strcmp(current_prim->name().c_str(), SOFTMAX) == 0 ||
237                                      strcmp(current_prim->name().c_str(), FLASH_ATTENTION_SCORE) == 0)) {
238           Shapes param_shapes = GetNodeShape(cnode);
239           if (param_shapes[0].size() == kTargetShape) {
240             MS_LOG(DEBUG) << current_prim->name().c_str() << " with shape " << param_shapes;
241             attention_heads = static_cast<size_t>(param_shapes[0][1]);
242           }
243         }
244         if (attention_heads > 0 && seq_length > 0) {
245           break;
246         }
247       }
248     }
249   }
250   if (seq_length > 0 && attention_heads > 0) {
251     return {seq_length, attention_heads};
252   }
253   return {PARSING_FAILED, PARSING_FAILED};
254 }
255 
256 // Get num micro
GetNumMicro(const FuncGraphPtr & root)257 size_t GetNumMicro(const FuncGraphPtr &root) {
258   auto manager = root->manager();
259   AnfNodePtr virtual_dataset;
260   for (auto &fg : manager->func_graphs()) {
261     for (auto &node : fg->nodes()) {
262       if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
263         virtual_dataset = node;
264         break;
265       }
266     }
267   }
268   auto node_user_map = manager->node_users();
269   auto node_users = node_user_map[virtual_dataset];
270   for (auto &node_user : node_users) {
271     if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
272       auto data_users = manager->node_users()[node_user.first];
273       auto node_first = data_users.front().first;
274       if (!IsPrimitiveCNode(node_first, prim::kPrimStridedSlice)) {
275         data_users = node_user_map[node_first];
276       }
277       MS_LOG(DEBUG) << "micro batch size found: " << int64_t(data_users.size());
278       return int64_t(data_users.size());
279     }
280   }
281   MS_LOG(WARNING) << "micro batch size not found";
282   return PARSING_FAILED;
283 }
284 
285 // Get per batch
GetPerBatch(const FuncGraphPtr & root,size_t seq_l)286 size_t GetPerBatch(const FuncGraphPtr &root, size_t seq_l) {
287   size_t per_batch = 0;
288   auto forward_nodes = GetNodes(root);
289   for (auto &forward_node : forward_nodes) {
290     if (forward_node->isa<CNode>()) {
291       auto cnode = forward_node->cast<CNodePtr>();
292       if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
293         auto current_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
294         if (per_batch == 0 && strcmp(current_prim->name().c_str(), MATMUL) == 0) {
295           Shapes param_shapes = GetNodeShape(cnode);
296           MS_LOG(DEBUG) << current_prim->name().c_str() << " with shape " << param_shapes;
297           per_batch = static_cast<size_t>(param_shapes[0][0]) / seq_l;
298           MS_LOG(DEBUG) << "batch size found: " << per_batch;
299         }
300         if (per_batch > 0) {
301           break;
302         }
303       }
304     }
305   }
306   return (per_batch > 0) ? per_batch : PARSING_FAILED;
307 }
308 
GetFPFromParams(const FuncGraphPtr & root)309 std::tuple<size_t, size_t, size_t, size_t> GetFPFromParams(const FuncGraphPtr &root) {
310   size_t fp_params = 0;
311   size_t fp_optim = 0;
312   size_t fp_grads = 0;
313   size_t fp_norm = 0;
314 
315   std::vector<AnfNodePtr> parameters = root->parameters();
316   for (auto &p : parameters) {
317     if (p == nullptr) {
318       continue;
319     }
320     auto parameter_ptr = p->cast<ParameterPtr>();
321     mindspore::TypePtr element_type;
322     auto data_type = p->Type();
323     MS_EXCEPTION_IF_NULL(data_type);
324     if (!data_type->isa<mindspore::TensorType>()) {
325       continue;
326     }
327     element_type = data_type->cast<mindspore::TensorTypePtr>()->element();
328     MS_EXCEPTION_IF_NULL(element_type);
329     auto type_id = element_type->type_id();
330     if (fp_grads == 0 && std::regex_match(parameter_ptr->name().c_str(), std::regex("accu_grads.*embedding.*"))) {
331       fp_grads = GetTypeByte(TypeIdToType(type_id));
332     } else if (fp_optim == 0 && std::regex_match(parameter_ptr->name().c_str(), std::regex("adam_m.*embedding.*"))) {
333       fp_optim = GetTypeByte(TypeIdToType(type_id));
334     } else if (fp_params == 0 &&
335                (std::regex_match(parameter_ptr->name().c_str(), std::regex("^model.*embedding.*")) ||
336                 std::regex_match(parameter_ptr->name().c_str(), std::regex("^backbone.*embedding.*")))) {
337       fp_params = GetTypeByte(TypeIdToType(type_id));
338     } else if (fp_norm == 0 &&
339                (std::regex_match(parameter_ptr->name().c_str(), std::regex("^model.*attention_norm.*")) ||
340                 std::regex_match(parameter_ptr->name().c_str(), std::regex("^backbone.*layernorm.*")))) {
341       fp_norm = GetTypeByte(TypeIdToType(type_id));
342     }
343     if (fp_optim > 0 && fp_params > 0 && fp_grads > 0 && fp_norm > 0) {
344       return {fp_optim, fp_params, fp_grads, fp_norm};
345     }
346   }
347 
348   return {fp_params, fp_optim, fp_grads, fp_norm};
349 }
350 
GetFPFromNodes(const FuncGraphPtr & root)351 std::tuple<size_t, size_t> GetFPFromNodes(const FuncGraphPtr &root) {
352   size_t fp_dropout = 0;
353   size_t fp_softmax = 0;
354 
355   auto forward_nodes = GetNodes(root);
356   for (auto &forward_node : forward_nodes) {
357     if (forward_node->isa<CNode>()) {
358       auto cnode = forward_node->cast<CNodePtr>();
359       if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
360         auto current_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
361         mindspore::TypePtr element_type;
362         auto data_type = cnode->Type();
363         MS_EXCEPTION_IF_NULL(data_type);
364         if (!data_type->isa<mindspore::TensorType>()) {
365           if (std::regex_match(current_prim->name().c_str(), std::regex("Dropout"))) {
366             fp_dropout = 1;
367           }
368           continue;
369         }
370         element_type = data_type->cast<mindspore::TensorTypePtr>()->element();
371         MS_EXCEPTION_IF_NULL(element_type);
372         auto type_id = element_type->type_id();
373         if (fp_softmax == 0 && (std::regex_match(current_prim->name().c_str(), std::regex(SOFTMAX)))) {
374           fp_softmax = GetTypeByte(TypeIdToType(type_id));
375         } else if (fp_dropout == 0 && std::regex_match(current_prim->name().c_str(), std::regex(DROPOUT))) {
376           fp_dropout = GetTypeByte(TypeIdToType(type_id));
377         }
378         if (fp_softmax > 0 && fp_dropout > 0) {
379           return {fp_dropout, fp_softmax};
380         }
381       }
382     }
383   }
384 
385   return {fp_dropout, fp_softmax};
386 }
387 
388 // Get FP format for params, optimizer, gradient, norm, softmax, dropout
GetFP_formats(const FuncGraphPtr & root)389 std::tuple<size_t, size_t, size_t, size_t, size_t, size_t> GetFP_formats(const FuncGraphPtr &root) {
390   size_t fp_params = 0;
391   size_t fp_optim = 0;
392   size_t fp_grads = 0;
393   size_t fp_norm = 0;
394   size_t fp_dropout = 0;
395   size_t fp_softmax = 0;
396 
397   std::tie(fp_params, fp_optim, fp_grads, fp_norm) = GetFPFromParams(root);
398   std::tie(fp_dropout, fp_softmax) = GetFPFromNodes(root);
399 
400   if (fp_params == 0 || fp_optim == 0 || fp_grads == 0 || fp_norm == 0) {
401     return {PARSING_FAILED, PARSING_FAILED, PARSING_FAILED, PARSING_FAILED, PARSING_FAILED, PARSING_FAILED};
402   }
403 
404   return {fp_params, fp_optim, fp_grads, fp_norm, fp_dropout, fp_softmax};
405 }
406 
GetExpansionRatio(const FuncGraphPtr & root)407 size_t GetExpansionRatio(const FuncGraphPtr &root) {
408   std::vector<AnfNodePtr> parameters = root->parameters();
409   for (auto &p : parameters) {
410     auto parameter_ptr = p->cast<ParameterPtr>();
411     Shapes param_shapes = GetNodeShape(p);
412     if (std::regex_match(parameter_ptr->name().c_str(), std::regex(".*0.feed_forward.*.weight")) ||
413         std::regex_match(parameter_ptr->name().c_str(), std::regex(".*0.output.projection.weight"))) {
414       return param_shapes[0][0];
415     }
416   }
417   return PARSING_FAILED;
418 }
419 
420 // Get FP format for multi head attention block, norms, feed forward block
GetNumTransformerComponents(const FuncGraphPtr & root)421 std::tuple<size_t, size_t, size_t> GetNumTransformerComponents(const FuncGraphPtr &root) {
422   size_t n_mha = 0;
423   size_t n_ff = 0;
424   size_t n_norm = 0;
425 
426   std::vector<AnfNodePtr> parameters = root->parameters();
427   for (auto &p : parameters) {
428     if (p == nullptr) {
429       continue;
430     }
431     auto parameter_ptr = p->cast<ParameterPtr>();
432     if (std::regex_match(parameter_ptr->name().c_str(), std::regex("^model.layers.0.attention.wo.weight")) ||
433         std::regex_match(parameter_ptr->name().c_str(), std::regex("^backbone.blocks.0.attention.dense1.weight"))) {
434       n_mha++;
435     } else if (std::regex_match(parameter_ptr->name().c_str(), std::regex("^model.layers.0.feed_forward.w1.weight")) ||
436                std::regex_match(parameter_ptr->name().c_str(),
437                                 std::regex("^backbone.blocks.0.output.mapping.weight"))) {
438       n_ff++;
439     } else if (std::regex_match(parameter_ptr->name().c_str(), std::regex("^model.layers.0.attention_norm.weight")) ||
440                std::regex_match(parameter_ptr->name().c_str(), std::regex("^model.layers.0.ffn_norm.weight")) ||
441                std::regex_match(parameter_ptr->name().c_str(), std::regex("^backbone.blocks.0.layernorm.*gamma"))) {
442       n_norm++;
443     }
444   }
445   if (n_mha > 0 && n_ff > 0 && n_norm > 0) {
446     return {n_mha, n_ff, n_norm};
447   }
448   return {PARSING_FAILED, PARSING_FAILED, PARSING_FAILED};
449 }
450 
451 // Count weights matrixes in MHA and FF block
GetNumWeightsTransformer(const FuncGraphPtr & root)452 std::tuple<size_t, size_t> GetNumWeightsTransformer(const FuncGraphPtr &root) {
453   size_t n_weight_MHA = 0;
454   size_t n_weight_FF = 0;
455 
456   std::vector<AnfNodePtr> parameters = root->parameters();
457   for (auto &p : parameters) {
458     if (p == nullptr) {
459       continue;
460     }
461     auto parameter_ptr = p->cast<ParameterPtr>();
462     if (std::regex_match(parameter_ptr->name().c_str(), std::regex("^model.layers.0.attention.w.*weight")) ||
463         std::regex_match(parameter_ptr->name().c_str(), std::regex("^backbone.blocks.0.attention.*weight"))) {
464       n_weight_MHA++;
465     } else if (std::regex_match(parameter_ptr->name().c_str(), std::regex("^model.layers.0.feed_forward.*.weight")) ||
466                std::regex_match(parameter_ptr->name().c_str(), std::regex("^backbone.blocks.0.output.*weight"))) {
467       n_weight_FF++;
468     }
469   }
470   if (n_weight_MHA > 0 && n_weight_FF > 0) {
471     return {n_weight_MHA, n_weight_FF};
472   }
473   return {PARSING_FAILED, PARSING_FAILED};
474 }
475 
StageComputing(const FuncGraphPtr & r,const std::shared_ptr<Graph> & g,size_t device_num,size_t device_capacity,size_t hidden_size,size_t vocab_size,size_t seq_length,size_t head_num,size_t layer_num,size_t expansion_ratio,size_t dp,size_t mp,size_t pp,size_t per_batch,size_t micro,bool parallel_opt,bool recompute)476 StageComputing::StageComputing(const FuncGraphPtr &r, const std::shared_ptr<Graph> &g, size_t device_num,
477                                size_t device_capacity, size_t hidden_size, size_t vocab_size, size_t seq_length,
478                                size_t head_num, size_t layer_num, size_t expansion_ratio, size_t dp, size_t mp,
479                                size_t pp, size_t per_batch, size_t micro, bool parallel_opt, bool recompute)
480     : root_(r),
481       graph_(g),
482       num_devices_(device_num),
483       device_capacity_(device_capacity),
484       vocab_size_(vocab_size),
485       seq_length_(seq_length),
486       hidden_size_(hidden_size),
487       attention_heads_(head_num),
488       num_layers_(layer_num),
489       expansion_ratio_(expansion_ratio),
490       parallel_opt_(parallel_opt),
491       recompute_(recompute),
492       dp_dim_(dp),
493       mp_dim_(mp),
494       pp_dim_(pp),
495       per_batch_(per_batch),
496       num_micros_(micro) {}
497 
SaveConfig()498 void StageComputing::SaveConfig() {
499   saved_config_ = std::make_tuple(dp_dim_, mp_dim_, pp_dim_, per_batch_, num_micros_);
500 }
501 
LoadConfig()502 void StageComputing::LoadConfig() { std::tie(dp_dim_, mp_dim_, pp_dim_, per_batch_, num_micros_) = saved_config_; }
503 
504 // Generalization of num parameters for transformer-based, relying on parsing
NumParametersParsing(size_t l)505 size_t StageComputing::NumParametersParsing(size_t l) {
506   size_t n_weight_MHA;
507   size_t n_weight_FF;
508   std::tie(n_weight_MHA, n_weight_FF) = GetNumWeightsTransformer(root_);
509   size_t n_MHA;
510   size_t n_FF;
511   size_t n_norm;
512   std::tie(n_MHA, n_FF, n_norm) = GetNumTransformerComponents(root_);
513 
514   if (n_weight_MHA == PARSING_FAILED || n_MHA == PARSING_FAILED) {
515     return PARSING_FAILED;
516   }
517   const size_t P_MHA = n_weight_MHA * (hidden_size_ * hidden_size_ + hidden_size_);
518   const size_t P_FF = n_weight_FF * (expansion_ratio_ * hidden_size_) + expansion_ratio_ + hidden_size_;
519   const size_t P_norm = 2 * hidden_size_;
520   const size_t P_linear = hidden_size_ * vocab_size_ + vocab_size_;
521   const size_t P_embedding = hidden_size_ * vocab_size_;
522 
523   return l * (n_MHA * P_MHA + n_norm * P_norm + n_FF * P_FF) + P_linear + P_embedding;
524 }
525 
526 // Generalization of static memory for transformer-based, relying on parsing
GetStaticMemoryParsing(size_t d,size_t t,size_t p,size_t P)527 size_t StageComputing::GetStaticMemoryParsing(size_t d, size_t t, size_t p, size_t P) {
528   size_t FP_params;
529   size_t FP_optimizer;
530   size_t FP_gradient;
531   std::tie(FP_params, FP_optimizer, FP_gradient, std::ignore, std::ignore, std::ignore) = GetFP_formats(root_);
532   if (FP_params == PARSING_FAILED) {
533     return PARSING_FAILED;
534   }
535   size_t model_params_size = (FP_params * P) / ((p == 1) ? (d * t) : t);
536   size_t accu_gradients_size = (FP_gradient * P) / ((p == 1 && parallel_opt_) ? (d * t) : t);
537   size_t optim_states_size = (2 * FP_optimizer * P) / ((parallel_opt_) ? (d * t) : t);
538   MS_LOG(DEBUG) << "model_params_size: " << TSepBytes(static_cast<size_t>(model_params_size));
539   MS_LOG(DEBUG) << "accu_gradients_size: " << TSepBytes(static_cast<size_t>(accu_gradients_size));
540   MS_LOG(DEBUG) << "optim_states_size: " << TSepBytes(static_cast<size_t>(optim_states_size));
541   return (model_params_size + accu_gradients_size + optim_states_size);
542 }
543 
544 // Generalization of dynamic memory for transformer-based, relying on parsing
545 // Assuming seq parallelism for dropout and norms
546 // Assuming full recomputation
GetDynamicMemoryParsing(size_t l,size_t b,size_t m,size_t p,size_t t)547 size_t StageComputing::GetDynamicMemoryParsing(size_t l, size_t b, size_t m, size_t p, size_t t) {
548   size_t n_weight_MHA;
549   size_t n_weight_FF;
550   std::tie(n_weight_MHA, n_weight_FF) = GetNumWeightsTransformer(root_);
551   size_t FP_params;
552   size_t FP_optimizer;
553   size_t FP_gradient;
554   size_t FP_norm;
555   size_t FP_dropout;
556   size_t FP_softmax;
557   size_t n_MHA;
558   size_t n_FF;
559   size_t n_norm;
560   std::tie(FP_params, FP_optimizer, FP_gradient, FP_norm, FP_dropout, FP_softmax) = GetFP_formats(root_);
561   FP_softmax = 0;
562   std::tie(n_MHA, n_FF, n_norm) = GetNumTransformerComponents(root_);
563   MS_LOG(DEBUG) << "FP_params: " << FP_params << ", FP_optimizer: " << FP_optimizer << ", FP_gradient: " << FP_gradient;
564   MS_LOG(DEBUG) << "FP_dropout: " << FP_dropout << ", FP_softmax: " << FP_softmax << ", FP_norm: " << FP_norm;
565   MS_LOG(DEBUG) << "n_MHA: " << n_MHA << ", n_FF: " << n_FF << ", n_norm: " << n_norm;
566   MS_LOG(DEBUG) << "n_weight_MHA: " << n_weight_MHA << ", n_weight_FF: " << n_weight_FF;
567   float sbh = seq_length_ * b * hidden_size_;
568   float A_norm = n_norm * (FP_norm * sbh) / t;
569   float A_MHA =
570     n_MHA * ((n_weight_MHA * FP_params * sbh + FP_softmax * seq_length_ * seq_length_ * b * attention_heads_ +
571               FP_dropout * 2 * seq_length_ * seq_length_ * b * attention_heads_) /
572                t +
573              FP_dropout * sbh / t);
574   float A_FF = n_FF * (static_cast<float>(n_weight_FF * FP_params * seq_length_ * b * expansion_ratio_) / t +
575                        FP_dropout * sbh / t);
576   float A_intermediate = A_norm + A_MHA + A_FF;
577   float nodes = static_cast<float>(num_devices_) / 8;
578   float A_input = ((p > 1 && nodes == 1) ? m : ceil(nodes / 4)) * sbh / t;
579   float n_Checkpoints = num_layers_ / 1.5;
580   float full_recompute_size = n_Checkpoints * A_input + (l / n_Checkpoints) * A_intermediate;
581 
582   float communications_size = static_cast<float>(8 * seq_length_ * b * hidden_size_ * (t - 1)) / t;
583   MS_LOG(DEBUG) << "l: " << l;
584   MS_LOG(DEBUG) << "n_Checkpoints: " << n_Checkpoints;
585   MS_LOG(DEBUG) << "A_input: " << TSepBytes(static_cast<size_t>(A_input));
586   MS_LOG(DEBUG) << "A_norm: " << TSepBytes(static_cast<size_t>(A_norm));
587   MS_LOG(DEBUG) << "A_MHA: " << TSepBytes(static_cast<size_t>(A_MHA));
588   MS_LOG(DEBUG) << "A_FF: " << TSepBytes(static_cast<size_t>(A_FF));
589   MS_LOG(DEBUG) << "A_intermediate: " << TSepBytes(static_cast<size_t>(A_intermediate));
590   MS_LOG(DEBUG) << "l*communication: " << TSepBytes(static_cast<size_t>(l * communications_size));
591   MS_LOG(DEBUG) << "l/n_Checkpoints * A_intermediate: "
592                 << TSepBytes(static_cast<size_t>((l / n_Checkpoints) * A_intermediate));
593   MS_LOG(DEBUG) << "n_Checkpoints * A_input: " << TSepBytes(static_cast<size_t>(n_Checkpoints * A_input));
594 
595   return static_cast<size_t>(full_recompute_size + l * communications_size);
596 }
597 
598 // Manually compute global batch size
GlobalBatchSize()599 size_t StageComputing::GlobalBatchSize() { return per_batch_ * num_micros_; }
600 
601 // Get layer per stage
GetLayerPerStage()602 size_t StageComputing::GetLayerPerStage() {
603   return ceil(static_cast<float>(num_layers_) / static_cast<float>(pp_dim_));
604 }
605 
GetMemory()606 size_t StageComputing::GetMemory() {
607   size_t P3 = NumParametersParsing(GetLayerPerStage());
608   if (P3 == PARSING_FAILED) {
609     return PARSING_FAILED;
610   }
611   size_t sMem3 = GetStaticMemoryParsing(dp_dim_, mp_dim_, pp_dim_, P3);
612   if (sMem3 == PARSING_FAILED) {
613     return PARSING_FAILED;
614   }
615   size_t dMem3 = GetDynamicMemoryParsing(GetLayerPerStage(), per_batch_, num_micros_, pp_dim_, mp_dim_);
616   PrintResults(sMem3, dMem3, P3);
617   return (sMem3 + dMem3);
618 }
619 
620 // MS_LOG
PrintHyperparams()621 void StageComputing::PrintHyperparams() {
622   MS_LOG(INFO) << "Hyperparameters : h : " << hidden_size_ << ", s : " << seq_length_ << ", v : " << vocab_size_
623                << ", a : " << attention_heads_ << ", L : " << num_layers_ << ", pb:" << per_batch_
624                << ", B :" << GlobalBatchSize() << ", er: " << expansion_ratio_ << ", opt: " << parallel_opt_
625                << ", rcpt: " << recompute_;
626 }
627 
Suggestion(const std::string & suggestion)628 void Suggestion(const std::string &suggestion) {
629   MS_LOG(INFO) << std::endl
630                << "=================== Auto Parallel Config Suggestion by SAPP ===================" << std::endl
631                << suggestion << std::endl
632                << "===============================================================================";
633 }
634 
ParamSuggest(float mem_coeff,size_t stage,size_t batch,size_t micro)635 std::string ParamSuggest(float mem_coeff, size_t stage, size_t batch, size_t micro) {
636   std::stringstream ss;
637 
638   ss << " mem_coeff: " << mem_coeff;
639   ss << ", pipeline_stage: " << stage << std::endl;
640   ss << " batch_size: " << batch;
641   ss << ", micro_batch_num: " << micro;
642 
643   return ss.str();
644 }
645 
FittingSuggestion()646 void StageComputing::FittingSuggestion() {
647   std::stringstream ss;
648 
649   ss << " SAPP algorithm suggests the following parallel configuration:" << std::endl;
650   ss << ParamSuggest(CostModelContext::GetInstance()->rp_matmul_mem_coef(), pp_dim_, per_batch_, num_micros_);
651 
652   Suggestion(ss.str());
653 }
654 
OOMSuggestion()655 void StageComputing::OOMSuggestion() {
656   std::stringstream ss;
657 
658   float default_coeff = 1024;
659   ss << " The current configuration seem to not fit in memory." << std::endl;
660   ss << " SAPP algorithm suggests to change configuration to:" << std::endl;
661   ss << ParamSuggest(default_coeff, pp_dim_, 1, num_micros_ * per_batch_);
662 
663   Suggestion(ss.str());
664 }
665 
ParsingException()666 void StageComputing::ParsingException() {
667   MS_LOG(WARNING) << "Something went wrong during the graph parsing process.";
668   MS_LOG(WARNING) << "SAPP algorithm uses original stage number";
669 }
670 
PrintResults(size_t StaticMEM,size_t DynamicMEM,size_t num_param)671 void StageComputing::PrintResults(size_t StaticMEM, size_t DynamicMEM, size_t num_param) {
672   MS_LOG(INFO) << "DP: " << dp_dim_ << ", MP: " << mp_dim_ << ", Stages: " << pp_dim_ << ", n_micros: " << num_micros_
673                << ", Per_batch: " << per_batch_ << ", Global batch size: " << GlobalBatchSize()
674                << ", Num Parameters: " << num_param;
675   MS_LOG(INFO) << "StaticMEM: " << TSepBytes(StaticMEM) << ", DynamicMEM: " << TSepBytes(DynamicMEM)
676                << ", totalMEM: " << TSepBytes(StaticMEM + DynamicMEM);
677 }
678 
CurrentEstimation()679 size_t StageComputing::CurrentEstimation() { return GetMemory(); }
680 
fits(size_t memory)681 bool StageComputing::fits(size_t memory) { return ((MARGIN_FACTOR * memory) < device_capacity_); }
682 
683 // Suggest a parallel config (dp,mp,pp) + batch (micro, per batch, dp)
684 // maintain global batch size
685 // memory function as cost model
FindSmallerStage()686 Status StageComputing::FindSmallerStage() {
687   if (pp_dim_ == 1) {
688     return FAILED;
689   }
690   SaveConfig();
691 
692   bool saved = false;
693   size_t factor = 2;
694   double ratio = static_cast<float>(dp_dim_) / static_cast<float>(mp_dim_);
695   for (; pp_dim_ >= 1; pp_dim_ /= factor) {
696     if (fits(GetMemory()) && !saved) {
697       SaveConfig();
698       saved = true;
699       MS_LOG(INFO) << "Stage " << pp_dim_ << " is selected";
700     }
701     float ratio_factor_dp = abs(static_cast<float>(factor * dp_dim_) / mp_dim_ - ratio);
702     float ratio_factor_mp = abs(static_cast<float>(dp_dim_) / (factor * mp_dim_) - ratio);
703     if (mp_dim_ == 1)
704       dp_dim_ *= factor;
705     else if (dp_dim_ == 1)
706       mp_dim_ *= factor;
707     else if (ratio_factor_dp <= ratio_factor_mp)
708       dp_dim_ *= factor;
709     else
710       mp_dim_ *= factor;
711   }
712 
713   LoadConfig();
714 
715   if (!saved) {
716     return FAILED;
717   }
718   return SUCCESS;
719 }
720 
721 // Estimation compute
LaunchStageCompute()722 size_t StageComputing::LaunchStageCompute() {
723   size_t pp = pp_dim_;
724   if (vocab_size_ == PARSING_FAILED || seq_length_ == PARSING_FAILED || expansion_ratio_ == PARSING_FAILED ||
725       num_layers_ == PARSING_FAILED || dp_dim_ == PARSING_FAILED || num_micros_ == PARSING_FAILED ||
726       per_batch_ == PARSING_FAILED) {
727     ParsingException();
728     return pp;
729   }
730 
731   MS_LOG(INFO) << "Current stage number : " << pp;
732   MS_LOG(DEBUG) << "Number of devices: " << num_devices_;
733 
734   PrintHyperparams();
735   if (CurrentEstimation() == PARSING_FAILED) {
736     ParsingException();
737     return pp;
738   }
739   if (FindSmallerStage() == FAILED) {
740     OOMSuggestion();
741     return pp;
742   } else {
743     FittingSuggestion();
744   }
745 
746   return pp_dim_;
747 }
748 
749 // Suggest a pipeline stage
ParallelSuggestion(const FuncGraphPtr & root,const std::shared_ptr<Graph> & graph)750 size_t ParallelSuggestion(const FuncGraphPtr &root, const std::shared_ptr<Graph> &graph) {
751   size_t vocab;
752   size_t seq;
753   size_t heads;
754   size_t dp;
755   size_t mp;
756   size_t pp;
757   size_t hidden;
758   size_t layers;
759   size_t devices;
760   size_t capacity;
761   size_t micros;
762   size_t per_batch;
763   size_t er;
764   bool opt;
765   bool recompute;
766 
767   pp = static_cast<size_t>(parallel::ParallelContext::GetInstance()->pipeline_stage_split_num());
768   if (root == nullptr || graph == nullptr) {
769     MS_LOG(WARNING) << "Null costgraph or recgraph, ParallelSuggestion cannot run.";
770     MS_LOG(WARNING) << "SAPP algorithm uses original stage number.";
771     return pp;
772   }
773 
774   std::tie(seq, heads) = GetSeqLengthAndAttentionHeads(root);
775   std::tie(dp, mp) = GetDPAndMP(graph, pp);
776   std::tie(hidden, vocab) = GetVocabAndHiddenSize(root);
777   er = GetExpansionRatio(root);
778   layers = GetNumLayers(root);
779   capacity = GetDeviceCapacity();
780   micros = GetNumMicro(root);
781 
782   per_batch = GetPerBatch(root, seq);
783   devices = GetNumDevices();
784   opt = HasParallelOptimizer(root);
785   recompute = HasRecompute(root);
786 
787   StageComputing sc(root, graph, devices, capacity, hidden, vocab, seq, heads, layers, er, dp, mp, pp, per_batch,
788                     micros, opt, recompute);
789   pp = sc.LaunchStageCompute();
790   return pp;
791 }
792 
IsGraphFilter(const AnfNodePtr & node)793 bool IsGraphFilter(const AnfNodePtr &node) { return !IsValueNode<FuncGraph>(node); }
794 
795 // Update old stage number with suggestion
ChangeStageNumber(const FuncGraphPtr & root,size_t new_stage_num)796 void ChangeStageNumber(const FuncGraphPtr &root, size_t new_stage_num) {
797   size_t old_stage = static_cast<size_t>(parallel::ParallelContext::GetInstance()->pipeline_stage_split_num());
798   if (old_stage == new_stage_num) {
799     MS_LOG(INFO) << "Stage number " << new_stage_num << " is the same as the old value. Nothing changed.";
800     return;
801   }
802 
803   if (old_stage % new_stage_num != 0) {
804     MS_LOG(WARNING) << "Stage number " << new_stage_num << " is not a divisor of the previous stage number "
805                     << old_stage << ". Stage Number is NOT changed.";
806     return;
807   }
808 
809   size_t change_factor = old_stage / new_stage_num;
810   MS_LOG(DEBUG) << "Old stage number:" << old_stage << " ; Change factor:" << change_factor;
811 
812   FuncGraphPtr main_graph;
813   // Get main graph
814   auto manager = root->manager();
815   if (!root->has_flag(kTraining)) {
816     main_graph = root;
817   } else {
818     for (auto &fg : manager->func_graphs()) {
819       for (auto &node : fg->nodes()) {
820         if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
821           main_graph = fg;
822           break;
823         }
824       }
825     }
826   }
827 
828   // Get all sub graphs
829   auto nodes = DeepScopedGraphSearchWithFilter(main_graph->get_return(), AlwaysInclude, IsGraphFilter);
830   std::reverse(nodes.begin(), nodes.end());
831   std::vector<FuncGraphPtr> subgraphs;
832   for (auto &node : nodes) {
833     auto graph = GetValueNode<FuncGraphPtr>(node);
834     subgraphs.push_back(graph);
835   }
836 
837   // Update stage in all sub_graphs
838   for (auto &graph : subgraphs) {
839     int graph_old_stage = graph->stage();
840     if (graph_old_stage != -1) {
841       graph->set_stage(graph_old_stage / change_factor);  // Either increase or decrease
842     }
843   }
844 
845   // Update stage in parallel context
846   parallel::ParallelContext::GetInstance()->set_pipeline_stage_split_num(old_stage / change_factor);
847   MS_LOG(INFO) << "END ChangeStageNumber"
848                << ", new stage number: " << (old_stage / change_factor);
849 }
850 
851 }  // namespace parallel
852 }  // namespace mindspore
853