• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/device_manager.h"
18 
19 #include <algorithm>
20 #include <string>
21 #include <unordered_set>
22 #include <vector>
23 
24 #include "frontend/parallel/step_parallel.h"
25 #include "utils/log_adapter.h"
26 
27 namespace mindspore {
28 namespace parallel {
29 DeviceManagerPtr g_device_manager = nullptr;
InitDevice(int64_t device_num,int64_t global_rank,const std::string & backend,const std::vector<int64_t> & stage)30 bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &backend,
31                 const std::vector<int64_t> &stage) {
32   if (device_num <= 0) {
33     MS_LOG(ERROR) << "'device_num' must be positive.";
34     return false;
35   }
36   if (global_rank < 0) {
37     MS_LOG(ERROR) << "'global_rank' must be nonnegative.";
38     return false;
39   }
40   if (device_num > MAX_DEVICE_NUM) {
41     MS_LOG(ERROR) << "'device_num' must be no more than " << MAX_DEVICE_NUM << ".";
42     return false;
43   }
44   // 'device_num_converted' must be the power of 2
45   if ((LongToUlong(device_num) & LongToUlong(device_num - 1)) != 0) {
46     MS_LOG(ERROR) << "'device_num' must be the power of 2.";
47     return false;
48   }
49   if (global_rank >= device_num) {
50     MS_LOG(ERROR) << "'global_rank' must be less than 'device_num'.";
51     return false;
52   }
53   if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) {
54     MS_LOG(ERROR) << "Invalid backend: " << backend;
55     return false;
56   }
57   if (stage.empty()) {
58     MS_LOG(ERROR) << "The size of stage must be positive";
59     return false;
60   }
61 
62   RankList devices, stage_map;
63   for (int64_t i = 0; i < device_num; ++i) {
64     devices.push_back(i);
65   }
66 
67   int64_t summed_value = 0;
68   for (auto begin = stage.begin(); begin != stage.end(); ++begin) {
69     if (*begin <= 0) {
70       MS_LOG(ERROR) << "The value in the pipeline stages should be positive value";
71       return false;
72     }
73     summed_value += *begin;
74     stage_map.push_back(*begin);
75   }
76 
77   if (summed_value != device_num) {
78     MS_LOG(ERROR) << "The sum of the pipeline stage :" << summed_value << " is not equal to the device_num "
79                   << device_num;
80     return false;
81   }
82 
83   for (auto &ele : stage_map) {
84     MS_LOG(DEBUG) << "Obtained stage id: " << ele;
85   }
86   if (g_device_manager) {
87     auto gm = g_device_manager->group_manager();
88     g_device_manager = std::make_shared<DeviceManager>();
89     g_device_manager->set_group_manager(gm);
90   } else {
91     g_device_manager = std::make_shared<DeviceManager>();
92   }
93   if (g_device_manager->Init(devices, global_rank, stage_map, backend) == SUCCESS) {
94     MS_LOG(INFO) << "Device initialization succeeds.";
95     return true;
96   }
97 
98   MS_LOG(ERROR) << "Device initialization fails.";
99   return false;
100 }
101 
CheckGlobalDeviceManager()102 void CheckGlobalDeviceManager() {
103   if (g_device_manager == nullptr) {
104     MS_LOG(EXCEPTION) << "Device information has not been set!";
105   }
106 }
107 
GetListMemberByIndex(size_t index,const RankList & devices)108 int64_t GetListMemberByIndex(size_t index, const RankList &devices) {
109   size_t i = 0;
110   int64_t result = 0;
111   if ((devices.empty()) || (index >= devices.size())) {
112     MS_LOG(EXCEPTION) << "Index is out of the list scope";
113   }
114   auto it = devices.begin();
115   for (; it != devices.end(); ++it) {
116     if (i == index) {
117       result = *it;
118       break;
119     }
120     ++i;
121   }
122   return result;
123 }
124 
GetListMemberByIndex(size_t index,const std::vector<std::shared_ptr<Device>> & device_list)125 std::shared_ptr<Device> GetListMemberByIndex(size_t index, const std::vector<std::shared_ptr<Device>> &device_list) {
126   size_t i = 0;
127   std::shared_ptr<Device> result;
128   if ((device_list.empty()) || (index >= device_list.size())) {
129     MS_LOG(EXCEPTION) << "Index is out of the list scope";
130   }
131   auto it = device_list.begin();
132   for (; it != device_list.end(); ++it) {
133     if (i == index) {
134       result = *it;
135       break;
136     }
137     ++i;
138   }
139   return result;
140 }
141 
142 // E.g. devices = [0, 1, 2, 3, 4, 5, 6, 7], stage_map = [4, 4],
143 // therefore the stage_devices_ = [[0, 1, 2, 3], [4, 5, 6, 7]].
Init(const RankList & devices,int64_t global_device_rank,const RankList & stage_map,const std::string & backend)144 Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, const RankList &stage_map,
145                            const std::string &backend) {
146   if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) {
147     MS_LOG(ERROR) << "Invalid backend: " << backend;
148     return FAILED;
149   }
150 
151   if (stage_map.empty() || devices.empty()) {
152     MS_LOG(ERROR) << "The size of stage_map and devices must be positive";
153     return FAILED;
154   }
155 
156   for (auto &dev : devices) {
157     std::shared_ptr<Device> one = std::make_shared<Device>(dev);
158     devices_.push_back(one);
159   }
160 
161   size_t global_index = 0;
162   for (auto &stage : stage_map) {
163     int64_t num_device = stage;
164     if (num_device > MAX_DEVICE_NUM) {
165       MS_LOG(ERROR) << "The number of 'devices' in a stage must not be greater than " << MAX_DEVICE_NUM;
166       return FAILED;
167     }
168     if (num_device <= 0) {
169       MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive";
170       return FAILED;
171     }
172     RankList curr_dev_list;
173     for (int64_t i = 0; i < num_device; ++i) {
174       curr_dev_list.push_back(GetListMemberByIndex(global_index, devices));
175       global_index++;
176     }
177     stage_devices_.push_back(curr_dev_list);
178   }
179 
180   std::shared_ptr<Device> dev = std::make_shared<Device>(global_device_rank);
181   device_ = dev;
182 
183   global_rank_ = global_device_rank;
184   stage_num_ = static_cast<const int64_t>(stage_map.size());
185   stage_id_ = global_device_rank / static_cast<const int64_t>(devices.size() / stage_map.size());
186   rank_index_in_stage_ = global_rank_ - stage_id_ * (static_cast<const int64_t>(devices.size()) / stage_num_);
187   stage_device_num_ = static_cast<const int64_t>(devices.size()) / stage_num_;
188 
189   backend_ = backend;
190 
191   if (backend == HCCL_BACKEND) {
192     gm_.set_world_group(HCCL_WORLD_GROUP);
193   } else if (backend_ == NCCL_BACKEND) {
194     gm_.set_world_group(NCCL_WORLD_GROUP);
195   } else {
196     gm_.set_world_group(UNDEFINED_WORLD_GROUP);
197   }
198   MS_LOG(INFO) << "The device num: " << devices.size() << ", rank id: " << global_device_rank
199                << ", the backend: " << backend << ", the stage num: " << stage_num_ << ", the stage id: " << stage_id_
200                << ", the rank index in stage is: " << rank_index_in_stage_;
201   return SUCCESS;
202 }
203 
GetDeviceListInThisStage() const204 RankList DeviceManager::GetDeviceListInThisStage() const { return GetDeviceListByStageId(stage_id_); }
205 
GetDeviceListByStageId(int64_t stage_id) const206 RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const {
207   if (LongToSize(stage_id) >= stage_devices_.size())
208     MS_LOG(ERROR) << "the 'stage_id': " << stage_id
209                   << ", is out of the scope of 'stage_devices_': " << stage_devices_.size();
210   RankList res;
211   int64_t index = 0;
212   for (auto &stage : stage_devices_) {
213     if (index == stage_id) {
214       return stage;
215     }
216     index++;
217   }
218   return res;
219 }
220 
CreateNewDeviceByRank(int64_t rank) const221 Device DeviceManager::CreateNewDeviceByRank(int64_t rank) const { return Device(rank); }
222 
CreateDeviceListByRankList(RankList ranks)223 std::vector<Device> DeviceManager::CreateDeviceListByRankList(RankList ranks) {
224   std::vector<Device> dev_list;
225   for (auto &rank : ranks) {
226     Device one = CreateNewDeviceByRank(rank);
227     dev_list.push_back(one);
228   }
229   return dev_list;
230 }
231 
GetInstance()232 DeviceManager &DeviceManager::GetInstance() {
233   static DeviceManager instance = DeviceManager();
234   return instance;
235 }
236 
FindRankListNameByHashName(const std::string & hash_name)237 std::string DeviceManager::FindRankListNameByHashName(const std::string &hash_name) {
238   std::string tmp = "WORLD_GROUP";
239   if ((hash_name == HCCL_WORLD_GROUP) || (hash_name == NCCL_WORLD_GROUP)) {
240     return tmp;
241   }
242   auto iter = group_to_rank_.find(hash_name);
243   if (iter == group_to_rank_.end()) {
244     MS_LOG(WARNING) << "Can not find the rank list name by hash name: " << hash_name;
245     return tmp;
246   }
247   return iter->second;
248 }
249 
HashName(const std::string & origin_name)250 std::string HashName(const std::string &origin_name) { return std::to_string(std::hash<string>{}(origin_name)); }
251 
252 // Group name is generated using the increasing ranks of the devices.
253 // E.g. the devices' ranks are '<0, 5, 3, 7, 1>', and the generated group name
254 // is '0-1-3-5-7'.
GenerateGroupNameByRanks(RankList ranks)255 std::string DeviceManager::GenerateGroupNameByRanks(RankList ranks) {
256   std::string rank_list_name;
257   std::vector<int64_t>::iterator it;
258   std::sort(ranks.begin(), ranks.end());  // sorted in increasing order
259   for (it = ranks.begin(); it != ranks.end(); ++it) {
260     if (it == ranks.begin()) {
261       rank_list_name = std::to_string(*it);
262     } else {
263       rank_list_name += "-" + std::to_string(*it);
264     }
265   }
266 
267   // hash rank-list-name and add ranks' size as prefix
268   std::string group_hash_name = HashName(rank_list_name);
269   std::string group_name = std::to_string(ranks.size()) + "-" + group_hash_name;
270 
271   if (rank_to_group_.find(rank_list_name) == rank_to_group_.end()) {
272     if (group_to_rank_.find(group_name) == group_to_rank_.end()) {
273       rank_to_group_[rank_list_name] = group_name;
274       group_to_rank_[group_name] = rank_list_name;
275       MS_LOG(INFO) << "The rank list name is " << rank_list_name << "nd group name is " << group_name;
276     } else {
277       MS_LOG(EXCEPTION) << "Hash collision, the current rank list: " << rank_list_name
278                         << "the old rank list:" << group_to_rank_.find(group_name)->second
279                         << "the group name: " << group_name;
280     }
281   }
282   return group_name;
283 }
284 
285 // Create the group with the given devices and the given name. The GroupManager
286 // gm_ will create a new group only if there does not exit a group with the same
287 // name. Otherwise, let the pointer g point to that group.
CreateGroup(const std::string & group_name,const std::vector<mindspore::parallel::Device> & devices)288 Group DeviceManager::CreateGroup(const std::string &group_name,
289                                  const std::vector<mindspore::parallel::Device> &devices) {
290   Group g;
291   (void)gm_.CreateGroup(group_name, devices, &g);
292   return g;
293 }
294 
295 // Create the group with only the given devices' ranks.
CreateGroup(const RankList & dev_ranks)296 Group DeviceManager::CreateGroup(const RankList &dev_ranks) {
297   std::unordered_set<int64_t> rank_set(dev_ranks.begin(), dev_ranks.end());
298   if (dev_ranks.size() != rank_set.size()) {
299     MS_LOG(EXCEPTION) << "Invalid dev ranks(" << dev_ranks << "), it has the Duplicate elements in list";
300   }
301 
302   std::string group_name = GenerateGroupNameByRanks(dev_ranks);
303   auto dev_list = CreateDeviceListByRankList(dev_ranks);
304   return CreateGroup(group_name, dev_list);
305 }
306 
Clear()307 void DeviceManager::Clear() {
308   devices_.clear();
309   stage_devices_.clear();
310   gm_.Clear();
311 }
312 }  // namespace parallel
313 }  // namespace mindspore
314