1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/device_manager.h"
18
19 #include <algorithm>
20 #include <string>
21 #include <unordered_set>
22 #include <vector>
23
24 #include "frontend/parallel/step_parallel.h"
25 #include "utils/log_adapter.h"
26
27 namespace mindspore {
28 namespace parallel {
29 DeviceManagerPtr g_device_manager = nullptr;
InitDevice(int64_t device_num,int64_t global_rank,const std::string & backend,const std::vector<int64_t> & stage)30 bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &backend,
31 const std::vector<int64_t> &stage) {
32 if (device_num <= 0) {
33 MS_LOG(ERROR) << "'device_num' must be positive.";
34 return false;
35 }
36 if (global_rank < 0) {
37 MS_LOG(ERROR) << "'global_rank' must be nonnegative.";
38 return false;
39 }
40 if (device_num > MAX_DEVICE_NUM) {
41 MS_LOG(ERROR) << "'device_num' must be no more than " << MAX_DEVICE_NUM << ".";
42 return false;
43 }
44 // 'device_num_converted' must be the power of 2
45 if ((LongToUlong(device_num) & LongToUlong(device_num - 1)) != 0) {
46 MS_LOG(ERROR) << "'device_num' must be the power of 2.";
47 return false;
48 }
49 if (global_rank >= device_num) {
50 MS_LOG(ERROR) << "'global_rank' must be less than 'device_num'.";
51 return false;
52 }
53 if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) {
54 MS_LOG(ERROR) << "Invalid backend: " << backend;
55 return false;
56 }
57 if (stage.empty()) {
58 MS_LOG(ERROR) << "The size of stage must be positive";
59 return false;
60 }
61
62 RankList devices, stage_map;
63 for (int64_t i = 0; i < device_num; ++i) {
64 devices.push_back(i);
65 }
66
67 int64_t summed_value = 0;
68 for (auto begin = stage.begin(); begin != stage.end(); ++begin) {
69 if (*begin <= 0) {
70 MS_LOG(ERROR) << "The value in the pipeline stages should be positive value";
71 return false;
72 }
73 summed_value += *begin;
74 stage_map.push_back(*begin);
75 }
76
77 if (summed_value != device_num) {
78 MS_LOG(ERROR) << "The sum of the pipeline stage :" << summed_value << " is not equal to the device_num "
79 << device_num;
80 return false;
81 }
82
83 for (auto &ele : stage_map) {
84 MS_LOG(DEBUG) << "Obtained stage id: " << ele;
85 }
86 if (g_device_manager) {
87 auto gm = g_device_manager->group_manager();
88 g_device_manager = std::make_shared<DeviceManager>();
89 g_device_manager->set_group_manager(gm);
90 } else {
91 g_device_manager = std::make_shared<DeviceManager>();
92 }
93 if (g_device_manager->Init(devices, global_rank, stage_map, backend) == SUCCESS) {
94 MS_LOG(INFO) << "Device initialization succeeds.";
95 return true;
96 }
97
98 MS_LOG(ERROR) << "Device initialization fails.";
99 return false;
100 }
101
CheckGlobalDeviceManager()102 void CheckGlobalDeviceManager() {
103 if (g_device_manager == nullptr) {
104 MS_LOG(EXCEPTION) << "Device information has not been set!";
105 }
106 }
107
GetListMemberByIndex(size_t index,const RankList & devices)108 int64_t GetListMemberByIndex(size_t index, const RankList &devices) {
109 size_t i = 0;
110 int64_t result = 0;
111 if ((devices.empty()) || (index >= devices.size())) {
112 MS_LOG(EXCEPTION) << "Index is out of the list scope";
113 }
114 auto it = devices.begin();
115 for (; it != devices.end(); ++it) {
116 if (i == index) {
117 result = *it;
118 break;
119 }
120 ++i;
121 }
122 return result;
123 }
124
GetListMemberByIndex(size_t index,const std::vector<std::shared_ptr<Device>> & device_list)125 std::shared_ptr<Device> GetListMemberByIndex(size_t index, const std::vector<std::shared_ptr<Device>> &device_list) {
126 size_t i = 0;
127 std::shared_ptr<Device> result;
128 if ((device_list.empty()) || (index >= device_list.size())) {
129 MS_LOG(EXCEPTION) << "Index is out of the list scope";
130 }
131 auto it = device_list.begin();
132 for (; it != device_list.end(); ++it) {
133 if (i == index) {
134 result = *it;
135 break;
136 }
137 ++i;
138 }
139 return result;
140 }
141
142 // E.g. devices = [0, 1, 2, 3, 4, 5, 6, 7], stage_map = [4, 4],
143 // therefore the stage_devices_ = [[0, 1, 2, 3], [4, 5, 6, 7]].
Init(const RankList & devices,int64_t global_device_rank,const RankList & stage_map,const std::string & backend)144 Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, const RankList &stage_map,
145 const std::string &backend) {
146 if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) {
147 MS_LOG(ERROR) << "Invalid backend: " << backend;
148 return FAILED;
149 }
150
151 if (stage_map.empty() || devices.empty()) {
152 MS_LOG(ERROR) << "The size of stage_map and devices must be positive";
153 return FAILED;
154 }
155
156 for (auto &dev : devices) {
157 std::shared_ptr<Device> one = std::make_shared<Device>(dev);
158 devices_.push_back(one);
159 }
160
161 size_t global_index = 0;
162 for (auto &stage : stage_map) {
163 int64_t num_device = stage;
164 if (num_device > MAX_DEVICE_NUM) {
165 MS_LOG(ERROR) << "The number of 'devices' in a stage must not be greater than " << MAX_DEVICE_NUM;
166 return FAILED;
167 }
168 if (num_device <= 0) {
169 MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive";
170 return FAILED;
171 }
172 RankList curr_dev_list;
173 for (int64_t i = 0; i < num_device; ++i) {
174 curr_dev_list.push_back(GetListMemberByIndex(global_index, devices));
175 global_index++;
176 }
177 stage_devices_.push_back(curr_dev_list);
178 }
179
180 std::shared_ptr<Device> dev = std::make_shared<Device>(global_device_rank);
181 device_ = dev;
182
183 global_rank_ = global_device_rank;
184 stage_num_ = static_cast<const int64_t>(stage_map.size());
185 stage_id_ = global_device_rank / static_cast<const int64_t>(devices.size() / stage_map.size());
186 rank_index_in_stage_ = global_rank_ - stage_id_ * (static_cast<const int64_t>(devices.size()) / stage_num_);
187 stage_device_num_ = static_cast<const int64_t>(devices.size()) / stage_num_;
188
189 backend_ = backend;
190
191 if (backend == HCCL_BACKEND) {
192 gm_.set_world_group(HCCL_WORLD_GROUP);
193 } else if (backend_ == NCCL_BACKEND) {
194 gm_.set_world_group(NCCL_WORLD_GROUP);
195 } else {
196 gm_.set_world_group(UNDEFINED_WORLD_GROUP);
197 }
198 MS_LOG(INFO) << "The device num: " << devices.size() << ", rank id: " << global_device_rank
199 << ", the backend: " << backend << ", the stage num: " << stage_num_ << ", the stage id: " << stage_id_
200 << ", the rank index in stage is: " << rank_index_in_stage_;
201 return SUCCESS;
202 }
203
GetDeviceListInThisStage() const204 RankList DeviceManager::GetDeviceListInThisStage() const { return GetDeviceListByStageId(stage_id_); }
205
GetDeviceListByStageId(int64_t stage_id) const206 RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const {
207 if (LongToSize(stage_id) >= stage_devices_.size())
208 MS_LOG(ERROR) << "the 'stage_id': " << stage_id
209 << ", is out of the scope of 'stage_devices_': " << stage_devices_.size();
210 RankList res;
211 int64_t index = 0;
212 for (auto &stage : stage_devices_) {
213 if (index == stage_id) {
214 return stage;
215 }
216 index++;
217 }
218 return res;
219 }
220
CreateNewDeviceByRank(int64_t rank) const221 Device DeviceManager::CreateNewDeviceByRank(int64_t rank) const { return Device(rank); }
222
CreateDeviceListByRankList(RankList ranks)223 std::vector<Device> DeviceManager::CreateDeviceListByRankList(RankList ranks) {
224 std::vector<Device> dev_list;
225 for (auto &rank : ranks) {
226 Device one = CreateNewDeviceByRank(rank);
227 dev_list.push_back(one);
228 }
229 return dev_list;
230 }
231
GetInstance()232 DeviceManager &DeviceManager::GetInstance() {
233 static DeviceManager instance = DeviceManager();
234 return instance;
235 }
236
FindRankListNameByHashName(const std::string & hash_name)237 std::string DeviceManager::FindRankListNameByHashName(const std::string &hash_name) {
238 std::string tmp = "WORLD_GROUP";
239 if ((hash_name == HCCL_WORLD_GROUP) || (hash_name == NCCL_WORLD_GROUP)) {
240 return tmp;
241 }
242 auto iter = group_to_rank_.find(hash_name);
243 if (iter == group_to_rank_.end()) {
244 MS_LOG(WARNING) << "Can not find the rank list name by hash name: " << hash_name;
245 return tmp;
246 }
247 return iter->second;
248 }
249
HashName(const std::string & origin_name)250 std::string HashName(const std::string &origin_name) { return std::to_string(std::hash<string>{}(origin_name)); }
251
252 // Group name is generated using the increasing ranks of the devices.
253 // E.g. the devices' ranks are '<0, 5, 3, 7, 1>', and the generated group name
254 // is '0-1-3-5-7'.
GenerateGroupNameByRanks(RankList ranks)255 std::string DeviceManager::GenerateGroupNameByRanks(RankList ranks) {
256 std::string rank_list_name;
257 std::vector<int64_t>::iterator it;
258 std::sort(ranks.begin(), ranks.end()); // sorted in increasing order
259 for (it = ranks.begin(); it != ranks.end(); ++it) {
260 if (it == ranks.begin()) {
261 rank_list_name = std::to_string(*it);
262 } else {
263 rank_list_name += "-" + std::to_string(*it);
264 }
265 }
266
267 // hash rank-list-name and add ranks' size as prefix
268 std::string group_hash_name = HashName(rank_list_name);
269 std::string group_name = std::to_string(ranks.size()) + "-" + group_hash_name;
270
271 if (rank_to_group_.find(rank_list_name) == rank_to_group_.end()) {
272 if (group_to_rank_.find(group_name) == group_to_rank_.end()) {
273 rank_to_group_[rank_list_name] = group_name;
274 group_to_rank_[group_name] = rank_list_name;
275 MS_LOG(INFO) << "The rank list name is " << rank_list_name << "nd group name is " << group_name;
276 } else {
277 MS_LOG(EXCEPTION) << "Hash collision, the current rank list: " << rank_list_name
278 << "the old rank list:" << group_to_rank_.find(group_name)->second
279 << "the group name: " << group_name;
280 }
281 }
282 return group_name;
283 }
284
285 // Create the group with the given devices and the given name. The GroupManager
286 // gm_ will create a new group only if there does not exit a group with the same
287 // name. Otherwise, let the pointer g point to that group.
CreateGroup(const std::string & group_name,const std::vector<mindspore::parallel::Device> & devices)288 Group DeviceManager::CreateGroup(const std::string &group_name,
289 const std::vector<mindspore::parallel::Device> &devices) {
290 Group g;
291 (void)gm_.CreateGroup(group_name, devices, &g);
292 return g;
293 }
294
295 // Create the group with only the given devices' ranks.
CreateGroup(const RankList & dev_ranks)296 Group DeviceManager::CreateGroup(const RankList &dev_ranks) {
297 std::unordered_set<int64_t> rank_set(dev_ranks.begin(), dev_ranks.end());
298 if (dev_ranks.size() != rank_set.size()) {
299 MS_LOG(EXCEPTION) << "Invalid dev ranks(" << dev_ranks << "), it has the Duplicate elements in list";
300 }
301
302 std::string group_name = GenerateGroupNameByRanks(dev_ranks);
303 auto dev_list = CreateDeviceListByRankList(dev_ranks);
304 return CreateGroup(group_name, dev_list);
305 }
306
Clear()307 void DeviceManager::Clear() {
308 devices_.clear();
309 stage_devices_.clear();
310 gm_.Clear();
311 }
312 } // namespace parallel
313 } // namespace mindspore
314