• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_
19 
20 #include <string>
21 #include <vector>
22 #include <memory>
23 #include <utility>
24 #include "utils/hash_map.h"
25 #include "frontend/parallel/strategy.h"
26 #include "include/common/utils/parallel_context.h"
27 #include "frontend/parallel/tensor_layout/tensor_layout.h"
28 #include "frontend/parallel/tensor_layout/tensor_info.h"
29 #include "frontend/parallel/strategy_checkpoint/strategy_checkpoint_info.h"
30 
31 namespace mindspore {
32 namespace parallel {
33 class StrategyCheckpoint {
34  public:
StrategyCheckpoint()35   StrategyCheckpoint() {
36     load_file_ = "";
37     save_file_ = "";
38     group_info_save_file_ = "";
39     auto_op_strategy_file_ = "";
40   }
41   ~StrategyCheckpoint() = default;
42 
43   Status Load(StrategyMap *strategy_map);
44   Status LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) const;
45   Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
46               const ManualShapeMap &manual_shape_map);
47   Status SaveGroupInfo(const GroupInfoMap &group_info_map, const RankList &restore_rank_list);
group_info_save_on()48   bool group_info_save_on() const { return group_info_save_on_; }
49 
50   static StrategyCheckpoint &GetInstance();
LoadCheckPointOn()51   bool LoadCheckPointOn() const { return load_checkpoint_on_; }
SaveCheckPointOn()52   bool SaveCheckPointOn() const { return save_checkpoint_on_; }
53 
set_common_mirror_group(const RankList & comm_group)54   void set_common_mirror_group(const RankList &comm_group) { common_mirror_group_ = comm_group; }
common_mirror_group()55   RankList common_mirror_group() const { return common_mirror_group_; }
56 
LoadAutoOpStrategyOn()57   bool LoadAutoOpStrategyOn() const { return load_auto_op_strategy_on_; }
SaveAutoOpStrategyOn()58   bool SaveAutoOpStrategyOn() const { return save_auto_op_strategy_on_; }
59   Status LoadAutoOpStrategy(StrategyMap *strategy_map);
60   Status SaveAutoOpStrategy(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
61                             const ManualShapeMap &manual_shape_map);
62 
63  private:
64   std::string auto_op_strategy_file_;
65   std::string auto_op_strategy_file_type_;
66   bool load_auto_op_strategy_on_ = false;
67   bool save_auto_op_strategy_on_ = false;
68   StrategyJsonInfo strategy_json_info_;
69 
70   std::string load_file_;
71   std::string save_file_;
72   bool load_checkpoint_on_ = false;
73   bool save_checkpoint_on_ = false;
74   bool CheckPointExit(const std::string path) const;
75   bool CheckPath(const std::string path) const;
76   int64_t current_stage_ = 0;
77   std::string group_info_save_file_;
78   bool group_info_save_on_ = false;
79   bool load_format_json_ = true;
80   bool save_format_json_ = true;
81   StrategyCheckpointInfo strategy_checkpoint_info_;
82   RankList common_mirror_group_;
83 };
84 }  // namespace parallel
85 }  // namespace mindspore
86 
87 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_
88