1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/auto_parallel/stage_compute.h"
18
19 #include <algorithm>
20 #include <utility>
21 #include <string>
22 #include <vector>
23 #include <map>
24 #include <regex>
25
26 #include "mindspore/core/ops/array_ops.h"
27 #include "frontend/parallel/step_parallel.h"
28 #include "utils/parallel_node_check.h"
29 #include "ir/func_graph.h"
30 #include "include/common/utils/parallel_context.h"
31 #include "mindspore/core/utils/ms_utils.h"
32
33 namespace mindspore {
34 namespace parallel {
35
36 constexpr size_t PARSING_FAILED = SIZE_MAX;
37 constexpr size_t Kilo = 1024;
38 constexpr double MARGIN_FACTOR = 1.1;
39
40 // Thousand separators for memory numbers
TSepBytes(size_t n)41 string TSepBytes(size_t n) {
42 string res = std::to_string(n / Kilo / Kilo);
43 int thousand_digit_num = 3;
44 for (int i = static_cast<int>(res.size()) - thousand_digit_num; i > 0; i -= thousand_digit_num) res.insert(i, ",");
45 return res + "M";
46 }
47
GetNodes(const FuncGraphPtr & root)48 std::vector<AnfNodePtr> GetNodes(const FuncGraphPtr &root) {
49 AnfNodePtr ret_forward = root->get_return();
50 return DeepScopedGraphSearch(ret_forward);
51 }
52
53 // Get Number of Layers ((each model has unique layers name to analyse))
GetNumLayers(const FuncGraphPtr & root)54 size_t GetNumLayers(const FuncGraphPtr &root) {
55 const std::string kHeadLayer = "Head";
56 const std::string kNormLayer = "Norm";
57 const std::string kEmbeddingLayer = "Embedding";
58 const std::string kLinearLayer = "Linear";
59 std::vector<FuncGraphPtr> pipeline_cells;
60 size_t num_layers = 0;
61
62 auto forward_nodes = GetNodes(root);
63 for (auto node : forward_nodes) {
64 if (!node->isa<CNode>()) continue;
65
66 auto cnode = node->cast<CNodePtr>();
67 if (!IsValueNode<FuncGraph>(cnode->input(0))) {
68 continue;
69 }
70
71 auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
72 if (graph->stage() == -1 ||
73 std::find(pipeline_cells.begin(), pipeline_cells.end(), graph) != pipeline_cells.end()) {
74 continue;
75 }
76
77 pipeline_cells.push_back(graph);
78 std::string name = graph->ToString();
79 // Remove pre/post cells
80 if (!(name.find(kHeadLayer) != std::string::npos || name.find(kNormLayer) != std::string::npos ||
81 name.find(kEmbeddingLayer) != std::string::npos || name.find(kLinearLayer) != std::string::npos)) {
82 MS_LOG(DEBUG) << name << " is counted as a layer";
83 num_layers++;
84 } else {
85 MS_LOG(DEBUG) << name << " is NOT a normal layer";
86 }
87 }
88
89 return (num_layers > 0) ? num_layers : PARSING_FAILED;
90 }
91
GetNumDevices()92 size_t GetNumDevices() { return g_device_manager->DeviceNum(); }
93
94 // Get parallel_optimizer
HasParallelOptimizer(const FuncGraphPtr & root)95 bool HasParallelOptimizer(const FuncGraphPtr &root) {
96 return parallel::ParallelContext::GetInstance()->enable_parallel_optimizer();
97 }
98
99 // Check if recomputation was chosen. Currently only able to check select_recompute
HasRecompute(const FuncGraphPtr & root)100 bool HasRecompute(const FuncGraphPtr &root) {
101 auto forward_nodes = GetNodes(root);
102 for (auto &forward_node : forward_nodes) {
103 if (!forward_node->isa<CNode>()) {
104 continue;
105 }
106 auto cnode = forward_node->cast<CNodePtr>();
107 if (IsValueNode<FuncGraph>(cnode->input(0))) {
108 auto fg = GetValueNode<FuncGraphPtr>(cnode->input(0));
109 if (fg->has_flag(FUNC_GRAPH_RECOMPUTE_GRAD_GRAPH) || fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH)) {
110 MS_LOG(DEBUG) << "found recompute cell " << fg->ToString();
111 return true;
112 }
113 }
114
115 if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
116 auto current_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
117 if (current_prim != nullptr) {
118 auto prim_recompute_attr = current_prim->GetAttr(kAttrRecompute);
119 if (prim_recompute_attr != nullptr && prim_recompute_attr->isa<BoolImm>()) {
120 auto recomputed = GetValue<bool>(prim_recompute_attr);
121 if (recomputed) {
122 MS_LOG(DEBUG) << "found recompute node " << current_prim->name();
123 return true;
124 }
125 }
126 }
127 }
128 }
129
130 return false;
131 }
132
133 // Get DP and MP dimensions
GetDPAndMP(const std::shared_ptr<Graph> & graph,const size_t stage)134 std::tuple<size_t, size_t> GetDPAndMP(const std::shared_ptr<Graph> &graph, const size_t stage) {
135 std::map<std::string, int> strategy_occurrence;
136
137 size_t dp = 0;
138 size_t mp = 0;
139 const unsigned int kTargetLength = 2;
140 size_t roll_back = FloatToSize(log2(stage));
141 for (auto &node_ptr : graph->nodes) {
142 if (node_ptr.apply.op_type == kRecMatMul) {
143 size_t n_cut = node_ptr.apply.strs.size() - roll_back - 1;
144 if (n_cut >= node_ptr.apply.strs.size()) {
145 MS_LOG(WARNING) << "Strategy of " << node_ptr.name << " not available";
146 return {PARSING_FAILED, PARSING_FAILED};
147 }
148 StrategyRec strategy = node_ptr.apply.strs[n_cut];
149 if (sizeof(strategy.inputTensor) / sizeof(TensorStr4D) >= kTargetLength) {
150 MS_LOG(DEBUG) << "inputTensor[0] " << strategy.inputTensor[0].str_w << " " << strategy.inputTensor[0].str_h
151 << " " << strategy.inputTensor[0].str_c << " " << strategy.inputTensor[0].str_n;
152 MS_LOG(DEBUG) << "inputTensor[1] " << strategy.inputTensor[1].str_w << " " << strategy.inputTensor[1].str_h
153 << " " << strategy.inputTensor[1].str_c << " " << strategy.inputTensor[1].str_n;
154 int mp_strat = 1;
155 int dp_strat = 1;
156 if (strategy.inputTensor[1].str_h * strategy.inputTensor[1].str_w != 0) {
157 mp_strat = static_cast<int>(1 / (strategy.inputTensor[1].str_h * strategy.inputTensor[1].str_w));
158 }
159 if (strategy.inputTensor[0].str_h != 0) {
160 dp_strat = static_cast<int>(1 / strategy.inputTensor[0].str_h);
161 }
162 MS_LOG(DEBUG) << "dp_strat: " << dp_strat << ", mp_strat: " << mp_strat;
163 std::string strategy_str = std::to_string(dp_strat) + "," + std::to_string(mp_strat);
164 auto it = strategy_occurrence.find(strategy_str);
165 if (it == strategy_occurrence.end()) {
166 strategy_occurrence.insert(std::pair<std::string, int>(strategy_str, 1));
167 } else {
168 it->second++;
169 }
170 } else {
171 MS_LOG(DEBUG) << "MatMul strategy found but null";
172 }
173 }
174 }
175 // Take the (DP,MP) that appears the most
176 int occurrence = 0;
177 for (auto it = strategy_occurrence.begin(); it != strategy_occurrence.end(); it++) {
178 if (it->second > occurrence) {
179 auto stra = it->first;
180 auto pos = stra.find(",");
181 dp = static_cast<size_t>(std::stoi(stra.substr(0, pos)));
182 mp = static_cast<size_t>(std::stoi(stra.substr(pos + 1, stra.size() - pos)));
183 occurrence = it->second;
184 }
185 }
186 if (dp > 0 && mp > 0) {
187 return {dp, mp};
188 }
189 return {PARSING_FAILED, PARSING_FAILED};
190 }
191
192 // Get Vocab Size and Hidden Size as a tuple
GetVocabAndHiddenSize(const FuncGraphPtr & root)193 std::tuple<size_t, size_t> GetVocabAndHiddenSize(const FuncGraphPtr &root) {
194 size_t hidden_size = 0;
195 size_t vocab_size = 0;
196 std::vector<AnfNodePtr> parameters = root->parameters();
197 for (auto &p : parameters) {
198 auto parameter_ptr = p->cast<ParameterPtr>();
199 Shapes param_shapes = GetNodeShape(p);
200 if (hidden_size == 0 && std::regex_match(parameter_ptr->name().c_str(), std::regex(".*0.attention.*.weight"))) {
201 hidden_size = static_cast<size_t>(param_shapes[0][1]);
202 MS_LOG(DEBUG) << "Parameter for hidden size: " << parameter_ptr->name().c_str() << "; with shape " << param_shapes
203 << "; h = " << hidden_size;
204 } else if (vocab_size == 0 && (std::regex_match(parameter_ptr->name().c_str(), std::regex(".*word_embedding.*")) ||
205 std::regex_match(parameter_ptr->name().c_str(), std::regex(".*tok_embeddings.*")))) {
206 vocab_size = static_cast<size_t>(param_shapes[0][0]);
207 MS_LOG(DEBUG) << "Parameter for vocab size: " << parameter_ptr->name().c_str() << "; with shape " << param_shapes
208 << "; v = " << vocab_size;
209 } else {
210 MS_LOG(DEBUG) << "Parameter " << parameter_ptr->name().c_str() << "; with shape " << param_shapes;
211 }
212 if (hidden_size > 0 && vocab_size > 0) break;
213 }
214 if (hidden_size > 0 && vocab_size > 0) {
215 return {hidden_size, vocab_size};
216 }
217 return {PARSING_FAILED, PARSING_FAILED};
218 }
219
220 // Get Attention Heads and Sequence Length
GetSeqLengthAndAttentionHeads(const FuncGraphPtr & root)221 std::tuple<size_t, size_t> GetSeqLengthAndAttentionHeads(const FuncGraphPtr &root) {
222 size_t seq_length = 0;
223 size_t attention_heads = 0;
224 auto forward_nodes = GetNodes(root);
225 const size_t kTargetShape = 4;
226 for (auto &forward_node : forward_nodes) {
227 if (forward_node->isa<CNode>()) {
228 auto cnode = forward_node->cast<CNodePtr>();
229 if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
230 auto current_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
231 if (seq_length == 0 && strcmp(current_prim->name().c_str(), GET_NEXT) == 0) {
232 Shapes param_shapes = GetNodeShape(cnode);
233 MS_LOG(DEBUG) << current_prim->name().c_str() << " with shape " << param_shapes;
234 seq_length = static_cast<size_t>(param_shapes[0][1] - 1);
235 }
236 if (attention_heads == 0 && (strcmp(current_prim->name().c_str(), SOFTMAX) == 0 ||
237 strcmp(current_prim->name().c_str(), FLASH_ATTENTION_SCORE) == 0)) {
238 Shapes param_shapes = GetNodeShape(cnode);
239 if (param_shapes[0].size() == kTargetShape) {
240 MS_LOG(DEBUG) << current_prim->name().c_str() << " with shape " << param_shapes;
241 attention_heads = static_cast<size_t>(param_shapes[0][1]);
242 }
243 }
244 if (attention_heads > 0 && seq_length > 0) {
245 break;
246 }
247 }
248 }
249 }
250 if (seq_length > 0 && attention_heads > 0) {
251 return {seq_length, attention_heads};
252 }
253 return {PARSING_FAILED, PARSING_FAILED};
254 }
255
256 // Get num micro
GetNumMicro(const FuncGraphPtr & root)257 size_t GetNumMicro(const FuncGraphPtr &root) {
258 auto manager = root->manager();
259 AnfNodePtr virtual_dataset;
260 for (auto &fg : manager->func_graphs()) {
261 for (auto &node : fg->nodes()) {
262 if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
263 virtual_dataset = node;
264 break;
265 }
266 }
267 }
268 auto node_user_map = manager->node_users();
269 auto node_users = node_user_map[virtual_dataset];
270 for (auto &node_user : node_users) {
271 if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
272 auto data_users = manager->node_users()[node_user.first];
273 auto node_first = data_users.front().first;
274 if (!IsPrimitiveCNode(node_first, prim::kPrimStridedSlice)) {
275 data_users = node_user_map[node_first];
276 }
277 MS_LOG(DEBUG) << "micro batch size found: " << int64_t(data_users.size());
278 return int64_t(data_users.size());
279 }
280 }
281 MS_LOG(WARNING) << "micro batch size not found";
282 return PARSING_FAILED;
283 }
284
285 // Get per batch
GetPerBatch(const FuncGraphPtr & root,size_t seq_l)286 size_t GetPerBatch(const FuncGraphPtr &root, size_t seq_l) {
287 size_t per_batch = 0;
288 auto forward_nodes = GetNodes(root);
289 for (auto &forward_node : forward_nodes) {
290 if (forward_node->isa<CNode>()) {
291 auto cnode = forward_node->cast<CNodePtr>();
292 if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
293 auto current_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
294 if (per_batch == 0 && strcmp(current_prim->name().c_str(), MATMUL) == 0) {
295 Shapes param_shapes = GetNodeShape(cnode);
296 MS_LOG(DEBUG) << current_prim->name().c_str() << " with shape " << param_shapes;
297 per_batch = static_cast<size_t>(param_shapes[0][0]) / seq_l;
298 MS_LOG(DEBUG) << "batch size found: " << per_batch;
299 }
300 if (per_batch > 0) {
301 break;
302 }
303 }
304 }
305 }
306 return (per_batch > 0) ? per_batch : PARSING_FAILED;
307 }
308
GetFPFromParams(const FuncGraphPtr & root)309 std::tuple<size_t, size_t, size_t, size_t> GetFPFromParams(const FuncGraphPtr &root) {
310 size_t fp_params = 0;
311 size_t fp_optim = 0;
312 size_t fp_grads = 0;
313 size_t fp_norm = 0;
314
315 std::vector<AnfNodePtr> parameters = root->parameters();
316 for (auto &p : parameters) {
317 if (p == nullptr) {
318 continue;
319 }
320 auto parameter_ptr = p->cast<ParameterPtr>();
321 mindspore::TypePtr element_type;
322 auto data_type = p->Type();
323 MS_EXCEPTION_IF_NULL(data_type);
324 if (!data_type->isa<mindspore::TensorType>()) {
325 continue;
326 }
327 element_type = data_type->cast<mindspore::TensorTypePtr>()->element();
328 MS_EXCEPTION_IF_NULL(element_type);
329 auto type_id = element_type->type_id();
330 if (fp_grads == 0 && std::regex_match(parameter_ptr->name().c_str(), std::regex("accu_grads.*embedding.*"))) {
331 fp_grads = GetTypeByte(TypeIdToType(type_id));
332 } else if (fp_optim == 0 && std::regex_match(parameter_ptr->name().c_str(), std::regex("adam_m.*embedding.*"))) {
333 fp_optim = GetTypeByte(TypeIdToType(type_id));
334 } else if (fp_params == 0 &&
335 (std::regex_match(parameter_ptr->name().c_str(), std::regex("^model.*embedding.*")) ||
336 std::regex_match(parameter_ptr->name().c_str(), std::regex("^backbone.*embedding.*")))) {
337 fp_params = GetTypeByte(TypeIdToType(type_id));
338 } else if (fp_norm == 0 &&
339 (std::regex_match(parameter_ptr->name().c_str(), std::regex("^model.*attention_norm.*")) ||
340 std::regex_match(parameter_ptr->name().c_str(), std::regex("^backbone.*layernorm.*")))) {
341 fp_norm = GetTypeByte(TypeIdToType(type_id));
342 }
343 if (fp_optim > 0 && fp_params > 0 && fp_grads > 0 && fp_norm > 0) {
344 return {fp_optim, fp_params, fp_grads, fp_norm};
345 }
346 }
347
348 return {fp_params, fp_optim, fp_grads, fp_norm};
349 }
350
GetFPFromNodes(const FuncGraphPtr & root)351 std::tuple<size_t, size_t> GetFPFromNodes(const FuncGraphPtr &root) {
352 size_t fp_dropout = 0;
353 size_t fp_softmax = 0;
354
355 auto forward_nodes = GetNodes(root);
356 for (auto &forward_node : forward_nodes) {
357 if (forward_node->isa<CNode>()) {
358 auto cnode = forward_node->cast<CNodePtr>();
359 if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
360 auto current_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
361 mindspore::TypePtr element_type;
362 auto data_type = cnode->Type();
363 MS_EXCEPTION_IF_NULL(data_type);
364 if (!data_type->isa<mindspore::TensorType>()) {
365 if (std::regex_match(current_prim->name().c_str(), std::regex("Dropout"))) {
366 fp_dropout = 1;
367 }
368 continue;
369 }
370 element_type = data_type->cast<mindspore::TensorTypePtr>()->element();
371 MS_EXCEPTION_IF_NULL(element_type);
372 auto type_id = element_type->type_id();
373 if (fp_softmax == 0 && (std::regex_match(current_prim->name().c_str(), std::regex(SOFTMAX)))) {
374 fp_softmax = GetTypeByte(TypeIdToType(type_id));
375 } else if (fp_dropout == 0 && std::regex_match(current_prim->name().c_str(), std::regex(DROPOUT))) {
376 fp_dropout = GetTypeByte(TypeIdToType(type_id));
377 }
378 if (fp_softmax > 0 && fp_dropout > 0) {
379 return {fp_dropout, fp_softmax};
380 }
381 }
382 }
383 }
384
385 return {fp_dropout, fp_softmax};
386 }
387
388 // Get FP format for params, optimizer, gradient, norm, softmax, dropout
GetFP_formats(const FuncGraphPtr & root)389 std::tuple<size_t, size_t, size_t, size_t, size_t, size_t> GetFP_formats(const FuncGraphPtr &root) {
390 size_t fp_params = 0;
391 size_t fp_optim = 0;
392 size_t fp_grads = 0;
393 size_t fp_norm = 0;
394 size_t fp_dropout = 0;
395 size_t fp_softmax = 0;
396
397 std::tie(fp_params, fp_optim, fp_grads, fp_norm) = GetFPFromParams(root);
398 std::tie(fp_dropout, fp_softmax) = GetFPFromNodes(root);
399
400 if (fp_params == 0 || fp_optim == 0 || fp_grads == 0 || fp_norm == 0) {
401 return {PARSING_FAILED, PARSING_FAILED, PARSING_FAILED, PARSING_FAILED, PARSING_FAILED, PARSING_FAILED};
402 }
403
404 return {fp_params, fp_optim, fp_grads, fp_norm, fp_dropout, fp_softmax};
405 }
406
GetExpansionRatio(const FuncGraphPtr & root)407 size_t GetExpansionRatio(const FuncGraphPtr &root) {
408 std::vector<AnfNodePtr> parameters = root->parameters();
409 for (auto &p : parameters) {
410 auto parameter_ptr = p->cast<ParameterPtr>();
411 Shapes param_shapes = GetNodeShape(p);
412 if (std::regex_match(parameter_ptr->name().c_str(), std::regex(".*0.feed_forward.*.weight")) ||
413 std::regex_match(parameter_ptr->name().c_str(), std::regex(".*0.output.projection.weight"))) {
414 return param_shapes[0][0];
415 }
416 }
417 return PARSING_FAILED;
418 }
419
420 // Get FP format for multi head attention block, norms, feed forward block
GetNumTransformerComponents(const FuncGraphPtr & root)421 std::tuple<size_t, size_t, size_t> GetNumTransformerComponents(const FuncGraphPtr &root) {
422 size_t n_mha = 0;
423 size_t n_ff = 0;
424 size_t n_norm = 0;
425
426 std::vector<AnfNodePtr> parameters = root->parameters();
427 for (auto &p : parameters) {
428 if (p == nullptr) {
429 continue;
430 }
431 auto parameter_ptr = p->cast<ParameterPtr>();
432 if (std::regex_match(parameter_ptr->name().c_str(), std::regex("^model.layers.0.attention.wo.weight")) ||
433 std::regex_match(parameter_ptr->name().c_str(), std::regex("^backbone.blocks.0.attention.dense1.weight"))) {
434 n_mha++;
435 } else if (std::regex_match(parameter_ptr->name().c_str(), std::regex("^model.layers.0.feed_forward.w1.weight")) ||
436 std::regex_match(parameter_ptr->name().c_str(),
437 std::regex("^backbone.blocks.0.output.mapping.weight"))) {
438 n_ff++;
439 } else if (std::regex_match(parameter_ptr->name().c_str(), std::regex("^model.layers.0.attention_norm.weight")) ||
440 std::regex_match(parameter_ptr->name().c_str(), std::regex("^model.layers.0.ffn_norm.weight")) ||
441 std::regex_match(parameter_ptr->name().c_str(), std::regex("^backbone.blocks.0.layernorm.*gamma"))) {
442 n_norm++;
443 }
444 }
445 if (n_mha > 0 && n_ff > 0 && n_norm > 0) {
446 return {n_mha, n_ff, n_norm};
447 }
448 return {PARSING_FAILED, PARSING_FAILED, PARSING_FAILED};
449 }
450
451 // Count weights matrixes in MHA and FF block
GetNumWeightsTransformer(const FuncGraphPtr & root)452 std::tuple<size_t, size_t> GetNumWeightsTransformer(const FuncGraphPtr &root) {
453 size_t n_weight_MHA = 0;
454 size_t n_weight_FF = 0;
455
456 std::vector<AnfNodePtr> parameters = root->parameters();
457 for (auto &p : parameters) {
458 if (p == nullptr) {
459 continue;
460 }
461 auto parameter_ptr = p->cast<ParameterPtr>();
462 if (std::regex_match(parameter_ptr->name().c_str(), std::regex("^model.layers.0.attention.w.*weight")) ||
463 std::regex_match(parameter_ptr->name().c_str(), std::regex("^backbone.blocks.0.attention.*weight"))) {
464 n_weight_MHA++;
465 } else if (std::regex_match(parameter_ptr->name().c_str(), std::regex("^model.layers.0.feed_forward.*.weight")) ||
466 std::regex_match(parameter_ptr->name().c_str(), std::regex("^backbone.blocks.0.output.*weight"))) {
467 n_weight_FF++;
468 }
469 }
470 if (n_weight_MHA > 0 && n_weight_FF > 0) {
471 return {n_weight_MHA, n_weight_FF};
472 }
473 return {PARSING_FAILED, PARSING_FAILED};
474 }
475
StageComputing(const FuncGraphPtr & r,const std::shared_ptr<Graph> & g,size_t device_num,size_t device_capacity,size_t hidden_size,size_t vocab_size,size_t seq_length,size_t head_num,size_t layer_num,size_t expansion_ratio,size_t dp,size_t mp,size_t pp,size_t per_batch,size_t micro,bool parallel_opt,bool recompute)476 StageComputing::StageComputing(const FuncGraphPtr &r, const std::shared_ptr<Graph> &g, size_t device_num,
477 size_t device_capacity, size_t hidden_size, size_t vocab_size, size_t seq_length,
478 size_t head_num, size_t layer_num, size_t expansion_ratio, size_t dp, size_t mp,
479 size_t pp, size_t per_batch, size_t micro, bool parallel_opt, bool recompute)
480 : root_(r),
481 graph_(g),
482 num_devices_(device_num),
483 device_capacity_(device_capacity),
484 vocab_size_(vocab_size),
485 seq_length_(seq_length),
486 hidden_size_(hidden_size),
487 attention_heads_(head_num),
488 num_layers_(layer_num),
489 expansion_ratio_(expansion_ratio),
490 parallel_opt_(parallel_opt),
491 recompute_(recompute),
492 dp_dim_(dp),
493 mp_dim_(mp),
494 pp_dim_(pp),
495 per_batch_(per_batch),
496 num_micros_(micro) {}
497
SaveConfig()498 void StageComputing::SaveConfig() {
499 saved_config_ = std::make_tuple(dp_dim_, mp_dim_, pp_dim_, per_batch_, num_micros_);
500 }
501
LoadConfig()502 void StageComputing::LoadConfig() { std::tie(dp_dim_, mp_dim_, pp_dim_, per_batch_, num_micros_) = saved_config_; }
503
504 // Generalization of num parameters for transformer-based, relying on parsing
NumParametersParsing(size_t l)505 size_t StageComputing::NumParametersParsing(size_t l) {
506 size_t n_weight_MHA;
507 size_t n_weight_FF;
508 std::tie(n_weight_MHA, n_weight_FF) = GetNumWeightsTransformer(root_);
509 size_t n_MHA;
510 size_t n_FF;
511 size_t n_norm;
512 std::tie(n_MHA, n_FF, n_norm) = GetNumTransformerComponents(root_);
513
514 if (n_weight_MHA == PARSING_FAILED || n_MHA == PARSING_FAILED) {
515 return PARSING_FAILED;
516 }
517 const size_t P_MHA = n_weight_MHA * (hidden_size_ * hidden_size_ + hidden_size_);
518 const size_t P_FF = n_weight_FF * (expansion_ratio_ * hidden_size_) + expansion_ratio_ + hidden_size_;
519 const size_t P_norm = 2 * hidden_size_;
520 const size_t P_linear = hidden_size_ * vocab_size_ + vocab_size_;
521 const size_t P_embedding = hidden_size_ * vocab_size_;
522
523 return l * (n_MHA * P_MHA + n_norm * P_norm + n_FF * P_FF) + P_linear + P_embedding;
524 }
525
526 // Generalization of static memory for transformer-based, relying on parsing
GetStaticMemoryParsing(size_t d,size_t t,size_t p,size_t P)527 size_t StageComputing::GetStaticMemoryParsing(size_t d, size_t t, size_t p, size_t P) {
528 size_t FP_params;
529 size_t FP_optimizer;
530 size_t FP_gradient;
531 std::tie(FP_params, FP_optimizer, FP_gradient, std::ignore, std::ignore, std::ignore) = GetFP_formats(root_);
532 if (FP_params == PARSING_FAILED) {
533 return PARSING_FAILED;
534 }
535 size_t model_params_size = (FP_params * P) / ((p == 1) ? (d * t) : t);
536 size_t accu_gradients_size = (FP_gradient * P) / ((p == 1 && parallel_opt_) ? (d * t) : t);
537 size_t optim_states_size = (2 * FP_optimizer * P) / ((parallel_opt_) ? (d * t) : t);
538 MS_LOG(DEBUG) << "model_params_size: " << TSepBytes(static_cast<size_t>(model_params_size));
539 MS_LOG(DEBUG) << "accu_gradients_size: " << TSepBytes(static_cast<size_t>(accu_gradients_size));
540 MS_LOG(DEBUG) << "optim_states_size: " << TSepBytes(static_cast<size_t>(optim_states_size));
541 return (model_params_size + accu_gradients_size + optim_states_size);
542 }
543
544 // Generalization of dynamic memory for transformer-based, relying on parsing
545 // Assuming seq parallelism for dropout and norms
546 // Assuming full recomputation
GetDynamicMemoryParsing(size_t l,size_t b,size_t m,size_t p,size_t t)547 size_t StageComputing::GetDynamicMemoryParsing(size_t l, size_t b, size_t m, size_t p, size_t t) {
548 size_t n_weight_MHA;
549 size_t n_weight_FF;
550 std::tie(n_weight_MHA, n_weight_FF) = GetNumWeightsTransformer(root_);
551 size_t FP_params;
552 size_t FP_optimizer;
553 size_t FP_gradient;
554 size_t FP_norm;
555 size_t FP_dropout;
556 size_t FP_softmax;
557 size_t n_MHA;
558 size_t n_FF;
559 size_t n_norm;
560 std::tie(FP_params, FP_optimizer, FP_gradient, FP_norm, FP_dropout, FP_softmax) = GetFP_formats(root_);
561 FP_softmax = 0;
562 std::tie(n_MHA, n_FF, n_norm) = GetNumTransformerComponents(root_);
563 MS_LOG(DEBUG) << "FP_params: " << FP_params << ", FP_optimizer: " << FP_optimizer << ", FP_gradient: " << FP_gradient;
564 MS_LOG(DEBUG) << "FP_dropout: " << FP_dropout << ", FP_softmax: " << FP_softmax << ", FP_norm: " << FP_norm;
565 MS_LOG(DEBUG) << "n_MHA: " << n_MHA << ", n_FF: " << n_FF << ", n_norm: " << n_norm;
566 MS_LOG(DEBUG) << "n_weight_MHA: " << n_weight_MHA << ", n_weight_FF: " << n_weight_FF;
567 float sbh = seq_length_ * b * hidden_size_;
568 float A_norm = n_norm * (FP_norm * sbh) / t;
569 float A_MHA =
570 n_MHA * ((n_weight_MHA * FP_params * sbh + FP_softmax * seq_length_ * seq_length_ * b * attention_heads_ +
571 FP_dropout * 2 * seq_length_ * seq_length_ * b * attention_heads_) /
572 t +
573 FP_dropout * sbh / t);
574 float A_FF = n_FF * (static_cast<float>(n_weight_FF * FP_params * seq_length_ * b * expansion_ratio_) / t +
575 FP_dropout * sbh / t);
576 float A_intermediate = A_norm + A_MHA + A_FF;
577 float nodes = static_cast<float>(num_devices_) / 8;
578 float A_input = ((p > 1 && nodes == 1) ? m : ceil(nodes / 4)) * sbh / t;
579 float n_Checkpoints = num_layers_ / 1.5;
580 float full_recompute_size = n_Checkpoints * A_input + (l / n_Checkpoints) * A_intermediate;
581
582 float communications_size = static_cast<float>(8 * seq_length_ * b * hidden_size_ * (t - 1)) / t;
583 MS_LOG(DEBUG) << "l: " << l;
584 MS_LOG(DEBUG) << "n_Checkpoints: " << n_Checkpoints;
585 MS_LOG(DEBUG) << "A_input: " << TSepBytes(static_cast<size_t>(A_input));
586 MS_LOG(DEBUG) << "A_norm: " << TSepBytes(static_cast<size_t>(A_norm));
587 MS_LOG(DEBUG) << "A_MHA: " << TSepBytes(static_cast<size_t>(A_MHA));
588 MS_LOG(DEBUG) << "A_FF: " << TSepBytes(static_cast<size_t>(A_FF));
589 MS_LOG(DEBUG) << "A_intermediate: " << TSepBytes(static_cast<size_t>(A_intermediate));
590 MS_LOG(DEBUG) << "l*communication: " << TSepBytes(static_cast<size_t>(l * communications_size));
591 MS_LOG(DEBUG) << "l/n_Checkpoints * A_intermediate: "
592 << TSepBytes(static_cast<size_t>((l / n_Checkpoints) * A_intermediate));
593 MS_LOG(DEBUG) << "n_Checkpoints * A_input: " << TSepBytes(static_cast<size_t>(n_Checkpoints * A_input));
594
595 return static_cast<size_t>(full_recompute_size + l * communications_size);
596 }
597
598 // Manually compute global batch size
GlobalBatchSize()599 size_t StageComputing::GlobalBatchSize() { return per_batch_ * num_micros_; }
600
601 // Get layer per stage
GetLayerPerStage()602 size_t StageComputing::GetLayerPerStage() {
603 return ceil(static_cast<float>(num_layers_) / static_cast<float>(pp_dim_));
604 }
605
GetMemory()606 size_t StageComputing::GetMemory() {
607 size_t P3 = NumParametersParsing(GetLayerPerStage());
608 if (P3 == PARSING_FAILED) {
609 return PARSING_FAILED;
610 }
611 size_t sMem3 = GetStaticMemoryParsing(dp_dim_, mp_dim_, pp_dim_, P3);
612 if (sMem3 == PARSING_FAILED) {
613 return PARSING_FAILED;
614 }
615 size_t dMem3 = GetDynamicMemoryParsing(GetLayerPerStage(), per_batch_, num_micros_, pp_dim_, mp_dim_);
616 PrintResults(sMem3, dMem3, P3);
617 return (sMem3 + dMem3);
618 }
619
620 // MS_LOG
PrintHyperparams()621 void StageComputing::PrintHyperparams() {
622 MS_LOG(INFO) << "Hyperparameters : h : " << hidden_size_ << ", s : " << seq_length_ << ", v : " << vocab_size_
623 << ", a : " << attention_heads_ << ", L : " << num_layers_ << ", pb:" << per_batch_
624 << ", B :" << GlobalBatchSize() << ", er: " << expansion_ratio_ << ", opt: " << parallel_opt_
625 << ", rcpt: " << recompute_;
626 }
627
Suggestion(const std::string & suggestion)628 void Suggestion(const std::string &suggestion) {
629 MS_LOG(INFO) << std::endl
630 << "=================== Auto Parallel Config Suggestion by SAPP ===================" << std::endl
631 << suggestion << std::endl
632 << "===============================================================================";
633 }
634
ParamSuggest(float mem_coeff,size_t stage,size_t batch,size_t micro)635 std::string ParamSuggest(float mem_coeff, size_t stage, size_t batch, size_t micro) {
636 std::stringstream ss;
637
638 ss << " mem_coeff: " << mem_coeff;
639 ss << ", pipeline_stage: " << stage << std::endl;
640 ss << " batch_size: " << batch;
641 ss << ", micro_batch_num: " << micro;
642
643 return ss.str();
644 }
645
FittingSuggestion()646 void StageComputing::FittingSuggestion() {
647 std::stringstream ss;
648
649 ss << " SAPP algorithm suggests the following parallel configuration:" << std::endl;
650 ss << ParamSuggest(CostModelContext::GetInstance()->rp_matmul_mem_coef(), pp_dim_, per_batch_, num_micros_);
651
652 Suggestion(ss.str());
653 }
654
OOMSuggestion()655 void StageComputing::OOMSuggestion() {
656 std::stringstream ss;
657
658 float default_coeff = 1024;
659 ss << " The current configuration seem to not fit in memory." << std::endl;
660 ss << " SAPP algorithm suggests to change configuration to:" << std::endl;
661 ss << ParamSuggest(default_coeff, pp_dim_, 1, num_micros_ * per_batch_);
662
663 Suggestion(ss.str());
664 }
665
ParsingException()666 void StageComputing::ParsingException() {
667 MS_LOG(WARNING) << "Something went wrong during the graph parsing process.";
668 MS_LOG(WARNING) << "SAPP algorithm uses original stage number";
669 }
670
PrintResults(size_t StaticMEM,size_t DynamicMEM,size_t num_param)671 void StageComputing::PrintResults(size_t StaticMEM, size_t DynamicMEM, size_t num_param) {
672 MS_LOG(INFO) << "DP: " << dp_dim_ << ", MP: " << mp_dim_ << ", Stages: " << pp_dim_ << ", n_micros: " << num_micros_
673 << ", Per_batch: " << per_batch_ << ", Global batch size: " << GlobalBatchSize()
674 << ", Num Parameters: " << num_param;
675 MS_LOG(INFO) << "StaticMEM: " << TSepBytes(StaticMEM) << ", DynamicMEM: " << TSepBytes(DynamicMEM)
676 << ", totalMEM: " << TSepBytes(StaticMEM + DynamicMEM);
677 }
678
CurrentEstimation()679 size_t StageComputing::CurrentEstimation() { return GetMemory(); }
680
fits(size_t memory)681 bool StageComputing::fits(size_t memory) { return ((MARGIN_FACTOR * memory) < device_capacity_); }
682
683 // Suggest a parallel config (dp,mp,pp) + batch (micro, per batch, dp)
684 // maintain global batch size
685 // memory function as cost model
FindSmallerStage()686 Status StageComputing::FindSmallerStage() {
687 if (pp_dim_ == 1) {
688 return FAILED;
689 }
690 SaveConfig();
691
692 bool saved = false;
693 size_t factor = 2;
694 double ratio = static_cast<float>(dp_dim_) / static_cast<float>(mp_dim_);
695 for (; pp_dim_ >= 1; pp_dim_ /= factor) {
696 if (fits(GetMemory()) && !saved) {
697 SaveConfig();
698 saved = true;
699 MS_LOG(INFO) << "Stage " << pp_dim_ << " is selected";
700 }
701 float ratio_factor_dp = abs(static_cast<float>(factor * dp_dim_) / mp_dim_ - ratio);
702 float ratio_factor_mp = abs(static_cast<float>(dp_dim_) / (factor * mp_dim_) - ratio);
703 if (mp_dim_ == 1)
704 dp_dim_ *= factor;
705 else if (dp_dim_ == 1)
706 mp_dim_ *= factor;
707 else if (ratio_factor_dp <= ratio_factor_mp)
708 dp_dim_ *= factor;
709 else
710 mp_dim_ *= factor;
711 }
712
713 LoadConfig();
714
715 if (!saved) {
716 return FAILED;
717 }
718 return SUCCESS;
719 }
720
721 // Estimation compute
LaunchStageCompute()722 size_t StageComputing::LaunchStageCompute() {
723 size_t pp = pp_dim_;
724 if (vocab_size_ == PARSING_FAILED || seq_length_ == PARSING_FAILED || expansion_ratio_ == PARSING_FAILED ||
725 num_layers_ == PARSING_FAILED || dp_dim_ == PARSING_FAILED || num_micros_ == PARSING_FAILED ||
726 per_batch_ == PARSING_FAILED) {
727 ParsingException();
728 return pp;
729 }
730
731 MS_LOG(INFO) << "Current stage number : " << pp;
732 MS_LOG(DEBUG) << "Number of devices: " << num_devices_;
733
734 PrintHyperparams();
735 if (CurrentEstimation() == PARSING_FAILED) {
736 ParsingException();
737 return pp;
738 }
739 if (FindSmallerStage() == FAILED) {
740 OOMSuggestion();
741 return pp;
742 } else {
743 FittingSuggestion();
744 }
745
746 return pp_dim_;
747 }
748
749 // Suggest a pipeline stage
ParallelSuggestion(const FuncGraphPtr & root,const std::shared_ptr<Graph> & graph)750 size_t ParallelSuggestion(const FuncGraphPtr &root, const std::shared_ptr<Graph> &graph) {
751 size_t vocab;
752 size_t seq;
753 size_t heads;
754 size_t dp;
755 size_t mp;
756 size_t pp;
757 size_t hidden;
758 size_t layers;
759 size_t devices;
760 size_t capacity;
761 size_t micros;
762 size_t per_batch;
763 size_t er;
764 bool opt;
765 bool recompute;
766
767 pp = static_cast<size_t>(parallel::ParallelContext::GetInstance()->pipeline_stage_split_num());
768 if (root == nullptr || graph == nullptr) {
769 MS_LOG(WARNING) << "Null costgraph or recgraph, ParallelSuggestion cannot run.";
770 MS_LOG(WARNING) << "SAPP algorithm uses original stage number.";
771 return pp;
772 }
773
774 std::tie(seq, heads) = GetSeqLengthAndAttentionHeads(root);
775 std::tie(dp, mp) = GetDPAndMP(graph, pp);
776 std::tie(hidden, vocab) = GetVocabAndHiddenSize(root);
777 er = GetExpansionRatio(root);
778 layers = GetNumLayers(root);
779 capacity = GetDeviceCapacity();
780 micros = GetNumMicro(root);
781
782 per_batch = GetPerBatch(root, seq);
783 devices = GetNumDevices();
784 opt = HasParallelOptimizer(root);
785 recompute = HasRecompute(root);
786
787 StageComputing sc(root, graph, devices, capacity, hidden, vocab, seq, heads, layers, er, dp, mp, pp, per_batch,
788 micros, opt, recompute);
789 pp = sc.LaunchStageCompute();
790 return pp;
791 }
792
IsGraphFilter(const AnfNodePtr & node)793 bool IsGraphFilter(const AnfNodePtr &node) { return !IsValueNode<FuncGraph>(node); }
794
795 // Update old stage number with suggestion
ChangeStageNumber(const FuncGraphPtr & root,size_t new_stage_num)796 void ChangeStageNumber(const FuncGraphPtr &root, size_t new_stage_num) {
797 size_t old_stage = static_cast<size_t>(parallel::ParallelContext::GetInstance()->pipeline_stage_split_num());
798 if (old_stage == new_stage_num) {
799 MS_LOG(INFO) << "Stage number " << new_stage_num << " is the same as the old value. Nothing changed.";
800 return;
801 }
802
803 if (old_stage % new_stage_num != 0) {
804 MS_LOG(WARNING) << "Stage number " << new_stage_num << " is not a divisor of the previous stage number "
805 << old_stage << ". Stage Number is NOT changed.";
806 return;
807 }
808
809 size_t change_factor = old_stage / new_stage_num;
810 MS_LOG(DEBUG) << "Old stage number:" << old_stage << " ; Change factor:" << change_factor;
811
812 FuncGraphPtr main_graph;
813 // Get main graph
814 auto manager = root->manager();
815 if (!root->has_flag(kTraining)) {
816 main_graph = root;
817 } else {
818 for (auto &fg : manager->func_graphs()) {
819 for (auto &node : fg->nodes()) {
820 if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
821 main_graph = fg;
822 break;
823 }
824 }
825 }
826 }
827
828 // Get all sub graphs
829 auto nodes = DeepScopedGraphSearchWithFilter(main_graph->get_return(), AlwaysInclude, IsGraphFilter);
830 std::reverse(nodes.begin(), nodes.end());
831 std::vector<FuncGraphPtr> subgraphs;
832 for (auto &node : nodes) {
833 auto graph = GetValueNode<FuncGraphPtr>(node);
834 subgraphs.push_back(graph);
835 }
836
837 // Update stage in all sub_graphs
838 for (auto &graph : subgraphs) {
839 int graph_old_stage = graph->stage();
840 if (graph_old_stage != -1) {
841 graph->set_stage(graph_old_stage / change_factor); // Either increase or decrease
842 }
843 }
844
845 // Update stage in parallel context
846 parallel::ParallelContext::GetInstance()->set_pipeline_stage_split_num(old_stage / change_factor);
847 MS_LOG(INFO) << "END ChangeStageNumber"
848 << ", new stage number: " << (old_stage / change_factor);
849 }
850
851 } // namespace parallel
852 } // namespace mindspore
853