1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_STRATEGY_CHECKPOINT_INFO_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_STRATEGY_CHECKPOINT_INFO_H_ 19 20 #include <string> 21 #include <vector> 22 #include <memory> 23 #include <utility> 24 #include "nlohmann/json.hpp" 25 #include "utils/hash_map.h" 26 #include "frontend/parallel/strategy.h" 27 #include "frontend/parallel/tensor_layout/tensor_layout.h" 28 #include "frontend/parallel/tensor_layout/tensor_info.h" 29 #include "proto/node_strategy.pb.h" 30 31 namespace mindspore { 32 namespace parallel { 33 using StrategyMap = mindspore::HashMap<std::string, StrategyPtr>; 34 using TensorLayoutPtr = std::shared_ptr<TensorLayout>; 35 using TensorInfoMap = mindspore::HashMap<std::string, TensorLayoutPtr>; 36 using ParameterMap = std::vector<std::pair<std::string, ParameterPtr>>; 37 using ManualShapeMap = mindspore::HashMap<std::string, std::vector<std::pair<int64_t, int64_t>>>; 38 using GroupInfoMap = std::vector<std::pair<std::string, std::vector<uint32_t>>>; 39 40 class StrategyCheckpointInfo { 41 public: StrategyCheckpointInfo()42 StrategyCheckpointInfo() : current_stage_(0) {} 43 virtual ~StrategyCheckpointInfo() = default; Init(const StrategyMap & strategy_map,const TensorInfoMap & tensor_info_map,const ManualShapeMap & manual_shape_map,int64_t current_stage)44 void Init(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, 45 const ManualShapeMap &manual_shape_map, int64_t current_stage) { 46 strategy_map_ = strategy_map; 47 tensor_info_map_ = tensor_info_map; 48 manual_shape_map_ = manual_shape_map; 49 current_stage_ = current_stage; 50 } strategy_map()51 StrategyMap strategy_map() const { return strategy_map_; } 52 void set_strategy_map(const StrategyMap &strategy_map); tensor_info_map()53 TensorInfoMap tensor_info_map() const { return tensor_info_map_; } 54 void set_tensor_info_map(const TensorInfoMap &tensor_info_map); manual_shape_map()55 ManualShapeMap manual_shape_map() const { return manual_shape_map_; } 56 void set_manual_shape_map(const ManualShapeMap &manual_shape_map); current_stage()57 int64_t current_stage() const { return current_stage_; } 58 59 virtual void FromJson(const nlohmann::json &stra_ckpt_info_j); 60 nlohmann::json to_json() const; 61 62 void from_protobuf(const straspb::ParallelStrategyMap ¶llel_strategy_map); 63 straspb::ParallelStrategyMap to_protobuf() const; 64 65 protected: 66 StrategyMap strategy_map_; 67 int64_t current_stage_; 68 TensorInfoMap tensor_info_map_; 69 ManualShapeMap manual_shape_map_; 70 }; 71 72 class StrategyJsonInfo : public StrategyCheckpointInfo { 73 public: StrategyJsonInfo()74 StrategyJsonInfo() : StrategyCheckpointInfo() {} 75 ~StrategyJsonInfo() override = default; 76 77 void FromJson(const nlohmann::json &stra_json_info_j) override; 78 }; 79 } // namespace parallel 80 } // namespace mindspore 81 82 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_CHEKCPOINT_STRATEGY_CHECKPOINT_INFO_H_ 83