• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_
19 
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 #include <memory>
25 #include "frontend/parallel/ops_info/ops_utils.h"
26 #include "frontend/parallel/strategy.h"
27 #include "frontend/parallel/context.h"
28 #include "frontend/parallel/tensor_layout/tensor_layout.h"
29 #include "frontend/parallel/tensor_layout/tensor_info.h"
30 
31 namespace mindspore {
32 namespace parallel {
33 using StrategyMap = std::unordered_map<std::string, StrategyPtr>;
34 using TensorLayoutPtr = std::shared_ptr<TensorLayout>;
35 using TensorInfoMap = std::unordered_map<std::string, TensorLayoutPtr>;
36 using ParameterMap = std::vector<std::pair<std::string, ParameterPtr>>;
37 using ManualShapeMap = std::unordered_map<std::string, std::vector<std::pair<int64_t, int64_t>>>;
38 using GroupInfoMap = std::vector<std::pair<std::string, std::vector<uint32_t>>>;
39 class StrategyCheckpoint {
40  public:
StrategyCheckpoint()41   StrategyCheckpoint() {
42     current_stage_ = 0;
43     load_file_ = "";
44     load_checkpoint_on_ = false;
45     save_file_ = "";
46     save_checkpoint_on_ = false;
47     group_info_save_file_ = "";
48     group_info_save_on_ = false;
49   }
50   ~StrategyCheckpoint() = default;
51 
52   Status Load(StrategyMap *strategy_map);
53   Status LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map);
54   Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map);
55   Status SaveGroupInfo(const GroupInfoMap &group_info_map);
group_info_save_on()56   bool group_info_save_on() const { return group_info_save_on_; }
57 
58   static StrategyCheckpoint &GetInstance();
LoadCheckPointOn()59   bool LoadCheckPointOn() const { return load_checkpoint_on_; }
SaveCheckPointOn()60   bool SaveCheckPointOn() const { return save_checkpoint_on_; }
61 
62  private:
63   std::string load_file_;
64   std::string save_file_;
65   bool load_checkpoint_on_;
66   bool save_checkpoint_on_;
67   bool CheckPointExit(const std::string path) const;
68   bool CheckPath(const std::string path) const;
69   int64_t current_stage_;
70   std::string group_info_save_file_;
71   bool group_info_save_on_;
72 };
73 }  // namespace parallel
74 }  // namespace mindspore
75 
76 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_
77