1 /** 2 * Copyright 2019-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ 19 20 #include <string> 21 #include <vector> 22 #include <memory> 23 #include <utility> 24 #include "utils/hash_map.h" 25 #include "frontend/parallel/strategy.h" 26 #include "include/common/utils/parallel_context.h" 27 #include "frontend/parallel/tensor_layout/tensor_layout.h" 28 #include "frontend/parallel/tensor_layout/tensor_info.h" 29 #include "frontend/parallel/strategy_checkpoint/strategy_checkpoint_info.h" 30 31 namespace mindspore { 32 namespace parallel { 33 class StrategyCheckpoint { 34 public: StrategyCheckpoint()35 StrategyCheckpoint() { 36 load_file_ = ""; 37 save_file_ = ""; 38 group_info_save_file_ = ""; 39 auto_op_strategy_file_ = ""; 40 } 41 ~StrategyCheckpoint() = default; 42 43 Status Load(StrategyMap *strategy_map); 44 Status LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) const; 45 Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, 46 const ManualShapeMap &manual_shape_map); 47 Status SaveGroupInfo(const GroupInfoMap &group_info_map, const RankList &restore_rank_list); group_info_save_on()48 bool group_info_save_on() const { return group_info_save_on_; } 49 50 static StrategyCheckpoint &GetInstance(); LoadCheckPointOn()51 bool LoadCheckPointOn() const { return load_checkpoint_on_; } SaveCheckPointOn()52 bool SaveCheckPointOn() const { return save_checkpoint_on_; } 53 set_common_mirror_group(const RankList & comm_group)54 void set_common_mirror_group(const RankList &comm_group) { common_mirror_group_ = comm_group; } common_mirror_group()55 RankList common_mirror_group() const { return common_mirror_group_; } 56 LoadAutoOpStrategyOn()57 bool LoadAutoOpStrategyOn() const { return load_auto_op_strategy_on_; } SaveAutoOpStrategyOn()58 bool SaveAutoOpStrategyOn() const { return save_auto_op_strategy_on_; } 59 Status LoadAutoOpStrategy(StrategyMap *strategy_map); 60 Status SaveAutoOpStrategy(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, 61 const ManualShapeMap &manual_shape_map); 62 63 private: 64 std::string auto_op_strategy_file_; 65 std::string auto_op_strategy_file_type_; 66 bool load_auto_op_strategy_on_ = false; 67 bool save_auto_op_strategy_on_ = false; 68 StrategyJsonInfo strategy_json_info_; 69 70 std::string load_file_; 71 std::string save_file_; 72 bool load_checkpoint_on_ = false; 73 bool save_checkpoint_on_ = false; 74 bool CheckPointExit(const std::string path) const; 75 bool CheckPath(const std::string path) const; 76 int64_t current_stage_ = 0; 77 std::string group_info_save_file_; 78 bool group_info_save_on_ = false; 79 bool load_format_json_ = true; 80 bool save_format_json_ = true; 81 StrategyCheckpointInfo strategy_checkpoint_info_; 82 RankList common_mirror_group_; 83 }; 84 } // namespace parallel 85 } // namespace mindspore 86 87 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ 88