1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
18
19 #include <fstream>
20 #include <vector>
21
22 #include "utils/ms_utils.h"
23 #include "utils/convert_utils.h"
24 #include "utils/log_adapter.h"
25 #include "debug/common.h"
26 #include "proto/node_strategy.pb.h"
27
28 namespace mindspore {
29 namespace parallel {
GetInstance()30 StrategyCheckpoint &StrategyCheckpoint::GetInstance() {
31 static StrategyCheckpoint instance = StrategyCheckpoint();
32 if (ParallelContext::GetInstance() != nullptr) {
33 instance.load_file_ = ParallelContext::GetInstance()->strategy_ckpt_load_file();
34 instance.load_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_load_file().empty();
35 instance.save_file_ = ParallelContext::GetInstance()->strategy_ckpt_save_file();
36 instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty();
37 instance.group_info_save_file_ = ParallelContext::GetInstance()->group_ckpt_save_file();
38 instance.group_info_save_on_ = !ParallelContext::GetInstance()->group_ckpt_save_file().empty();
39 }
40 return instance;
41 }
42
CheckPath(const std::string path) const43 bool StrategyCheckpoint::CheckPath(const std::string path) const {
44 if (path.size() > PATH_MAX) {
45 MS_LOG(ERROR) << "The checkpoit path " << path << " is too long";
46 return false;
47 }
48 auto realpath = Common::CreatePrefixPath(path);
49 if (!realpath.has_value()) {
50 MS_LOG(ERROR) << "Get real path failed, path=" << path;
51 return false;
52 }
53 return true;
54 }
55
CheckPointExit(const std::string path) const56 bool StrategyCheckpoint::CheckPointExit(const std::string path) const {
57 std::ifstream fin(path);
58 if (fin) {
59 return true;
60 }
61 return false;
62 }
63
LoadGroupInfo(const std::string & file,GroupInfoMap * group_info_map)64 Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) {
65 MS_EXCEPTION_IF_NULL(group_info_map);
66 if (!CheckPath(file)) {
67 MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
68 }
69 if (!CheckPointExit(file)) {
70 MS_LOG(EXCEPTION) << "CheckPoint file is not found";
71 }
72 straspb::ParallelGroupMap parallel_group_map;
73 std::fstream input(file, std::ios::in | std::ios::binary);
74 if (!parallel_group_map.ParseFromIstream(&input)) {
75 MS_LOG(ERROR) << "Load strategy file failed";
76 return FAILED;
77 }
78 input.close();
79
80 size_t group_num = LongToSize(parallel_group_map.parallel_group_item_size());
81 for (size_t i = 0; i < group_num; ++i) {
82 straspb::ParallelGroupItem parallel_group_item = parallel_group_map.parallel_group_item(SizeToInt(i));
83 std::string group_name = parallel_group_item.group_name();
84
85 straspb::ParallelGroupRanks parallel_group_ranks = parallel_group_item.parallel_group_ranks();
86 size_t rank_num = LongToSize(parallel_group_ranks.dim_size());
87 std::vector<uint32_t> ranks;
88 for (size_t j = 0; j < rank_num; ++j) {
89 uint32_t rank = parallel_group_ranks.dim(SizeToInt(j));
90 ranks.push_back(rank);
91 }
92
93 std::pair<std::string, std::vector<uint32_t>> group = std::make_pair(group_name, ranks);
94 group_info_map->push_back(group);
95 }
96
97 return SUCCESS;
98 }
99
Load(StrategyMap * strategy_map)100 Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
101 if (strategy_map == nullptr) {
102 MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr";
103 }
104 if (!CheckPath(load_file_)) {
105 MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
106 }
107 if (!CheckPointExit(load_file_)) {
108 MS_LOG(EXCEPTION) << "CheckPoint file is not found";
109 }
110 straspb::ParallelStrategyMap parallel_strategy_map;
111 std::fstream input(load_file_, std::ios::in | std::ios::binary);
112 if (!parallel_strategy_map.ParseFromIstream(&input)) {
113 MS_LOG(ERROR) << "Load strategy file failed";
114 return FAILED;
115 }
116 input.close();
117 size_t node_num = LongToSize(parallel_strategy_map.parallel_strategy_item_size());
118 for (size_t i = 0; i < node_num; i++) {
119 straspb::ParallelStrategyItem parallel_strategy_item = parallel_strategy_map.parallel_strategy_item(SizeToInt(i));
120 std::string node_name = parallel_strategy_item.node_name();
121 straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys();
122 auto stage = (int64_t)parallel_strategys.stage();
123 size_t strategys_num = LongToSize(parallel_strategys.parallel_strategy_size());
124 Strategys strategy_inputs;
125 for (size_t j = 0; j < strategys_num; j++) {
126 straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j));
127 Dimensions dimension;
128 size_t dim_num = LongToSize(parallel_strategy.dim_size());
129 for (size_t k = 0; k < dim_num; k++) {
130 dimension.push_back(parallel_strategy.dim(SizeToInt(k)));
131 }
132 strategy_inputs.push_back(dimension);
133 }
134
135 StrategyPtr strategy = NewStrategy(stage, strategy_inputs);
136 (*strategy_map)[node_name] = strategy;
137 current_stage_ = (int64_t)parallel_strategy_map.current_stage();
138 }
139 return SUCCESS;
140 }
141
Save(const StrategyMap & strategy_map,const TensorInfoMap & tensor_info_map,ManualShapeMap * manual_shape_map)142 Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
143 ManualShapeMap *manual_shape_map) {
144 straspb::ParallelStrategyMap parallel_strategy_map;
145 parallel_strategy_map.set_current_stage(UlongToUint(LongToUlong(++current_stage_)));
146 for (auto &node_stra : strategy_map) {
147 straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item();
148 MS_EXCEPTION_IF_NULL(parallel_strategy_item);
149 parallel_strategy_item->set_node_name(node_stra.first);
150 straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys();
151 MS_EXCEPTION_IF_NULL(parallel_strategys);
152 MS_EXCEPTION_IF_NULL(node_stra.second);
153 parallel_strategys->set_stage(UlongToUint(LongToUlong(node_stra.second->GetInputStage())));
154 for (auto &dims : node_stra.second->GetInputDim()) {
155 straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy();
156 MS_EXCEPTION_IF_NULL(parallel_strategy);
157 for (auto stra_dim : dims) {
158 parallel_strategy->add_dim(UlongToUint(LongToUlong(stra_dim)));
159 }
160 }
161 }
162 for (auto &node_tensor_info : tensor_info_map) {
163 TensorLayoutPtr tensor_layout = node_tensor_info.second;
164 MS_EXCEPTION_IF_NULL(tensor_layout);
165 straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item();
166 MS_EXCEPTION_IF_NULL(parallel_layout_item);
167 parallel_layout_item->set_param_name(node_tensor_info.first);
168 straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts();
169 straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix();
170 MS_EXCEPTION_IF_NULL(dev_matrix);
171 for (auto dev_dim : tensor_layout->device_arrangement().array()) {
172 dev_matrix->add_dim(UlongToUint(LongToUlong(dev_dim)));
173 }
174 straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map();
175 MS_EXCEPTION_IF_NULL(tensor_map);
176 for (auto map_dim : tensor_layout->tensor_map().array()) {
177 tensor_map->add_dim(LongToInt(map_dim));
178 }
179 straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape();
180 straspb::IndicesOffset *indices_offset = parallel_layouts->add_indices_offset();
181 MS_EXCEPTION_IF_NULL(manual_shape_map);
182 auto manual_shape = (*manual_shape_map)[node_tensor_info.first];
183 for (auto dim_pair : manual_shape) {
184 param_split_shape->add_dim(dim_pair.first);
185 indices_offset->add_dim(dim_pair.second);
186 }
187 parallel_layouts->set_field(LongToInt(tensor_layout->get_field_size()));
188 parallel_layouts->set_opt_weight_shard_step(tensor_layout->opt_weight_shard_step());
189 parallel_layouts->set_opt_weight_shard_size(tensor_layout->opt_weight_shard_size());
190 }
191 if (!CheckPath(save_file_)) {
192 MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
193 }
194 std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
195 if (!parallel_strategy_map.SerializeToOstream(&output)) {
196 MS_LOG(ERROR) << "Save strategy file failed";
197 return FAILED;
198 }
199 output.close();
200 ChangeFileMode(save_file_, S_IRUSR | S_IWUSR);
201 return SUCCESS;
202 }
203
SaveGroupInfo(const GroupInfoMap & group_info_map)204 Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) {
205 straspb::ParallelGroupMap parallel_group_map;
206 for (auto &group : group_info_map) {
207 straspb::ParallelGroupItem *parallel_group_item = parallel_group_map.add_parallel_group_item();
208 MS_EXCEPTION_IF_NULL(parallel_group_item);
209 parallel_group_item->set_group_name(group.first);
210 straspb::ParallelGroupRanks *parallel_group_ranks = parallel_group_item->mutable_parallel_group_ranks();
211 MS_EXCEPTION_IF_NULL(parallel_group_ranks);
212 for (auto &rank : group.second) {
213 parallel_group_ranks->add_dim(rank);
214 }
215 }
216 if (!CheckPath(group_info_save_file_)) {
217 MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
218 }
219 std::fstream output(group_info_save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
220 if (!parallel_group_map.SerializeToOstream(&output)) {
221 MS_LOG(ERROR) << "Save strategy file failed";
222 return FAILED;
223 }
224 output.close();
225 ChangeFileMode(group_info_save_file_, S_IRUSR | S_IWUSR);
226 return SUCCESS;
227 }
228 } // namespace parallel
229 } // namespace mindspore
230