• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/strategy_checkpoint/strategy_checkpoint_info.h"
18 
19 #include <vector>
20 #include <utility>
21 
22 namespace mindspore {
23 namespace parallel {
set_strategy_map(const StrategyMap & strategy_map)24 void StrategyCheckpointInfo::set_strategy_map(const StrategyMap &strategy_map) { strategy_map_ = strategy_map; }
25 
set_tensor_info_map(const TensorInfoMap & tensor_info_map)26 void StrategyCheckpointInfo::set_tensor_info_map(const TensorInfoMap &tensor_info_map) {
27   tensor_info_map_ = tensor_info_map;
28 }
29 
set_manual_shape_map(const ManualShapeMap & manual_shape_map)30 void StrategyCheckpointInfo::set_manual_shape_map(const ManualShapeMap &manual_shape_map) {
31   manual_shape_map_ = manual_shape_map;
32 }
33 
FromJson(const nlohmann::json & stra_ckpt_info_j)34 void StrategyCheckpointInfo::FromJson(const nlohmann::json &stra_ckpt_info_j) {
35   current_stage_ = stra_ckpt_info_j.at("current_stage").get<int64_t>();
36   for (const auto &stra_j : stra_ckpt_info_j.at("parallel_strategy_item").items()) {
37     auto node_name = stra_j.key();
38     auto stage = stra_j.value().at("stage").get<int64_t>();
39     auto stra = stra_j.value().at("parallel_strategy").get<std::vector<std::vector<int64_t>>>();
40     strategy_map_[node_name] = std::make_shared<Strategy>(stage, stra);
41   }
42   for (const auto &layout_j : stra_ckpt_info_j.at("parallel_layout_item").items()) {
43     auto parameter_name = layout_j.key();
44     auto dev_matrix = layout_j.value().at("dev_matrix").get<std::vector<int64_t>>();
45     auto tensor_map = layout_j.value().at("tensor_map").get<std::vector<int64_t>>();
46     auto tensor_shape = layout_j.value().at("tensor_shape").get<std::vector<int64_t>>();
47     auto field = layout_j.value().at("field").get<int64_t>();
48     auto opt_weight_shard_step = layout_j.value().at("opt_weight_shard_step").get<int64_t>();
49     auto opt_weight_shard_size = layout_j.value().at("opt_weight_shard_size").get<int64_t>();
50     if (layout_j.value().contains("param_split_shape") && layout_j.value().contains("indices_offset")) {
51       auto param_split_shape = layout_j.value().at("param_split_shape").get<std::vector<int64_t>>();
52       auto indices_offset = layout_j.value().at("indices_offset").get<std::vector<int64_t>>();
53       if (param_split_shape.size() != indices_offset.size()) {
54         MS_LOG(EXCEPTION) << "For field_split strategy, the size of param_split_shape " << param_split_shape.size()
55                           << " is not equal to the size of indices_offset " << indices_offset.size();
56       }
57       for (size_t i = 0; i < param_split_shape.size(); ++i) {
58         manual_shape_map_[parameter_name].push_back({param_split_shape[i], indices_offset[i]});
59       }
60     }
61     tensor_info_map_[parameter_name] = std::make_shared<TensorLayout>();
62     (void)tensor_info_map_[parameter_name]->InitFromVector(dev_matrix, tensor_map, tensor_shape);
63     tensor_info_map_[parameter_name]->set_opt_weight_shard_size(opt_weight_shard_size);
64     tensor_info_map_[parameter_name]->set_opt_weight_shard_step(opt_weight_shard_step);
65     tensor_info_map_[parameter_name]->set_field_size(field);
66   }
67 }
68 
to_json() const69 nlohmann::json StrategyCheckpointInfo::to_json() const {
70   nlohmann::json stra_ckpt_info_j;
71   stra_ckpt_info_j["current_stage"] = current_stage_;
72   for (const auto &stra_pair : strategy_map_) {
73     auto node_name = stra_pair.first;
74     auto node_stra = stra_pair.second;
75     nlohmann::json stra_j;
76     stra_j["stage"] = node_stra->GetInputStage();
77     stra_j["parallel_strategy"] = node_stra->GetInputDim();
78     stra_ckpt_info_j["parallel_strategy_item"][node_name] = stra_j;
79   }
80   for (const auto &layout_pair : tensor_info_map_) {
81     auto parameter_name = layout_pair.first;
82     auto layout = layout_pair.second;
83     nlohmann::json layout_j;
84     layout_j["dev_matrix"] = layout->device_arrangement().array();
85     layout_j["tensor_map"] = layout->tensor_map().array();
86     layout_j["tensor_shape"] = layout->tensor_shape().array();
87     layout_j["field"] = layout->get_field_size();
88     layout_j["opt_weight_shard_step"] = layout->opt_weight_shard_step();
89     layout_j["opt_weight_shard_size"] = layout->opt_weight_shard_size();
90     if (manual_shape_map_.find(parameter_name) != manual_shape_map_.end()) {
91       auto manual_shape = manual_shape_map_.at(parameter_name);
92       for (auto dim_pair : manual_shape) {
93         layout_j["param_split_shape"].push_back(dim_pair.first);
94         layout_j["indices_offset"].push_back(dim_pair.second);
95       }
96     }
97     stra_ckpt_info_j["parallel_layout_item"][parameter_name] = layout_j;
98   }
99   return stra_ckpt_info_j;
100 }
101 
from_protobuf(const straspb::ParallelStrategyMap & parallel_strategy_map)102 void StrategyCheckpointInfo::from_protobuf(const straspb::ParallelStrategyMap &parallel_strategy_map) {
103   size_t node_num = LongToSize(parallel_strategy_map.parallel_strategy_item_size());
104   for (size_t i = 0; i < node_num; i++) {
105     straspb::ParallelStrategyItem parallel_strategy_item = parallel_strategy_map.parallel_strategy_item(SizeToInt(i));
106     std::string node_name = parallel_strategy_item.node_name();
107     straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys();
108     int64_t stage = SizeToLong(parallel_strategys.stage());
109     size_t strategys_num = LongToSize(parallel_strategys.parallel_strategy_size());
110     Strategies strategy_inputs;
111     for (size_t j = 0; j < strategys_num; j++) {
112       straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j));
113       Dimensions dimension;
114       size_t dim_num = LongToSize(parallel_strategy.dim_size());
115       for (size_t k = 0; k < dim_num; k++) {
116         dimension.push_back(parallel_strategy.dim(SizeToInt(k)));
117       }
118       strategy_inputs.push_back(dimension);
119     }
120     StrategyPtr strategy = NewStrategy(stage, strategy_inputs);
121     strategy_map_[node_name] = strategy;
122     current_stage_ = SizeToLong(parallel_strategy_map.current_stage());
123   }
124 }
125 
to_protobuf() const126 straspb::ParallelStrategyMap StrategyCheckpointInfo::to_protobuf() const {
127   straspb::ParallelStrategyMap parallel_strategy_map;
128   parallel_strategy_map.set_current_stage(UlongToUint(LongToUlong(current_stage_)));
129   for (auto &node_stra : strategy_map_) {
130     straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item();
131     MS_EXCEPTION_IF_NULL(parallel_strategy_item);
132     parallel_strategy_item->set_node_name(node_stra.first);
133     straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys();
134     MS_EXCEPTION_IF_NULL(parallel_strategys);
135     MS_EXCEPTION_IF_NULL(node_stra.second);
136     parallel_strategys->set_stage(UlongToUint(LongToUlong(node_stra.second->GetInputStage())));
137     for (auto &dims : node_stra.second->GetInputDim()) {
138       straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy();
139       MS_EXCEPTION_IF_NULL(parallel_strategy);
140       for (auto stra_dim : dims) {
141         parallel_strategy->add_dim(UlongToUint(LongToUlong(stra_dim)));
142       }
143     }
144   }
145   for (auto &node_tensor_info : tensor_info_map_) {
146     TensorLayoutPtr tensor_layout = node_tensor_info.second;
147     MS_EXCEPTION_IF_NULL(tensor_layout);
148     straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item();
149     MS_EXCEPTION_IF_NULL(parallel_layout_item);
150     parallel_layout_item->set_param_name(node_tensor_info.first);
151     straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts();
152     straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix();
153     MS_EXCEPTION_IF_NULL(dev_matrix);
154     for (auto dev_dim : tensor_layout->device_arrangement().array()) {
155       dev_matrix->add_dim(UlongToUint(LongToUlong(dev_dim)));
156     }
157     straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map();
158     MS_EXCEPTION_IF_NULL(tensor_map);
159     for (auto map_dim : tensor_layout->tensor_map().array()) {
160       tensor_map->add_dim(LongToInt(map_dim));
161     }
162     straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape();
163     straspb::IndicesOffset *indices_offset = parallel_layouts->add_indices_offset();
164     parallel_layouts->set_field(LongToInt(tensor_layout->get_field_size()));
165     parallel_layouts->set_opt_weight_shard_step(tensor_layout->opt_weight_shard_step());
166     parallel_layouts->set_opt_weight_shard_size(tensor_layout->opt_weight_shard_size());
167     if (manual_shape_map_.find(node_tensor_info.first) != manual_shape_map_.end()) {
168       auto manual_shape = manual_shape_map_.at(node_tensor_info.first);
169       for (auto dim_pair : manual_shape) {
170         param_split_shape->add_dim(dim_pair.first);
171         indices_offset->add_dim(dim_pair.second);
172       }
173     }
174   }
175   return parallel_strategy_map;
176 }
FromJson(const nlohmann::json & stra_json_info_j)177 void StrategyJsonInfo::FromJson(const nlohmann::json &stra_json_info_j) {
178   for (const auto &stra_j : stra_json_info_j.at("parallel_strategy_item").items()) {
179     auto node_name = stra_j.key();
180     auto stage = stra_j.value().at("stage").get<int64_t>();
181     auto stra = stra_j.value().at("parallel_strategy").get<std::vector<std::vector<int64_t>>>();
182     strategy_map_[node_name] = std::make_shared<Strategy>(stage, stra);
183   }
184 }
185 }  // namespace parallel
186 }  // namespace mindspore
187