1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_STAGE_COMPUTE_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_STAGE_COMPUTE_H_ 19 20 #include <tuple> 21 #include <memory> 22 23 #include "frontend/parallel/step_parallel_utils.h" 24 #include "include/common/utils/parallel_context.h" 25 #include "mindspore/core/ops/other_ops.h" 26 #include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" 27 #include "frontend/parallel/auto_parallel/rec_core/rec_strategy.h" 28 29 namespace mindspore { 30 namespace parallel { 31 32 // Get hyperparams 33 std::tuple<size_t, size_t> GetSeqLengthAndAttentionHeads(const FuncGraphPtr &root); 34 std::tuple<size_t, size_t> GetDPAndMP(const std::shared_ptr<Graph> &graph, const size_t stage); 35 std::tuple<size_t, size_t> GetVocabAndHiddenSize(const FuncGraphPtr &root); 36 size_t GetNumLayers(const FuncGraphPtr &root); 37 size_t GetNumMicro(const FuncGraphPtr &root); 38 size_t GetPerBatch(const FuncGraphPtr &root, size_t seq_l); 39 size_t GetNumDevices(); 40 41 class StageComputing { 42 private: 43 const FuncGraphPtr &root_; 44 const std::shared_ptr<Graph> &graph_; 45 // Hyperparameters 46 const size_t num_devices_ = 0; 47 const size_t device_capacity_ = 0; 48 const size_t vocab_size_ = 0; 49 const size_t seq_length_ = 0; 50 const size_t hidden_size_ = 0; 51 const size_t attention_heads_ = 0; 52 const size_t num_layers_ = 0; 53 const size_t expansion_ratio_ = 0; 54 55 const bool parallel_opt_ = 0; 56 const bool recompute_ = 0; 57 58 // Parallelism parameters 59 size_t dp_dim_ = 0; 60 size_t mp_dim_ = 0; 61 size_t pp_dim_ = 0; 62 size_t per_batch_ = 0; 63 size_t num_micros_ = 0; 64 65 std::tuple<size_t, size_t, size_t, size_t, size_t> saved_config_; 66 void SaveConfig(); 67 void LoadConfig(); 68 69 size_t NumParametersParsing(size_t l); 70 size_t GetStaticMemoryParsing(size_t d, size_t t, size_t p, size_t P); 71 size_t GetDynamicMemoryParsing(size_t l, size_t b, size_t m, size_t p, size_t t); 72 73 size_t GetLayerPerStage(); 74 size_t GetMemory(); 75 bool fits(size_t memory); 76 77 public: 78 StageComputing(const FuncGraphPtr &r, const std::shared_ptr<Graph> &g, size_t device_num, size_t device_capacity, 79 size_t hidden_size, size_t vocab_size, size_t seq_length, size_t head_num, size_t layer_num, 80 size_t expansion_ratio, size_t dp, size_t mp, size_t pp, size_t per_batch, size_t micro, 81 bool parallel_opt, bool recompute); 82 83 size_t GlobalBatchSize(); 84 size_t CurrentEstimation(); 85 Status FindSmallerStage(); 86 size_t LaunchStageCompute(); 87 void PrintHyperparams(); 88 void PrintResults(size_t StaticMEM, size_t DynamicMEM, size_t num_param); 89 void ParsingException(); 90 void OOMSuggestion(); 91 void FittingSuggestion(); 92 }; 93 94 size_t ParallelSuggestion(const FuncGraphPtr &root, const std::shared_ptr<Graph> &graph); 95 void ChangeStageNumber(const FuncGraphPtr &root, size_t new_stage_num); 96 97 } // namespace parallel 98 } // namespace mindspore 99 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_STAGE_COMPUTE_H_ 100