1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/strategy_checkpoint/strategy_checkpoint_info.h"
18
19 #include <vector>
20 #include <utility>
21
22 namespace mindspore {
23 namespace parallel {
set_strategy_map(const StrategyMap & strategy_map)24 void StrategyCheckpointInfo::set_strategy_map(const StrategyMap &strategy_map) { strategy_map_ = strategy_map; }
25
set_tensor_info_map(const TensorInfoMap & tensor_info_map)26 void StrategyCheckpointInfo::set_tensor_info_map(const TensorInfoMap &tensor_info_map) {
27 tensor_info_map_ = tensor_info_map;
28 }
29
set_manual_shape_map(const ManualShapeMap & manual_shape_map)30 void StrategyCheckpointInfo::set_manual_shape_map(const ManualShapeMap &manual_shape_map) {
31 manual_shape_map_ = manual_shape_map;
32 }
33
FromJson(const nlohmann::json & stra_ckpt_info_j)34 void StrategyCheckpointInfo::FromJson(const nlohmann::json &stra_ckpt_info_j) {
35 current_stage_ = stra_ckpt_info_j.at("current_stage").get<int64_t>();
36 for (const auto &stra_j : stra_ckpt_info_j.at("parallel_strategy_item").items()) {
37 auto node_name = stra_j.key();
38 auto stage = stra_j.value().at("stage").get<int64_t>();
39 auto stra = stra_j.value().at("parallel_strategy").get<std::vector<std::vector<int64_t>>>();
40 strategy_map_[node_name] = std::make_shared<Strategy>(stage, stra);
41 }
42 for (const auto &layout_j : stra_ckpt_info_j.at("parallel_layout_item").items()) {
43 auto parameter_name = layout_j.key();
44 auto dev_matrix = layout_j.value().at("dev_matrix").get<std::vector<int64_t>>();
45 auto tensor_map = layout_j.value().at("tensor_map").get<std::vector<int64_t>>();
46 auto tensor_shape = layout_j.value().at("tensor_shape").get<std::vector<int64_t>>();
47 auto field = layout_j.value().at("field").get<int64_t>();
48 auto opt_weight_shard_step = layout_j.value().at("opt_weight_shard_step").get<int64_t>();
49 auto opt_weight_shard_size = layout_j.value().at("opt_weight_shard_size").get<int64_t>();
50 if (layout_j.value().contains("param_split_shape") && layout_j.value().contains("indices_offset")) {
51 auto param_split_shape = layout_j.value().at("param_split_shape").get<std::vector<int64_t>>();
52 auto indices_offset = layout_j.value().at("indices_offset").get<std::vector<int64_t>>();
53 if (param_split_shape.size() != indices_offset.size()) {
54 MS_LOG(EXCEPTION) << "For field_split strategy, the size of param_split_shape " << param_split_shape.size()
55 << " is not equal to the size of indices_offset " << indices_offset.size();
56 }
57 for (size_t i = 0; i < param_split_shape.size(); ++i) {
58 manual_shape_map_[parameter_name].push_back({param_split_shape[i], indices_offset[i]});
59 }
60 }
61 tensor_info_map_[parameter_name] = std::make_shared<TensorLayout>();
62 (void)tensor_info_map_[parameter_name]->InitFromVector(dev_matrix, tensor_map, tensor_shape);
63 tensor_info_map_[parameter_name]->set_opt_weight_shard_size(opt_weight_shard_size);
64 tensor_info_map_[parameter_name]->set_opt_weight_shard_step(opt_weight_shard_step);
65 tensor_info_map_[parameter_name]->set_field_size(field);
66 }
67 }
68
to_json() const69 nlohmann::json StrategyCheckpointInfo::to_json() const {
70 nlohmann::json stra_ckpt_info_j;
71 stra_ckpt_info_j["current_stage"] = current_stage_;
72 for (const auto &stra_pair : strategy_map_) {
73 auto node_name = stra_pair.first;
74 auto node_stra = stra_pair.second;
75 nlohmann::json stra_j;
76 stra_j["stage"] = node_stra->GetInputStage();
77 stra_j["parallel_strategy"] = node_stra->GetInputDim();
78 stra_ckpt_info_j["parallel_strategy_item"][node_name] = stra_j;
79 }
80 for (const auto &layout_pair : tensor_info_map_) {
81 auto parameter_name = layout_pair.first;
82 auto layout = layout_pair.second;
83 nlohmann::json layout_j;
84 layout_j["dev_matrix"] = layout->device_arrangement().array();
85 layout_j["tensor_map"] = layout->tensor_map().array();
86 layout_j["tensor_shape"] = layout->tensor_shape().array();
87 layout_j["field"] = layout->get_field_size();
88 layout_j["opt_weight_shard_step"] = layout->opt_weight_shard_step();
89 layout_j["opt_weight_shard_size"] = layout->opt_weight_shard_size();
90 if (manual_shape_map_.find(parameter_name) != manual_shape_map_.end()) {
91 auto manual_shape = manual_shape_map_.at(parameter_name);
92 for (auto dim_pair : manual_shape) {
93 layout_j["param_split_shape"].push_back(dim_pair.first);
94 layout_j["indices_offset"].push_back(dim_pair.second);
95 }
96 }
97 stra_ckpt_info_j["parallel_layout_item"][parameter_name] = layout_j;
98 }
99 return stra_ckpt_info_j;
100 }
101
from_protobuf(const straspb::ParallelStrategyMap & parallel_strategy_map)102 void StrategyCheckpointInfo::from_protobuf(const straspb::ParallelStrategyMap ¶llel_strategy_map) {
103 size_t node_num = LongToSize(parallel_strategy_map.parallel_strategy_item_size());
104 for (size_t i = 0; i < node_num; i++) {
105 straspb::ParallelStrategyItem parallel_strategy_item = parallel_strategy_map.parallel_strategy_item(SizeToInt(i));
106 std::string node_name = parallel_strategy_item.node_name();
107 straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys();
108 int64_t stage = SizeToLong(parallel_strategys.stage());
109 size_t strategys_num = LongToSize(parallel_strategys.parallel_strategy_size());
110 Strategies strategy_inputs;
111 for (size_t j = 0; j < strategys_num; j++) {
112 straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j));
113 Dimensions dimension;
114 size_t dim_num = LongToSize(parallel_strategy.dim_size());
115 for (size_t k = 0; k < dim_num; k++) {
116 dimension.push_back(parallel_strategy.dim(SizeToInt(k)));
117 }
118 strategy_inputs.push_back(dimension);
119 }
120 StrategyPtr strategy = NewStrategy(stage, strategy_inputs);
121 strategy_map_[node_name] = strategy;
122 current_stage_ = SizeToLong(parallel_strategy_map.current_stage());
123 }
124 }
125
to_protobuf() const126 straspb::ParallelStrategyMap StrategyCheckpointInfo::to_protobuf() const {
127 straspb::ParallelStrategyMap parallel_strategy_map;
128 parallel_strategy_map.set_current_stage(UlongToUint(LongToUlong(current_stage_)));
129 for (auto &node_stra : strategy_map_) {
130 straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item();
131 MS_EXCEPTION_IF_NULL(parallel_strategy_item);
132 parallel_strategy_item->set_node_name(node_stra.first);
133 straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys();
134 MS_EXCEPTION_IF_NULL(parallel_strategys);
135 MS_EXCEPTION_IF_NULL(node_stra.second);
136 parallel_strategys->set_stage(UlongToUint(LongToUlong(node_stra.second->GetInputStage())));
137 for (auto &dims : node_stra.second->GetInputDim()) {
138 straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy();
139 MS_EXCEPTION_IF_NULL(parallel_strategy);
140 for (auto stra_dim : dims) {
141 parallel_strategy->add_dim(UlongToUint(LongToUlong(stra_dim)));
142 }
143 }
144 }
145 for (auto &node_tensor_info : tensor_info_map_) {
146 TensorLayoutPtr tensor_layout = node_tensor_info.second;
147 MS_EXCEPTION_IF_NULL(tensor_layout);
148 straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item();
149 MS_EXCEPTION_IF_NULL(parallel_layout_item);
150 parallel_layout_item->set_param_name(node_tensor_info.first);
151 straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts();
152 straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix();
153 MS_EXCEPTION_IF_NULL(dev_matrix);
154 for (auto dev_dim : tensor_layout->device_arrangement().array()) {
155 dev_matrix->add_dim(UlongToUint(LongToUlong(dev_dim)));
156 }
157 straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map();
158 MS_EXCEPTION_IF_NULL(tensor_map);
159 for (auto map_dim : tensor_layout->tensor_map().array()) {
160 tensor_map->add_dim(LongToInt(map_dim));
161 }
162 straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape();
163 straspb::IndicesOffset *indices_offset = parallel_layouts->add_indices_offset();
164 parallel_layouts->set_field(LongToInt(tensor_layout->get_field_size()));
165 parallel_layouts->set_opt_weight_shard_step(tensor_layout->opt_weight_shard_step());
166 parallel_layouts->set_opt_weight_shard_size(tensor_layout->opt_weight_shard_size());
167 if (manual_shape_map_.find(node_tensor_info.first) != manual_shape_map_.end()) {
168 auto manual_shape = manual_shape_map_.at(node_tensor_info.first);
169 for (auto dim_pair : manual_shape) {
170 param_split_shape->add_dim(dim_pair.first);
171 indices_offset->add_dim(dim_pair.second);
172 }
173 }
174 }
175 return parallel_strategy_map;
176 }
FromJson(const nlohmann::json & stra_json_info_j)177 void StrategyJsonInfo::FromJson(const nlohmann::json &stra_json_info_j) {
178 for (const auto &stra_j : stra_json_info_j.at("parallel_strategy_item").items()) {
179 auto node_name = stra_j.key();
180 auto stage = stra_j.value().at("stage").get<int64_t>();
181 auto stra = stra_j.value().at("parallel_strategy").get<std::vector<std::vector<int64_t>>>();
182 strategy_map_[node_name] = std::make_shared<Strategy>(stage, stra);
183 }
184 }
185 } // namespace parallel
186 } // namespace mindspore
187