• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/device_manager.h"
18 
19 #include <algorithm>
20 #include <string>
21 #include <vector>
22 #include <unordered_map>
23 
24 #include "utils/hash_set.h"
25 #include "utils/ms_context.h"
26 #include "utils/log_adapter.h"
27 
28 namespace mindspore {
29 namespace parallel {
30 DeviceManagerPtr g_device_manager = nullptr;
31 
CheckDeviceConfig(int64_t device_num,int64_t global_rank,const std::string & backend,const std::vector<int64_t> & stage)32 bool CheckDeviceConfig(int64_t device_num, int64_t global_rank, const std::string &backend,
33                        const std::vector<int64_t> &stage) {
34   if (device_num <= 0) {
35     MS_LOG(ERROR) << "The context configuration parameter 'device_num' must be positive, "
36                      "but got the value of device_num: "
37                   << device_num;
38     return false;
39   }
40   if (global_rank < 0) {
41     MS_LOG(ERROR) << "The context configuration parameter 'global_rank' must be nonnegative, "
42                      "but got the value of global_rank: "
43                   << global_rank;
44     return false;
45   }
46   if (device_num > MAX_DEVICE_NUM) {
47     MS_LOG(ERROR) << "The context configuration parameter 'device_num' must be no more than " << MAX_DEVICE_NUM
48                   << ", but got the value of device_num: " << device_num;
49     return false;
50   }
51   // 'device_num_converted' must be divisible by 8
52   if (LongToSize(device_num) % DEVICE_NUM_PER_SERVER != 0 && device_num != 1 && device_num != 2 && device_num != 4) {
53     MS_LOG(ERROR) << "The context configuration parameter device_num' must be divisible by 8, "
54                      "or equal to 1, 2 or 4, but got the value of device_num: "
55                   << device_num;
56     return false;
57   }
58   if (global_rank >= device_num) {
59     MS_LOG(ERROR) << "The context configuration parameter 'global_rank' must be less than 'device_num', "
60                      "but got the value of global_rank: "
61                   << global_rank << ", and the value of device_num: " << device_num;
62     return false;
63   }
64   if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) {
65     MS_LOG(ERROR) << "For 'InitDevice', the argument 'backend' must be hccl, nccl "
66                      "or undefined_backend, but got invalid backend: "
67                   << backend;
68     return false;
69   }
70   if (stage.empty()) {
71     MS_LOG(ERROR) << "The size of parameter 'stage' must be positive, but got the size of stage is empty.";
72     return false;
73   }
74   return true;
75 }
76 
InitDevice(int64_t device_num,int64_t global_rank,const std::string & backend,const std::vector<int64_t> & stage)77 bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &backend,
78                 const std::vector<int64_t> &stage) {
79   if (!CheckDeviceConfig(device_num, global_rank, backend, stage)) {
80     return false;
81   }
82 
83   RankList devices;
84   RankList stage_map;
85   for (int64_t i = 0; i < device_num; ++i) {
86     devices.push_back(i);
87   }
88 
89   int64_t summed_value = 0;
90   for (auto begin = stage.begin(); begin != stage.end(); ++begin) {
91     if (*begin <= 0) {
92       MS_LOG(ERROR) << "The value in the pipeline stages should be positive value, but got the value: " << *begin;
93       return false;
94     }
95     summed_value += *begin;
96     stage_map.push_back(*begin);
97   }
98 
99   if (summed_value != device_num) {
100     MS_LOG(ERROR) << "The sum of the pipeline stage must be equal to the device_num, "
101                      "but got sum of the pipeline stage :"
102                   << summed_value << " and the device_num : " << device_num;
103     return false;
104   }
105 
106   for (auto &ele : stage_map) {
107     MS_LOG(DEBUG) << "Obtained stage id: " << ele;
108   }
109   if (g_device_manager) {
110     auto gm = g_device_manager->group_manager();
111     g_device_manager = std::make_shared<DeviceManager>();
112     g_device_manager->set_group_manager(gm);
113   } else {
114     g_device_manager = std::make_shared<DeviceManager>();
115   }
116   if (g_device_manager->Init(devices, global_rank, stage_map, backend) == SUCCESS) {
117     MS_LOG(INFO) << "Device initialization succeeds.";
118     MS_LOG(INFO) << "g_device_manager: DeviceNum: " << g_device_manager->DeviceNum();
119     return true;
120   }
121 
122   MS_LOG(ERROR) << "Device initialization fails.";
123   return false;
124 }
125 
CheckGlobalDeviceManager()126 void CheckGlobalDeviceManager() {
127   if (g_device_manager == nullptr) {
128     MS_LOG(EXCEPTION) << "Device information has not been set!";
129   }
130 }
131 
GetListMemberByIndex(size_t index,const RankList & devices)132 int64_t GetListMemberByIndex(size_t index, const RankList &devices) {
133   size_t i = 0;
134   int64_t result = 0;
135   if ((devices.empty()) || (index >= devices.size())) {
136     MS_LOG(EXCEPTION) << "Index is out of the list scope";
137   }
138   auto it = devices.begin();
139   for (; it != devices.end(); ++it) {
140     if (i == index) {
141       result = *it;
142       break;
143     }
144     ++i;
145   }
146   return result;
147 }
148 
GetListMemberByIndex(size_t index,const std::vector<std::shared_ptr<Device>> & device_list)149 std::shared_ptr<Device> GetListMemberByIndex(size_t index, const std::vector<std::shared_ptr<Device>> &device_list) {
150   size_t i = 0;
151   std::shared_ptr<Device> result;
152   if ((device_list.empty()) || (index >= device_list.size())) {
153     MS_LOG(EXCEPTION) << "Index is out of the list scope";
154   }
155   auto it = device_list.begin();
156   for (; it != device_list.end(); ++it) {
157     if (i == index) {
158       result = *it;
159       break;
160     }
161     ++i;
162   }
163   return result;
164 }
165 
166 namespace {
167 constexpr int64_t NODE_PER_SERVER = 8;
IsFeasibleDeiveListOneServer(const RankList & rank_list)168 Status IsFeasibleDeiveListOneServer(const RankList &rank_list) {
169   if (rank_list.size() == 1 || rank_list.size() == NODE_PER_SERVER) {
170     return SUCCESS;
171   }
172   if (rank_list.size() == 4 && (rank_list[3] - rank_list[0] == 3) && (rank_list[0] == 0 || rank_list[3] == 7)) {
173     return SUCCESS;
174   }
175   if (rank_list.size() == 4 && (rank_list[3] % 4 == rank_list[1] % 4) && (rank_list[2] % 4 == rank_list[0] % 4)) {
176     return SUCCESS;
177   }
178   if (rank_list.size() == 2) {
179     if (rank_list[1] - rank_list[0] == 4) {
180       return SUCCESS;
181     }
182     if (rank_list[1] < 4 && rank_list[0] < 4) {
183       return SUCCESS;
184     }
185     if (rank_list[1] >= 4 && rank_list[0] >= 4) {
186       return SUCCESS;
187     }
188   }
189   return FAILED;
190 }
191 
IsFeasibleDeiveList(const RankList & rank_list)192 Status IsFeasibleDeiveList(const RankList &rank_list) {
193   std::unordered_map<int64_t, RankList> server_ranks_map;
194   for (auto rank : rank_list) {
195     int64_t server_id = rank / NODE_PER_SERVER;
196     int64_t local_rank = rank % NODE_PER_SERVER;
197     server_ranks_map[server_id].push_back(local_rank);
198   }
199   std::vector<RankList> server_ranks_list;
200   (void)std::transform(server_ranks_map.begin(), server_ranks_map.end(), std::back_inserter(server_ranks_list),
201                        [](auto pairs) { return pairs.second; });
202   auto server0_local_ranks = server_ranks_list[0];
203   bool is_all_server_same_count =
204     std::all_of(server_ranks_list.begin(), server_ranks_list.end(),
205                 [&server0_local_ranks](auto ranks) { return ranks == server0_local_ranks; });
206   if (!is_all_server_same_count) {
207     MS_LOG(INFO) << "All server should has the same ranks, which means rank_id % 8 in each server should be the same. "
208                     "current rank list is"
209                  << rank_list;
210     return FAILED;
211   }
212   return IsFeasibleDeiveListOneServer(server0_local_ranks);
213 }
214 }  // namespace
215 
CheckDeviceList(const RankList & rank_list) const216 Status DeviceManager::CheckDeviceList(const RankList &rank_list) const {
217   auto ms_context = MsContext::GetInstance();
218   MS_EXCEPTION_IF_NULL(ms_context);
219   auto backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
220   auto soc_version = ms_context->ascend_soc_version();
221   if (backend == kAscendDevice && (soc_version.empty() || soc_version == kAscendVersion910)) {
222     return IsFeasibleDeiveList(rank_list);
223   }
224   return SUCCESS;
225 }
226 
227 // E.g. devices = [0, 1, 2, 3, 4, 5, 6, 7], stage_map = [4, 4],
228 // therefore the stage_devices_ = [[0, 1, 2, 3], [4, 5, 6, 7]].
Init(const RankList & devices,int64_t global_device_rank,const RankList & stage_map,const std::string & backend)229 Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, const RankList &stage_map,
230                            const std::string &backend) {
231   if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) {
232     MS_LOG(ERROR) << "For 'Init', the argument 'backend' must be hccl, nccl "
233                      "or undefined_backend, but got invalid backend: "
234                   << backend;
235     return FAILED;
236   }
237 
238   if (stage_map.empty() || devices.empty()) {
239     MS_LOG(ERROR) << "The size of stage_map and devices must be positive, but got the size of stage_map: "
240                   << stage_map.size() << ", and the size of devices : " << devices.size();
241     return FAILED;
242   }
243 
244   devices_.clear();
245   stage_devices_.clear();
246 
247   for (auto &dev : devices) {
248     std::shared_ptr<Device> one = std::make_shared<Device>(dev);
249     devices_.push_back(one);
250   }
251 
252   size_t global_index = 0;
253   for (auto &stage : stage_map) {
254     int64_t num_device = stage;
255     if (num_device > MAX_DEVICE_NUM) {
256       MS_LOG(ERROR) << "The number of 'devices' in a stage must not be greater than " << MAX_DEVICE_NUM
257                     << ", but got the number of 'devices' in a stage: " << num_device;
258       return FAILED;
259     }
260     if (num_device <= 0) {
261       MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive, but got the num_device: " << num_device;
262       return FAILED;
263     }
264     RankList curr_dev_list;
265     for (int64_t i = 0; i < num_device; ++i) {
266       curr_dev_list.push_back(GetListMemberByIndex(global_index, devices));
267       global_index++;
268     }
269     stage_devices_.push_back(curr_dev_list);
270   }
271 
272   std::shared_ptr<Device> dev = std::make_shared<Device>(global_device_rank);
273   device_ = dev;
274 
275   global_rank_ = global_device_rank;
276   stage_num_ = static_cast<const int64_t>(stage_map.size());
277   stage_id_ = global_device_rank / static_cast<const int64_t>(devices.size() / stage_map.size());
278   rank_index_in_stage_ = global_rank_ - stage_id_ * (static_cast<const int64_t>(devices.size()) / stage_num_);
279   stage_device_num_ = static_cast<const int64_t>(devices.size()) / stage_num_;
280 
281   backend_ = backend;
282 
283   if (backend == HCCL_BACKEND) {
284     gm_.set_world_group(HCCL_WORLD_GROUP);
285   } else if (backend_ == NCCL_BACKEND) {
286     gm_.set_world_group(NCCL_WORLD_GROUP);
287   } else {
288     gm_.set_world_group(UNDEFINED_WORLD_GROUP);
289   }
290   MS_LOG(INFO) << "The device num: " << devices.size() << ", rank id: " << global_device_rank
291                << ", the backend: " << backend << ", the stage num: " << stage_num_ << ", the stage id: " << stage_id_
292                << ", the rank index in stage is: " << rank_index_in_stage_;
293   return SUCCESS;
294 }
295 
GetDeviceListInThisStage() const296 RankList DeviceManager::GetDeviceListInThisStage() const { return GetDeviceListByStageId(stage_id_); }
297 
GetDeviceListBetweenStage() const298 RankList DeviceManager::GetDeviceListBetweenStage() const {
299   std::vector<int64_t> rank_list;
300   auto rank_id = g_device_manager->global_rank();
301   auto stage_id = g_device_manager->stage_id();
302   auto stage_num = g_device_manager->stage_num();
303   if (stage_num < 1) {
304     MS_LOG(EXCEPTION) << "Stage num got " << stage_num << ", expected a positive integer.";
305   }
306   auto device_num = DeviceNum();
307   auto per_stage_rank_num = device_num / LongToSize(stage_num);
308   for (int64_t i = 0; i < stage_num; ++i) {
309     rank_list.push_back(rank_id + SizeToLong(per_stage_rank_num) * (i - stage_id));
310   }
311   return rank_list;
312 }
313 
GetDeviceListByStageId(int64_t stage_id) const314 RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const {
315   if (LongToSize(stage_id) >= stage_devices_.size()) {
316     MS_LOG(ERROR) << "the 'stage_id': " << stage_id
317                   << ", is out of the scope of 'stage_devices_': " << stage_devices_.size();
318   }
319   RankList res;
320   int64_t index = 0;
321   for (auto &stage : stage_devices_) {
322     if (index == stage_id) {
323       return stage;
324     }
325     index++;
326   }
327   return res;
328 }
329 
CreateNewDeviceByRank(int64_t rank) const330 Device DeviceManager::CreateNewDeviceByRank(int64_t rank) const { return Device(rank); }
331 
CreateDeviceListByRankList(RankList ranks) const332 std::vector<Device> DeviceManager::CreateDeviceListByRankList(RankList ranks) const {
333   std::vector<Device> dev_list;
334   for (auto &rank : ranks) {
335     Device one = CreateNewDeviceByRank(rank);
336     dev_list.push_back(one);
337   }
338   return dev_list;
339 }
340 
GetInstance()341 DeviceManager &DeviceManager::GetInstance() {
342   static DeviceManager instance = DeviceManager();
343   return instance;
344 }
345 
FindRankListNameByHashName(const std::string & hash_name)346 std::string DeviceManager::FindRankListNameByHashName(const std::string &hash_name) {
347   std::string tmp = "WORLD_GROUP";
348   if ((hash_name == HCCL_WORLD_GROUP) || (hash_name == NCCL_WORLD_GROUP)) {
349     return tmp;
350   }
351   std::map<std::string, std::string>::const_iterator iter = group_to_rank_.find(hash_name);
352   if (iter == group_to_rank_.cend()) {
353     MS_LOG(INFO) << "Can not find the rank list name by hash name: " << hash_name;
354     return tmp;
355   }
356   return iter->second;
357 }
358 
FindRankListByHashName(const std::string & hash_name)359 RankList DeviceManager::FindRankListByHashName(const std::string &hash_name) {
360   std::string rank_list_name = FindRankListNameByHashName(hash_name);
361   if (rank_list_name == "WORLD_GROUP") {
362     int64_t device_num = SizeToLong(g_device_manager->DeviceNum());
363     RankList rank_list;
364     for (size_t i = 0; i < size_t(device_num); ++i) {
365       rank_list.push_back(i);
366     }
367     return rank_list;
368   }
369   RankList rank_list;
370   std::string rank_str = "";
371   rank_list_name = rank_list_name + "-";
372   for (size_t i = 0; i < rank_list_name.size(); i++) {
373     if (rank_list_name[i] == '-') {
374       int64_t rank_id = std::atoi(rank_str.c_str());
375       rank_list.push_back(rank_id);
376       rank_str = "";
377     } else if (rank_list_name[i] <= '9' && rank_list_name[i] >= '0') {
378       rank_str.push_back(rank_list_name[i]);
379     } else {
380       MS_LOG(EXCEPTION) << "The rank list name cannot convert to rank list: " << rank_list_name;
381     }
382   }
383   return rank_list;
384 }
385 
HashName(const std::string & origin_name)386 std::string HashName(const std::string &origin_name) { return std::to_string(std::hash<string>{}(origin_name)); }
387 
RankListName(const RankList & ranks)388 std::string RankListName(const RankList &ranks) {
389   std::string rank_list_name;
390   for (auto it = ranks.begin(); it != ranks.end(); ++it) {
391     if (it == ranks.begin()) {
392       rank_list_name = std::to_string(*it);
393     } else {
394       rank_list_name += "-" + std::to_string(*it);
395     }
396   }
397   return rank_list_name;
398 }
399 
400 // Group name is generated using the increasing ranks of the devices.
401 // E.g. the devices' ranks are '<0, 5, 3, 7, 1>', and the generated group name
402 // is '0-1-3-5-7'.
GenerateGroupNameByRanks(RankList ranks)403 std::string DeviceManager::GenerateGroupNameByRanks(RankList ranks) {
404   std::sort(ranks.begin(), ranks.end());  // sorted in increasing order
405   std::string rank_list_name = RankListName(ranks);
406 
407   // hash rank-list-name and add ranks' size as prefix
408   std::string group_hash_name = HashName(rank_list_name);
409   std::string group_name = std::to_string(ranks.size()) + "-" + group_hash_name;
410 
411   if (rank_to_group_.find(rank_list_name) == rank_to_group_.end()) {
412     if (group_to_rank_.find(group_name) == group_to_rank_.end()) {
413       rank_to_group_[rank_list_name] = group_name;
414       group_to_rank_[group_name] = rank_list_name;
415       MS_LOG(INFO) << "The rank list name is " << rank_list_name << " and group name is " << group_name;
416     } else {
417       MS_LOG(EXCEPTION) << "Hash collision, the current rank list: " << rank_list_name
418                         << "the old rank list:" << group_to_rank_.find(group_name)->second
419                         << "the group name: " << group_name;
420     }
421   }
422   return group_name;
423 }
424 
425 // Create the group with the given devices and the given name. The GroupManager
426 // gm_ will create a new group only if there does not exit a group with the same
427 // name. Otherwise, let the pointer g point to that group.
CreateGroup(const std::string & group_name,const std::vector<mindspore::parallel::Device> & devices,Group * const comm_group)428 Status DeviceManager::CreateGroup(const std::string &group_name,
429                                   const std::vector<mindspore::parallel::Device> &devices, Group *const comm_group) {
430   RankList rank_list;
431   (void)std::transform(devices.begin(), devices.end(), std::back_inserter(rank_list),
432                        [](const Device &device) { return device.rank(); });
433   if (CheckDeviceList(rank_list) != SUCCESS) {
434     MS_LOG(ERROR) << "Create communication group failed, the rank list is: " << rank_list;
435     return FAILED;
436   }
437   if (gm_.CreateGroup(group_name, devices, comm_group) != SUCCESS) {
438     return FAILED;
439   }
440   group_to_rank_[group_name] = RankListName(rank_list);
441   return SUCCESS;
442 }
443 
444 // Create the group with only the given devices' ranks.
CreateGroup(const RankList & dev_ranks,Group * const comm_group)445 Status DeviceManager::CreateGroup(const RankList &dev_ranks, Group *const comm_group) {
446   mindspore::HashSet<int64_t> rank_set(dev_ranks.begin(), dev_ranks.end());
447   if (dev_ranks.size() != rank_set.size()) {
448     MS_LOG(ERROR) << "Invalid dev ranks(" << dev_ranks << "), it has the Duplicate elements in list";
449     return FAILED;
450   }
451   if (CheckDeviceList(dev_ranks) != SUCCESS) {
452     MS_LOG(ERROR) << "Create communication group failed, the rank list is: " << dev_ranks;
453     return FAILED;
454   }
455   std::string group_name = GenerateGroupNameByRanks(dev_ranks);
456   auto dev_list = CreateDeviceListByRankList(dev_ranks);
457   return CreateGroup(group_name, dev_list, comm_group);
458 }
459 
Clear()460 void DeviceManager::Clear() {
461   devices_.clear();
462   stage_devices_.clear();
463   gm_.Clear();
464 }
465 }  // namespace parallel
466 }  // namespace mindspore
467