• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/group_manager.h"
18 #include <algorithm>
19 #include <vector>
20 #include <utility>
21 #if (!defined(_WIN32) && !defined(__APPLE__) && !(defined(ENABLE_TESTCASES) || defined(ENABLE_TEST)))
22 #include "backend/common/session/executor_manager.h"
23 #else
24 #include "frontend/parallel/parallel_stub/executor_manager_stub.h"
25 #endif
26 #include "frontend/parallel/device_manager.h"
27 #include "include/common/utils/comm_manager.h"
28 #include "include/common/utils/parallel_context.h"
29 #include "utils/ms_context.h"
30 
31 namespace mindspore {
32 namespace parallel {
Group()33 Group::Group() {
34   name_.clear();
35   devices_.clear();
36 }
37 
Init(const std::string & name,const std::vector<Device> & devices)38 Status Group::Init(const std::string &name, const std::vector<Device> &devices) {
39   this->name_ = name;
40   this->devices_ = devices;
41   return Status::SUCCESS;
42 }
43 
GetDevicesList() const44 std::vector<Device> Group::GetDevicesList() const { return devices_; }
45 
IsInThisGroup(int64_t device_rank)46 bool Group::IsInThisGroup(int64_t device_rank) {
47   for (auto &device : devices_) {
48     if (device.rank() == device_rank) {
49       return true;
50     }
51   }
52   return false;
53 }
54 
55 // Get the position of the device in the group
GetIndex(size_t * index)56 Status Group::GetIndex(size_t *index) {
57   size_t pos = 0;
58   auto rank = ParallelContext::GetInstance()->global_rank();
59   if (!ParallelContext::GetInstance()->do_transform()) {
60     CheckGlobalDeviceManager();
61     MS_EXCEPTION_IF_NULL(g_device_manager);
62     rank = g_device_manager->global_rank();
63   }
64   for (auto &device : devices_) {
65     if (device.rank() == rank) {
66       *index = pos;
67       return Status::SUCCESS;
68     } else {
69       pos++;
70     }
71   }
72   MS_LOG(ERROR) << "Could not find device rank " << rank << "in this group!";
73   return Status::FAILED;
74 }
75 
GetIndexByRank(int64_t rank,size_t * index)76 Status Group::GetIndexByRank(int64_t rank, size_t *index) {
77   size_t pos = 0;
78   for (auto &device : devices_) {
79     if (device.rank() == rank) {
80       *index = pos;
81       return Status::SUCCESS;
82     } else {
83       pos++;
84     }
85   }
86   MS_LOG(ERROR) << "Could not find device rank " << rank << "in this group!";
87   return Status::FAILED;
88 }
89 
GroupManager()90 GroupManager::GroupManager() { groups_.clear(); }
91 
92 #if (!defined(_WIN32) && !defined(__APPLE__) && !(defined(ENABLE_TESTCASES) || defined(ENABLE_TEST)))
CreateGroupByExecutor(const std::string & device_name,const std::string & group_name,const std::vector<uint32_t> ranks,uint32_t device_id) const93 bool GroupManager::CreateGroupByExecutor(const std::string &device_name, const std::string &group_name,
94                                          const std::vector<uint32_t> ranks, uint32_t device_id) const {
95   // The group operation thread must be same with nccl init thread in the GPU device.
96   if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT) ||
97       (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice)) {
98     return CommManager::GetInstance().CreateGroupSync(group_name, ranks);
99   } else {
100     auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
101     MS_EXCEPTION_IF_NULL(executor);
102     return executor->CreateCommGroup(group_name, ranks);
103   }
104 }
105 
DestroyGroupByExecutor(const std::string & device_name,const std::string & group_name,uint32_t device_id) const106 bool GroupManager::DestroyGroupByExecutor(const std::string &device_name, const std::string &group_name,
107                                           uint32_t device_id) const {
108   // The group operation thread must be same with nccl init thread in the GPU device.
109   if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT) ||
110       (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice)) {
111     return CommManager::GetInstance().DestroyGroup(group_name);
112   } else {
113     auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
114     MS_EXCEPTION_IF_NULL(executor);
115     return executor->DestroyCommGroup(group_name);
116   }
117 }
118 
CreateGroups(const std::vector<std::pair<std::string,std::vector<uint32_t>>> & group_info)119 Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info) {
120   // Create group through the executor
121   auto context_ptr = MsContext::GetInstance();
122   MS_EXCEPTION_IF_NULL(context_ptr);
123   auto print_func = [](const auto &group, bool judge) {
124     if (judge) {
125       MS_LOG(INFO) << "Create group success, group name is " << group.first << ", ranks is " << group.second;
126     } else {
127       MS_LOG(ERROR) << "Create group failed, group name is " << group.first << ", ranks is " << group.second;
128     }
129   };
130   bool ret = true;
131   // The group operation thread must be same with nccl init thread in the GPU device.
132   if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT) ||
133       (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice)) {
134     for (auto &group : group_info) {
135       ret = CommManager::GetInstance().CreateGroupSync(group.first, group.second);
136       print_func(group, ret);
137       if (!ret) {
138         return FAILED;
139       }
140     }
141   } else {
142     std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
143     uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
144     auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
145     MS_EXCEPTION_IF_NULL(executor);
146     for (auto &group : group_info) {
147       ret = executor->CreateCommGroup(group.first, group.second);
148       print_func(group, ret);
149       if (!ret) {
150         return FAILED;
151       }
152     }
153   }
154   return SUCCESS;
155 }
156 #else
CreateGroupByExecutor(const std::string & device_name,const std::string & group_name,const std::vector<uint32_t> ranks,uint32_t device_id) const157 bool GroupManager::CreateGroupByExecutor(const std::string &device_name, const std::string &group_name,
158                                          const std::vector<uint32_t> ranks, uint32_t device_id) const {
159   MS_LOG(WARNING) << "Create group in stub";
160   auto executor = parallel::ExecutorManager::Instance().GetExecutor(device_name, device_id);
161   MS_EXCEPTION_IF_NULL(executor);
162   return executor->CreateCommGroup(group_name, ranks);
163 }
164 
DestroyGroupByExecutor(const std::string & device_name,const std::string & group_name,uint32_t device_id) const165 bool GroupManager::DestroyGroupByExecutor(const std::string &device_name, const std::string &group_name,
166                                           uint32_t device_id) const {
167   MS_LOG(WARNING) << "Destroy group in stub";
168   auto executor = parallel::ExecutorManager::Instance().GetExecutor(device_name, device_id);
169   MS_EXCEPTION_IF_NULL(executor);
170   return executor->DestroyCommGroup(group_name);
171 }
172 
CreateGroups(const std::vector<std::pair<std::string,std::vector<uint32_t>>> & group_info)173 Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info) {
174   // Create group through the executor
175   auto context_ptr = MsContext::GetInstance();
176   MS_EXCEPTION_IF_NULL(context_ptr);
177   std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
178   uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
179   auto executor = parallel::ExecutorManager::Instance().GetExecutor(device_name, device_id);
180   MS_EXCEPTION_IF_NULL(executor);
181   for (auto &group : group_info) {
182     bool ret = executor->CreateCommGroup(group.first, group.second);
183     if (!ret) {
184       MS_LOG(ERROR) << "Create group failed, group name is " << group.first << ", ranks is " << group.second;
185       return FAILED;
186     }
187     MS_LOG(INFO) << "Create group success, group name is " << group.first << ", ranks is " << group.second;
188   }
189 
190   return SUCCESS;
191 }
192 #endif
193 
CreateGroup(const std::string & group_name,const std::vector<Device> & devices,mindspore::parallel::Group * const group)194 Status GroupManager::CreateGroup(const std::string &group_name, const std::vector<Device> &devices,
195                                  mindspore::parallel::Group *const group) {
196   // it is simple to use size to determine whether it is a world group
197   if (devices.size() == g_device_manager->DeviceNum()) {
198     auto iter = groups_.find(world_group_);
199     if (iter == groups_.end()) {
200       (void)group->Init(world_group_, devices);
201       groups_[world_group_] = *group;
202     } else {
203       *group = iter->second;
204     }
205     MS_LOG(INFO) << "It is world group " << world_group_ << ", no need to create it.";
206     return Status::SUCCESS;
207   }
208 
209   auto it = groups_.find(group_name);
210   // If there already exits a group with the desired 'name',
211   // let the pointer point to the group.
212   if (it != groups_.end()) {
213     *group = it->second;
214     return Status::SUCCESS;
215   } else {
216     (void)group->Init(group_name, devices);
217     groups_[group_name] = *group;
218 
219     std::vector<uint32_t> ranks;
220     (void)std::transform(std::begin(devices), std::end(devices), std::back_inserter(ranks),
221                          [](const Device dev) { return static_cast<uint32_t>(dev.rank()); });
222     // Create group through the executor
223     auto context_ptr = MsContext::GetInstance();
224     MS_EXCEPTION_IF_NULL(context_ptr);
225     std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
226     uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
227 
228     std::pair<std::string, std::vector<uint32_t>> group_info = std::make_pair(group_name, ranks);
229     group_info_.push_back(group_info);
230 
231     bool ret = CreateGroupByExecutor(device_name, group_name, ranks, device_id);
232     if (!ret) {
233       MS_LOG(WARNING) << "Create group failed, group name is " << group_name;
234       return Status::FAILED;
235     }
236 
237     MS_LOG(INFO) << "Create group success, group name is " << group_name;
238     return Status::SUCCESS;
239   }
240 }
241 
DestroyGroup(const std::string & group_name) const242 Status GroupManager::DestroyGroup(const std::string &group_name) const {
243   auto context_ptr = MsContext::GetInstance();
244   MS_EXCEPTION_IF_NULL(context_ptr);
245   std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
246   uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
247   bool ret = DestroyGroupByExecutor(device_name, group_name, device_id);
248   if (!ret) {
249     return Status::FAILED;
250   }
251   return Status::SUCCESS;
252 }
253 
DestroyGroup(mindspore::parallel::Group * const group)254 Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) {
255   const std::string name = (*group).name();
256   auto it = groups_.find(name);
257   if (it == groups_.cend()) {
258     MS_LOG(ERROR) << "Could not find group name :" << name;
259     return Status::FAILED;
260   }
261   (void)groups_.erase(it);
262   return DestroyGroup(name);
263 }
264 
DestroyAllGroups()265 Status GroupManager::DestroyAllGroups() {
266   for (auto it = groups_.cbegin(); it != groups_.cend(); ++it) {
267     std::string name = it->first;
268     auto ret = DestroyGroup(name);
269     if (ret != Status::SUCCESS) {
270       return Status::FAILED;
271     }
272   }
273   groups_.clear();
274   return Status::SUCCESS;
275 }
276 
GetRankID(const std::string & name,uint32_t * const rank_id)277 Status GroupManager::GetRankID(const std::string &name, uint32_t *const rank_id) {
278   auto it = groups_.find(name);
279   if (it == groups_.cend()) {
280     MS_LOG(ERROR) << "Could not find group name :" << name;
281     return Status::FAILED;
282   }
283   bool ret = CommManager::GetInstance().GetRankID(name, rank_id);
284   if (!ret) {
285     return Status::FAILED;
286   }
287   return Status::SUCCESS;
288 }
289 
GetRankSize(const std::string & name,uint32_t * const rank_size)290 Status GroupManager::GetRankSize(const std::string &name, uint32_t *const rank_size) {
291   auto it = groups_.find(name);
292   if (it == groups_.cend()) {
293     MS_LOG(ERROR) << "Could not find group name :" << name;
294     return Status::FAILED;
295   }
296   bool ret = CommManager::GetInstance().GetRankSize(name, rank_size);
297   if (!ret) {
298     return Status::FAILED;
299   }
300   return Status::SUCCESS;
301 }
302 
FindGroup(const std::string & name,mindspore::parallel::Group ** group)303 Status GroupManager::FindGroup(const std::string &name, mindspore::parallel::Group **group) {
304   auto it = groups_.find(name);
305   if (it == groups_.cend()) {
306     return Status::FAILED;
307   }
308   *group = &it->second;
309   return Status::SUCCESS;
310 }
311 
Clear()312 void GroupManager::Clear() { (void)DestroyAllGroups(); }
313 }  // namespace parallel
314 }  // namespace mindspore
315