1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/device_manager.h"
18
19 #include <algorithm>
20 #include <string>
21 #include <vector>
22 #include <unordered_map>
23
24 #include "utils/hash_set.h"
25 #include "utils/ms_context.h"
26 #include "utils/log_adapter.h"
27
28 namespace mindspore {
29 namespace parallel {
30 DeviceManagerPtr g_device_manager = nullptr;
31
CheckDeviceConfig(int64_t device_num,int64_t global_rank,const std::string & backend,const std::vector<int64_t> & stage)32 bool CheckDeviceConfig(int64_t device_num, int64_t global_rank, const std::string &backend,
33 const std::vector<int64_t> &stage) {
34 if (device_num <= 0) {
35 MS_LOG(ERROR) << "The context configuration parameter 'device_num' must be positive, "
36 "but got the value of device_num: "
37 << device_num;
38 return false;
39 }
40 if (global_rank < 0) {
41 MS_LOG(ERROR) << "The context configuration parameter 'global_rank' must be nonnegative, "
42 "but got the value of global_rank: "
43 << global_rank;
44 return false;
45 }
46 if (device_num > MAX_DEVICE_NUM) {
47 MS_LOG(ERROR) << "The context configuration parameter 'device_num' must be no more than " << MAX_DEVICE_NUM
48 << ", but got the value of device_num: " << device_num;
49 return false;
50 }
51 // 'device_num_converted' must be divisible by 8
52 if (LongToSize(device_num) % DEVICE_NUM_PER_SERVER != 0 && device_num != 1 && device_num != 2 && device_num != 4) {
53 MS_LOG(ERROR) << "The context configuration parameter device_num' must be divisible by 8, "
54 "or equal to 1, 2 or 4, but got the value of device_num: "
55 << device_num;
56 return false;
57 }
58 if (global_rank >= device_num) {
59 MS_LOG(ERROR) << "The context configuration parameter 'global_rank' must be less than 'device_num', "
60 "but got the value of global_rank: "
61 << global_rank << ", and the value of device_num: " << device_num;
62 return false;
63 }
64 if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) {
65 MS_LOG(ERROR) << "For 'InitDevice', the argument 'backend' must be hccl, nccl "
66 "or undefined_backend, but got invalid backend: "
67 << backend;
68 return false;
69 }
70 if (stage.empty()) {
71 MS_LOG(ERROR) << "The size of parameter 'stage' must be positive, but got the size of stage is empty.";
72 return false;
73 }
74 return true;
75 }
76
InitDevice(int64_t device_num,int64_t global_rank,const std::string & backend,const std::vector<int64_t> & stage)77 bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &backend,
78 const std::vector<int64_t> &stage) {
79 if (!CheckDeviceConfig(device_num, global_rank, backend, stage)) {
80 return false;
81 }
82
83 RankList devices;
84 RankList stage_map;
85 for (int64_t i = 0; i < device_num; ++i) {
86 devices.push_back(i);
87 }
88
89 int64_t summed_value = 0;
90 for (auto begin = stage.begin(); begin != stage.end(); ++begin) {
91 if (*begin <= 0) {
92 MS_LOG(ERROR) << "The value in the pipeline stages should be positive value, but got the value: " << *begin;
93 return false;
94 }
95 summed_value += *begin;
96 stage_map.push_back(*begin);
97 }
98
99 if (summed_value != device_num) {
100 MS_LOG(ERROR) << "The sum of the pipeline stage must be equal to the device_num, "
101 "but got sum of the pipeline stage :"
102 << summed_value << " and the device_num : " << device_num;
103 return false;
104 }
105
106 for (auto &ele : stage_map) {
107 MS_LOG(DEBUG) << "Obtained stage id: " << ele;
108 }
109 if (g_device_manager) {
110 auto gm = g_device_manager->group_manager();
111 g_device_manager = std::make_shared<DeviceManager>();
112 g_device_manager->set_group_manager(gm);
113 } else {
114 g_device_manager = std::make_shared<DeviceManager>();
115 }
116 if (g_device_manager->Init(devices, global_rank, stage_map, backend) == SUCCESS) {
117 MS_LOG(INFO) << "Device initialization succeeds.";
118 MS_LOG(INFO) << "g_device_manager: DeviceNum: " << g_device_manager->DeviceNum();
119 return true;
120 }
121
122 MS_LOG(ERROR) << "Device initialization fails.";
123 return false;
124 }
125
CheckGlobalDeviceManager()126 void CheckGlobalDeviceManager() {
127 if (g_device_manager == nullptr) {
128 MS_LOG(EXCEPTION) << "Device information has not been set!";
129 }
130 }
131
GetListMemberByIndex(size_t index,const RankList & devices)132 int64_t GetListMemberByIndex(size_t index, const RankList &devices) {
133 size_t i = 0;
134 int64_t result = 0;
135 if ((devices.empty()) || (index >= devices.size())) {
136 MS_LOG(EXCEPTION) << "Index is out of the list scope";
137 }
138 auto it = devices.begin();
139 for (; it != devices.end(); ++it) {
140 if (i == index) {
141 result = *it;
142 break;
143 }
144 ++i;
145 }
146 return result;
147 }
148
GetListMemberByIndex(size_t index,const std::vector<std::shared_ptr<Device>> & device_list)149 std::shared_ptr<Device> GetListMemberByIndex(size_t index, const std::vector<std::shared_ptr<Device>> &device_list) {
150 size_t i = 0;
151 std::shared_ptr<Device> result;
152 if ((device_list.empty()) || (index >= device_list.size())) {
153 MS_LOG(EXCEPTION) << "Index is out of the list scope";
154 }
155 auto it = device_list.begin();
156 for (; it != device_list.end(); ++it) {
157 if (i == index) {
158 result = *it;
159 break;
160 }
161 ++i;
162 }
163 return result;
164 }
165
166 namespace {
167 constexpr int64_t NODE_PER_SERVER = 8;
IsFeasibleDeiveListOneServer(const RankList & rank_list)168 Status IsFeasibleDeiveListOneServer(const RankList &rank_list) {
169 if (rank_list.size() == 1 || rank_list.size() == NODE_PER_SERVER) {
170 return SUCCESS;
171 }
172 if (rank_list.size() == 4 && (rank_list[3] - rank_list[0] == 3) && (rank_list[0] == 0 || rank_list[3] == 7)) {
173 return SUCCESS;
174 }
175 if (rank_list.size() == 4 && (rank_list[3] % 4 == rank_list[1] % 4) && (rank_list[2] % 4 == rank_list[0] % 4)) {
176 return SUCCESS;
177 }
178 if (rank_list.size() == 2) {
179 if (rank_list[1] - rank_list[0] == 4) {
180 return SUCCESS;
181 }
182 if (rank_list[1] < 4 && rank_list[0] < 4) {
183 return SUCCESS;
184 }
185 if (rank_list[1] >= 4 && rank_list[0] >= 4) {
186 return SUCCESS;
187 }
188 }
189 return FAILED;
190 }
191
IsFeasibleDeiveList(const RankList & rank_list)192 Status IsFeasibleDeiveList(const RankList &rank_list) {
193 std::unordered_map<int64_t, RankList> server_ranks_map;
194 for (auto rank : rank_list) {
195 int64_t server_id = rank / NODE_PER_SERVER;
196 int64_t local_rank = rank % NODE_PER_SERVER;
197 server_ranks_map[server_id].push_back(local_rank);
198 }
199 std::vector<RankList> server_ranks_list;
200 (void)std::transform(server_ranks_map.begin(), server_ranks_map.end(), std::back_inserter(server_ranks_list),
201 [](auto pairs) { return pairs.second; });
202 auto server0_local_ranks = server_ranks_list[0];
203 bool is_all_server_same_count =
204 std::all_of(server_ranks_list.begin(), server_ranks_list.end(),
205 [&server0_local_ranks](auto ranks) { return ranks == server0_local_ranks; });
206 if (!is_all_server_same_count) {
207 MS_LOG(INFO) << "All server should has the same ranks, which means rank_id % 8 in each server should be the same. "
208 "current rank list is"
209 << rank_list;
210 return FAILED;
211 }
212 return IsFeasibleDeiveListOneServer(server0_local_ranks);
213 }
214 } // namespace
215
CheckDeviceList(const RankList & rank_list) const216 Status DeviceManager::CheckDeviceList(const RankList &rank_list) const {
217 auto ms_context = MsContext::GetInstance();
218 MS_EXCEPTION_IF_NULL(ms_context);
219 auto backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
220 auto soc_version = ms_context->ascend_soc_version();
221 if (backend == kAscendDevice && (soc_version.empty() || soc_version == kAscendVersion910)) {
222 return IsFeasibleDeiveList(rank_list);
223 }
224 return SUCCESS;
225 }
226
227 // E.g. devices = [0, 1, 2, 3, 4, 5, 6, 7], stage_map = [4, 4],
228 // therefore the stage_devices_ = [[0, 1, 2, 3], [4, 5, 6, 7]].
Init(const RankList & devices,int64_t global_device_rank,const RankList & stage_map,const std::string & backend)229 Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, const RankList &stage_map,
230 const std::string &backend) {
231 if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) {
232 MS_LOG(ERROR) << "For 'Init', the argument 'backend' must be hccl, nccl "
233 "or undefined_backend, but got invalid backend: "
234 << backend;
235 return FAILED;
236 }
237
238 if (stage_map.empty() || devices.empty()) {
239 MS_LOG(ERROR) << "The size of stage_map and devices must be positive, but got the size of stage_map: "
240 << stage_map.size() << ", and the size of devices : " << devices.size();
241 return FAILED;
242 }
243
244 devices_.clear();
245 stage_devices_.clear();
246
247 for (auto &dev : devices) {
248 std::shared_ptr<Device> one = std::make_shared<Device>(dev);
249 devices_.push_back(one);
250 }
251
252 size_t global_index = 0;
253 for (auto &stage : stage_map) {
254 int64_t num_device = stage;
255 if (num_device > MAX_DEVICE_NUM) {
256 MS_LOG(ERROR) << "The number of 'devices' in a stage must not be greater than " << MAX_DEVICE_NUM
257 << ", but got the number of 'devices' in a stage: " << num_device;
258 return FAILED;
259 }
260 if (num_device <= 0) {
261 MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive, but got the num_device: " << num_device;
262 return FAILED;
263 }
264 RankList curr_dev_list;
265 for (int64_t i = 0; i < num_device; ++i) {
266 curr_dev_list.push_back(GetListMemberByIndex(global_index, devices));
267 global_index++;
268 }
269 stage_devices_.push_back(curr_dev_list);
270 }
271
272 std::shared_ptr<Device> dev = std::make_shared<Device>(global_device_rank);
273 device_ = dev;
274
275 global_rank_ = global_device_rank;
276 stage_num_ = static_cast<const int64_t>(stage_map.size());
277 stage_id_ = global_device_rank / static_cast<const int64_t>(devices.size() / stage_map.size());
278 rank_index_in_stage_ = global_rank_ - stage_id_ * (static_cast<const int64_t>(devices.size()) / stage_num_);
279 stage_device_num_ = static_cast<const int64_t>(devices.size()) / stage_num_;
280
281 backend_ = backend;
282
283 if (backend == HCCL_BACKEND) {
284 gm_.set_world_group(HCCL_WORLD_GROUP);
285 } else if (backend_ == NCCL_BACKEND) {
286 gm_.set_world_group(NCCL_WORLD_GROUP);
287 } else {
288 gm_.set_world_group(UNDEFINED_WORLD_GROUP);
289 }
290 MS_LOG(INFO) << "The device num: " << devices.size() << ", rank id: " << global_device_rank
291 << ", the backend: " << backend << ", the stage num: " << stage_num_ << ", the stage id: " << stage_id_
292 << ", the rank index in stage is: " << rank_index_in_stage_;
293 return SUCCESS;
294 }
295
GetDeviceListInThisStage() const296 RankList DeviceManager::GetDeviceListInThisStage() const { return GetDeviceListByStageId(stage_id_); }
297
GetDeviceListBetweenStage() const298 RankList DeviceManager::GetDeviceListBetweenStage() const {
299 std::vector<int64_t> rank_list;
300 auto rank_id = g_device_manager->global_rank();
301 auto stage_id = g_device_manager->stage_id();
302 auto stage_num = g_device_manager->stage_num();
303 if (stage_num < 1) {
304 MS_LOG(EXCEPTION) << "Stage num got " << stage_num << ", expected a positive integer.";
305 }
306 auto device_num = DeviceNum();
307 auto per_stage_rank_num = device_num / LongToSize(stage_num);
308 for (int64_t i = 0; i < stage_num; ++i) {
309 rank_list.push_back(rank_id + SizeToLong(per_stage_rank_num) * (i - stage_id));
310 }
311 return rank_list;
312 }
313
GetDeviceListByStageId(int64_t stage_id) const314 RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const {
315 if (LongToSize(stage_id) >= stage_devices_.size()) {
316 MS_LOG(ERROR) << "the 'stage_id': " << stage_id
317 << ", is out of the scope of 'stage_devices_': " << stage_devices_.size();
318 }
319 RankList res;
320 int64_t index = 0;
321 for (auto &stage : stage_devices_) {
322 if (index == stage_id) {
323 return stage;
324 }
325 index++;
326 }
327 return res;
328 }
329
CreateNewDeviceByRank(int64_t rank) const330 Device DeviceManager::CreateNewDeviceByRank(int64_t rank) const { return Device(rank); }
331
CreateDeviceListByRankList(RankList ranks) const332 std::vector<Device> DeviceManager::CreateDeviceListByRankList(RankList ranks) const {
333 std::vector<Device> dev_list;
334 for (auto &rank : ranks) {
335 Device one = CreateNewDeviceByRank(rank);
336 dev_list.push_back(one);
337 }
338 return dev_list;
339 }
340
GetInstance()341 DeviceManager &DeviceManager::GetInstance() {
342 static DeviceManager instance = DeviceManager();
343 return instance;
344 }
345
FindRankListNameByHashName(const std::string & hash_name)346 std::string DeviceManager::FindRankListNameByHashName(const std::string &hash_name) {
347 std::string tmp = "WORLD_GROUP";
348 if ((hash_name == HCCL_WORLD_GROUP) || (hash_name == NCCL_WORLD_GROUP)) {
349 return tmp;
350 }
351 std::map<std::string, std::string>::const_iterator iter = group_to_rank_.find(hash_name);
352 if (iter == group_to_rank_.cend()) {
353 MS_LOG(INFO) << "Can not find the rank list name by hash name: " << hash_name;
354 return tmp;
355 }
356 return iter->second;
357 }
358
FindRankListByHashName(const std::string & hash_name)359 RankList DeviceManager::FindRankListByHashName(const std::string &hash_name) {
360 std::string rank_list_name = FindRankListNameByHashName(hash_name);
361 if (rank_list_name == "WORLD_GROUP") {
362 int64_t device_num = SizeToLong(g_device_manager->DeviceNum());
363 RankList rank_list;
364 for (size_t i = 0; i < size_t(device_num); ++i) {
365 rank_list.push_back(i);
366 }
367 return rank_list;
368 }
369 RankList rank_list;
370 std::string rank_str = "";
371 rank_list_name = rank_list_name + "-";
372 for (size_t i = 0; i < rank_list_name.size(); i++) {
373 if (rank_list_name[i] == '-') {
374 int64_t rank_id = std::atoi(rank_str.c_str());
375 rank_list.push_back(rank_id);
376 rank_str = "";
377 } else if (rank_list_name[i] <= '9' && rank_list_name[i] >= '0') {
378 rank_str.push_back(rank_list_name[i]);
379 } else {
380 MS_LOG(EXCEPTION) << "The rank list name cannot convert to rank list: " << rank_list_name;
381 }
382 }
383 return rank_list;
384 }
385
HashName(const std::string & origin_name)386 std::string HashName(const std::string &origin_name) { return std::to_string(std::hash<string>{}(origin_name)); }
387
RankListName(const RankList & ranks)388 std::string RankListName(const RankList &ranks) {
389 std::string rank_list_name;
390 for (auto it = ranks.begin(); it != ranks.end(); ++it) {
391 if (it == ranks.begin()) {
392 rank_list_name = std::to_string(*it);
393 } else {
394 rank_list_name += "-" + std::to_string(*it);
395 }
396 }
397 return rank_list_name;
398 }
399
400 // Group name is generated using the increasing ranks of the devices.
401 // E.g. the devices' ranks are '<0, 5, 3, 7, 1>', and the generated group name
402 // is '0-1-3-5-7'.
GenerateGroupNameByRanks(RankList ranks)403 std::string DeviceManager::GenerateGroupNameByRanks(RankList ranks) {
404 std::sort(ranks.begin(), ranks.end()); // sorted in increasing order
405 std::string rank_list_name = RankListName(ranks);
406
407 // hash rank-list-name and add ranks' size as prefix
408 std::string group_hash_name = HashName(rank_list_name);
409 std::string group_name = std::to_string(ranks.size()) + "-" + group_hash_name;
410
411 if (rank_to_group_.find(rank_list_name) == rank_to_group_.end()) {
412 if (group_to_rank_.find(group_name) == group_to_rank_.end()) {
413 rank_to_group_[rank_list_name] = group_name;
414 group_to_rank_[group_name] = rank_list_name;
415 MS_LOG(INFO) << "The rank list name is " << rank_list_name << " and group name is " << group_name;
416 } else {
417 MS_LOG(EXCEPTION) << "Hash collision, the current rank list: " << rank_list_name
418 << "the old rank list:" << group_to_rank_.find(group_name)->second
419 << "the group name: " << group_name;
420 }
421 }
422 return group_name;
423 }
424
425 // Create the group with the given devices and the given name. The GroupManager
426 // gm_ will create a new group only if there does not exit a group with the same
427 // name. Otherwise, let the pointer g point to that group.
CreateGroup(const std::string & group_name,const std::vector<mindspore::parallel::Device> & devices,Group * const comm_group)428 Status DeviceManager::CreateGroup(const std::string &group_name,
429 const std::vector<mindspore::parallel::Device> &devices, Group *const comm_group) {
430 RankList rank_list;
431 (void)std::transform(devices.begin(), devices.end(), std::back_inserter(rank_list),
432 [](const Device &device) { return device.rank(); });
433 if (CheckDeviceList(rank_list) != SUCCESS) {
434 MS_LOG(ERROR) << "Create communication group failed, the rank list is: " << rank_list;
435 return FAILED;
436 }
437 if (gm_.CreateGroup(group_name, devices, comm_group) != SUCCESS) {
438 return FAILED;
439 }
440 group_to_rank_[group_name] = RankListName(rank_list);
441 return SUCCESS;
442 }
443
444 // Create the group with only the given devices' ranks.
CreateGroup(const RankList & dev_ranks,Group * const comm_group)445 Status DeviceManager::CreateGroup(const RankList &dev_ranks, Group *const comm_group) {
446 mindspore::HashSet<int64_t> rank_set(dev_ranks.begin(), dev_ranks.end());
447 if (dev_ranks.size() != rank_set.size()) {
448 MS_LOG(ERROR) << "Invalid dev ranks(" << dev_ranks << "), it has the Duplicate elements in list";
449 return FAILED;
450 }
451 if (CheckDeviceList(dev_ranks) != SUCCESS) {
452 MS_LOG(ERROR) << "Create communication group failed, the rank list is: " << dev_ranks;
453 return FAILED;
454 }
455 std::string group_name = GenerateGroupNameByRanks(dev_ranks);
456 auto dev_list = CreateDeviceListByRankList(dev_ranks);
457 return CreateGroup(group_name, dev_list, comm_group);
458 }
459
Clear()460 void DeviceManager::Clear() {
461 devices_.clear();
462 stage_devices_.clear();
463 gm_.Clear();
464 }
465 } // namespace parallel
466 } // namespace mindspore
467