• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
18 
19 #include <fstream>
20 #include <vector>
21 
22 #include "utils/ms_utils.h"
23 #include "utils/convert_utils.h"
24 #include "utils/log_adapter.h"
25 #include "debug/common.h"
26 #include "proto/node_strategy.pb.h"
27 
28 namespace mindspore {
29 namespace parallel {
GetInstance()30 StrategyCheckpoint &StrategyCheckpoint::GetInstance() {
31   static StrategyCheckpoint instance = StrategyCheckpoint();
32   if (ParallelContext::GetInstance() != nullptr) {
33     instance.load_file_ = ParallelContext::GetInstance()->strategy_ckpt_load_file();
34     instance.load_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_load_file().empty();
35     instance.save_file_ = ParallelContext::GetInstance()->strategy_ckpt_save_file();
36     instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty();
37     instance.group_info_save_file_ = ParallelContext::GetInstance()->group_ckpt_save_file();
38     instance.group_info_save_on_ = !ParallelContext::GetInstance()->group_ckpt_save_file().empty();
39   }
40   return instance;
41 }
42 
CheckPath(const std::string path) const43 bool StrategyCheckpoint::CheckPath(const std::string path) const {
44   if (path.size() > PATH_MAX) {
45     MS_LOG(ERROR) << "The checkpoit path " << path << " is too long";
46     return false;
47   }
48   auto realpath = Common::CreatePrefixPath(path);
49   if (!realpath.has_value()) {
50     MS_LOG(ERROR) << "Get real path failed, path=" << path;
51     return false;
52   }
53   return true;
54 }
55 
CheckPointExit(const std::string path) const56 bool StrategyCheckpoint::CheckPointExit(const std::string path) const {
57   std::ifstream fin(path);
58   if (fin) {
59     return true;
60   }
61   return false;
62 }
63 
LoadGroupInfo(const std::string & file,GroupInfoMap * group_info_map)64 Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) {
65   MS_EXCEPTION_IF_NULL(group_info_map);
66   if (!CheckPath(file)) {
67     MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
68   }
69   if (!CheckPointExit(file)) {
70     MS_LOG(EXCEPTION) << "CheckPoint file is not found";
71   }
72   straspb::ParallelGroupMap parallel_group_map;
73   std::fstream input(file, std::ios::in | std::ios::binary);
74   if (!parallel_group_map.ParseFromIstream(&input)) {
75     MS_LOG(ERROR) << "Load strategy file failed";
76     return FAILED;
77   }
78   input.close();
79 
80   size_t group_num = LongToSize(parallel_group_map.parallel_group_item_size());
81   for (size_t i = 0; i < group_num; ++i) {
82     straspb::ParallelGroupItem parallel_group_item = parallel_group_map.parallel_group_item(SizeToInt(i));
83     std::string group_name = parallel_group_item.group_name();
84 
85     straspb::ParallelGroupRanks parallel_group_ranks = parallel_group_item.parallel_group_ranks();
86     size_t rank_num = LongToSize(parallel_group_ranks.dim_size());
87     std::vector<uint32_t> ranks;
88     for (size_t j = 0; j < rank_num; ++j) {
89       uint32_t rank = parallel_group_ranks.dim(SizeToInt(j));
90       ranks.push_back(rank);
91     }
92 
93     std::pair<std::string, std::vector<uint32_t>> group = std::make_pair(group_name, ranks);
94     group_info_map->push_back(group);
95   }
96 
97   return SUCCESS;
98 }
99 
Load(StrategyMap * strategy_map)100 Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
101   if (strategy_map == nullptr) {
102     MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr";
103   }
104   if (!CheckPath(load_file_)) {
105     MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
106   }
107   if (!CheckPointExit(load_file_)) {
108     MS_LOG(EXCEPTION) << "CheckPoint file is not found";
109   }
110   straspb::ParallelStrategyMap parallel_strategy_map;
111   std::fstream input(load_file_, std::ios::in | std::ios::binary);
112   if (!parallel_strategy_map.ParseFromIstream(&input)) {
113     MS_LOG(ERROR) << "Load strategy file failed";
114     return FAILED;
115   }
116   input.close();
117   size_t node_num = LongToSize(parallel_strategy_map.parallel_strategy_item_size());
118   for (size_t i = 0; i < node_num; i++) {
119     straspb::ParallelStrategyItem parallel_strategy_item = parallel_strategy_map.parallel_strategy_item(SizeToInt(i));
120     std::string node_name = parallel_strategy_item.node_name();
121     straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys();
122     auto stage = (int64_t)parallel_strategys.stage();
123     size_t strategys_num = LongToSize(parallel_strategys.parallel_strategy_size());
124     Strategys strategy_inputs;
125     for (size_t j = 0; j < strategys_num; j++) {
126       straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j));
127       Dimensions dimension;
128       size_t dim_num = LongToSize(parallel_strategy.dim_size());
129       for (size_t k = 0; k < dim_num; k++) {
130         dimension.push_back(parallel_strategy.dim(SizeToInt(k)));
131       }
132       strategy_inputs.push_back(dimension);
133     }
134 
135     StrategyPtr strategy = NewStrategy(stage, strategy_inputs);
136     (*strategy_map)[node_name] = strategy;
137     current_stage_ = (int64_t)parallel_strategy_map.current_stage();
138   }
139   return SUCCESS;
140 }
141 
Save(const StrategyMap & strategy_map,const TensorInfoMap & tensor_info_map,ManualShapeMap * manual_shape_map)142 Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
143                                 ManualShapeMap *manual_shape_map) {
144   straspb::ParallelStrategyMap parallel_strategy_map;
145   parallel_strategy_map.set_current_stage(UlongToUint(LongToUlong(++current_stage_)));
146   for (auto &node_stra : strategy_map) {
147     straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item();
148     MS_EXCEPTION_IF_NULL(parallel_strategy_item);
149     parallel_strategy_item->set_node_name(node_stra.first);
150     straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys();
151     MS_EXCEPTION_IF_NULL(parallel_strategys);
152     MS_EXCEPTION_IF_NULL(node_stra.second);
153     parallel_strategys->set_stage(UlongToUint(LongToUlong(node_stra.second->GetInputStage())));
154     for (auto &dims : node_stra.second->GetInputDim()) {
155       straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy();
156       MS_EXCEPTION_IF_NULL(parallel_strategy);
157       for (auto stra_dim : dims) {
158         parallel_strategy->add_dim(UlongToUint(LongToUlong(stra_dim)));
159       }
160     }
161   }
162   for (auto &node_tensor_info : tensor_info_map) {
163     TensorLayoutPtr tensor_layout = node_tensor_info.second;
164     MS_EXCEPTION_IF_NULL(tensor_layout);
165     straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item();
166     MS_EXCEPTION_IF_NULL(parallel_layout_item);
167     parallel_layout_item->set_param_name(node_tensor_info.first);
168     straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts();
169     straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix();
170     MS_EXCEPTION_IF_NULL(dev_matrix);
171     for (auto dev_dim : tensor_layout->device_arrangement().array()) {
172       dev_matrix->add_dim(UlongToUint(LongToUlong(dev_dim)));
173     }
174     straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map();
175     MS_EXCEPTION_IF_NULL(tensor_map);
176     for (auto map_dim : tensor_layout->tensor_map().array()) {
177       tensor_map->add_dim(LongToInt(map_dim));
178     }
179     straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape();
180     straspb::IndicesOffset *indices_offset = parallel_layouts->add_indices_offset();
181     MS_EXCEPTION_IF_NULL(manual_shape_map);
182     auto manual_shape = (*manual_shape_map)[node_tensor_info.first];
183     for (auto dim_pair : manual_shape) {
184       param_split_shape->add_dim(dim_pair.first);
185       indices_offset->add_dim(dim_pair.second);
186     }
187     parallel_layouts->set_field(LongToInt(tensor_layout->get_field_size()));
188     parallel_layouts->set_opt_weight_shard_step(tensor_layout->opt_weight_shard_step());
189     parallel_layouts->set_opt_weight_shard_size(tensor_layout->opt_weight_shard_size());
190   }
191   if (!CheckPath(save_file_)) {
192     MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
193   }
194   std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
195   if (!parallel_strategy_map.SerializeToOstream(&output)) {
196     MS_LOG(ERROR) << "Save strategy file failed";
197     return FAILED;
198   }
199   output.close();
200   ChangeFileMode(save_file_, S_IRUSR | S_IWUSR);
201   return SUCCESS;
202 }
203 
SaveGroupInfo(const GroupInfoMap & group_info_map)204 Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) {
205   straspb::ParallelGroupMap parallel_group_map;
206   for (auto &group : group_info_map) {
207     straspb::ParallelGroupItem *parallel_group_item = parallel_group_map.add_parallel_group_item();
208     MS_EXCEPTION_IF_NULL(parallel_group_item);
209     parallel_group_item->set_group_name(group.first);
210     straspb::ParallelGroupRanks *parallel_group_ranks = parallel_group_item->mutable_parallel_group_ranks();
211     MS_EXCEPTION_IF_NULL(parallel_group_ranks);
212     for (auto &rank : group.second) {
213       parallel_group_ranks->add_dim(rank);
214     }
215   }
216   if (!CheckPath(group_info_save_file_)) {
217     MS_LOG(EXCEPTION) << "CheckPoint file in invalid";
218   }
219   std::fstream output(group_info_save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
220   if (!parallel_group_map.SerializeToOstream(&output)) {
221     MS_LOG(ERROR) << "Save strategy file failed";
222     return FAILED;
223   }
224   output.close();
225   ChangeFileMode(group_info_save_file_, S_IRUSR | S_IWUSR);
226   return SUCCESS;
227 }
228 }  // namespace parallel
229 }  // namespace mindspore
230