• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_STRATEGY_CHECKPOINT_INFO_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_STRATEGY_CHECKPOINT_INFO_H_
19 
20 #include <string>
21 #include <vector>
22 #include <memory>
23 #include <utility>
24 #include "nlohmann/json.hpp"
25 #include "utils/hash_map.h"
26 #include "frontend/parallel/strategy.h"
27 #include "frontend/parallel/tensor_layout/tensor_layout.h"
28 #include "frontend/parallel/tensor_layout/tensor_info.h"
29 #include "proto/node_strategy.pb.h"
30 
31 namespace mindspore {
32 namespace parallel {
33 using StrategyMap = mindspore::HashMap<std::string, StrategyPtr>;
34 using TensorLayoutPtr = std::shared_ptr<TensorLayout>;
35 using TensorInfoMap = mindspore::HashMap<std::string, TensorLayoutPtr>;
36 using ParameterMap = std::vector<std::pair<std::string, ParameterPtr>>;
37 using ManualShapeMap = mindspore::HashMap<std::string, std::vector<std::pair<int64_t, int64_t>>>;
38 using GroupInfoMap = std::vector<std::pair<std::string, std::vector<uint32_t>>>;
39 
40 class StrategyCheckpointInfo {
41  public:
StrategyCheckpointInfo()42   StrategyCheckpointInfo() : current_stage_(0) {}
43   virtual ~StrategyCheckpointInfo() = default;
Init(const StrategyMap & strategy_map,const TensorInfoMap & tensor_info_map,const ManualShapeMap & manual_shape_map,int64_t current_stage)44   void Init(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
45             const ManualShapeMap &manual_shape_map, int64_t current_stage) {
46     strategy_map_ = strategy_map;
47     tensor_info_map_ = tensor_info_map;
48     manual_shape_map_ = manual_shape_map;
49     current_stage_ = current_stage;
50   }
strategy_map()51   StrategyMap strategy_map() const { return strategy_map_; }
52   void set_strategy_map(const StrategyMap &strategy_map);
tensor_info_map()53   TensorInfoMap tensor_info_map() const { return tensor_info_map_; }
54   void set_tensor_info_map(const TensorInfoMap &tensor_info_map);
manual_shape_map()55   ManualShapeMap manual_shape_map() const { return manual_shape_map_; }
56   void set_manual_shape_map(const ManualShapeMap &manual_shape_map);
current_stage()57   int64_t current_stage() const { return current_stage_; }
58 
59   virtual void FromJson(const nlohmann::json &stra_ckpt_info_j);
60   nlohmann::json to_json() const;
61 
62   void from_protobuf(const straspb::ParallelStrategyMap &parallel_strategy_map);
63   straspb::ParallelStrategyMap to_protobuf() const;
64 
65  protected:
66   StrategyMap strategy_map_;
67   int64_t current_stage_;
68   TensorInfoMap tensor_info_map_;
69   ManualShapeMap manual_shape_map_;
70 };
71 
72 class StrategyJsonInfo : public StrategyCheckpointInfo {
73  public:
StrategyJsonInfo()74   StrategyJsonInfo() : StrategyCheckpointInfo() {}
75   ~StrategyJsonInfo() override = default;
76 
77   void FromJson(const nlohmann::json &stra_json_info_j) override;
78 };
79 }  // namespace parallel
80 }  // namespace mindspore
81 
82 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_STRATEGY_CHECKPOINT_INFO_H_
83