• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/group_manager.h"
18 #include <algorithm>
19 #include <vector>
20 #include <utility>
21 #if !defined(NO_DLIB) || defined(ENABLE_GPU)
22 #include "backend/session/executor_manager.h"
23 #else
24 #include "frontend/parallel/parallel_stub/executor_manager_stub.h"
25 #endif
26 #include "frontend/parallel/device_manager.h"
27 #include "utils/comm_manager.h"
28 #include "utils/ms_context.h"
29 
30 namespace mindspore {
31 namespace parallel {
Group()32 Group::Group() {
33   name_.clear();
34   devices_.clear();
35 }
36 
Init(const std::string & name,const std::vector<Device> & devices)37 Status Group::Init(const std::string &name, const std::vector<Device> &devices) {
38   this->name_ = name;
39   this->devices_ = devices;
40   return Status::SUCCESS;
41 }
42 
GetDevicesList() const43 std::vector<Device> Group::GetDevicesList() const { return devices_; }
44 
IsInThisGroup(int64_t device_rank)45 bool Group::IsInThisGroup(int64_t device_rank) {
46   for (auto &device : devices_) {
47     if (device.rank() == device_rank) {
48       return true;
49     }
50   }
51   return false;
52 }
53 
54 // Get the position of the device in the group
GetIndex(size_t * index)55 Status Group::GetIndex(size_t *index) {
56   size_t pos = 0;
57   CheckGlobalDeviceManager();
58   int64_t rank = g_device_manager->global_rank();
59   for (auto &device : devices_) {
60     if (device.rank() == rank) {
61       *index = pos;
62       return Status::SUCCESS;
63     } else {
64       pos++;
65     }
66   }
67   MS_LOG(ERROR) << "Could not find device rank " << rank << "in this group!";
68   return Status::FAILED;
69 }
70 
GroupManager()71 GroupManager::GroupManager() { groups_.clear(); }
72 
73 #if !defined(NO_DLIB) || defined(ENABLE_GPU)
CreateGroupByExecutor(const std::string & device_name,const std::string & group_name,const std::vector<uint32_t> ranks,uint32_t device_id)74 bool GroupManager::CreateGroupByExecutor(const std::string &device_name, const std::string &group_name,
75                                          const std::vector<uint32_t> ranks, uint32_t device_id) {
76   // The group operation thread must be same with nccl init thread in the GPU device.
77   if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT) ||
78       (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice)) {
79     return CommManager::GetInstance().CreateGroupSync(group_name, ranks);
80   } else {
81     auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
82     MS_EXCEPTION_IF_NULL(executor);
83     return executor->CreateCommGroup(group_name, ranks);
84   }
85 }
86 
DestroyGroupByExecutor(const std::string & device_name,const std::string & group_name,uint32_t device_id)87 bool GroupManager::DestroyGroupByExecutor(const std::string &device_name, const std::string &group_name,
88                                           uint32_t device_id) {
89   // The group operation thread must be same with nccl init thread in the GPU device.
90   if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT) ||
91       (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice)) {
92     return CommManager::GetInstance().DestroyGroup(group_name);
93   } else {
94     auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
95     MS_EXCEPTION_IF_NULL(executor);
96     return executor->DestroyCommGroup(group_name);
97   }
98 }
99 
CreateGroups(const std::vector<std::pair<std::string,std::vector<uint32_t>>> & group_info)100 Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info) {
101   // Create group through the executor
102   auto context_ptr = MsContext::GetInstance();
103   MS_EXCEPTION_IF_NULL(context_ptr);
104   std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
105   uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
106   auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
107   MS_EXCEPTION_IF_NULL(executor);
108   for (auto &group : group_info) {
109     bool ret = true;
110     // The group operation thread must be same with nccl init thread in the GPU device.
111     if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT) ||
112         (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice)) {
113       ret = CommManager::GetInstance().CreateGroupSync(group.first, group.second);
114     } else {
115       ret = executor->CreateCommGroup(group.first, group.second);
116     }
117     if (!ret) {
118       MS_LOG(ERROR) << "Create group failed, group name is " << group.first << ", ranks is " << group.second;
119       return FAILED;
120     }
121     MS_LOG(INFO) << "Create group success, group name is " << group.first << ", ranks is " << group.second;
122   }
123 
124   return SUCCESS;
125 }
126 #else
CreateGroupByExecutor(const std::string & device_name,const std::string & group_name,const std::vector<uint32_t> ranks,uint32_t device_id)127 bool GroupManager::CreateGroupByExecutor(const std::string &device_name, const std::string &group_name,
128                                          const std::vector<uint32_t> ranks, uint32_t device_id) {
129   MS_LOG(WARNING) << "Create group in stub";
130   auto executor = parallel::ExecutorManager::Instance().GetExecutor(device_name, device_id);
131   MS_EXCEPTION_IF_NULL(executor);
132   return executor->CreateCommGroup(group_name, ranks);
133 }
134 
DestroyGroupByExecutor(const std::string & device_name,const std::string & group_name,uint32_t device_id)135 bool GroupManager::DestroyGroupByExecutor(const std::string &device_name, const std::string &group_name,
136                                           uint32_t device_id) {
137   MS_LOG(WARNING) << "Destroy group in stub";
138   auto executor = parallel::ExecutorManager::Instance().GetExecutor(device_name, device_id);
139   MS_EXCEPTION_IF_NULL(executor);
140   return executor->DestroyCommGroup(group_name);
141 }
142 
CreateGroups(const std::vector<std::pair<std::string,std::vector<uint32_t>>> & group_info)143 Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info) {
144   // Create group through the executor
145   auto context_ptr = MsContext::GetInstance();
146   MS_EXCEPTION_IF_NULL(context_ptr);
147   std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
148   uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
149   auto executor = parallel::ExecutorManager::Instance().GetExecutor(device_name, device_id);
150   MS_EXCEPTION_IF_NULL(executor);
151   for (auto &group : group_info) {
152     bool ret = executor->CreateCommGroup(group.first, group.second);
153     if (!ret) {
154       MS_LOG(ERROR) << "Create group failed, group name is " << group.first << ", ranks is " << group.second;
155       return FAILED;
156     }
157     MS_LOG(INFO) << "Create group success, group name is " << group.first << ", ranks is " << group.second;
158   }
159 
160   return SUCCESS;
161 }
162 #endif
CreateGroup(const std::string & group_name,const std::vector<Device> & devices,mindspore::parallel::Group * const group)163 Status GroupManager::CreateGroup(const std::string &group_name, const std::vector<Device> &devices,
164                                  mindspore::parallel::Group *const group) {
165   // it is simple to use size to determine whether it is a world group
166   uint32_t world_size = 0;
167   (void)CommManager::GetInstance().GetRankSize(world_group_, &world_size);
168 
169   if (devices.size() == world_size) {
170     auto iter = groups_.find(world_group_);
171     if (iter == groups_.end()) {
172       (void)group->Init(world_group_, devices);
173       groups_[world_group_] = *group;
174     } else {
175       *group = iter->second;
176     }
177     MS_LOG(INFO) << "It is world group " << world_group_ << ", no need to create it.";
178     return Status::SUCCESS;
179   }
180 
181   auto it = groups_.find(group_name);
182   // If there already exits a group with the desired 'name',
183   // let the pointer point to the group.
184   if (it != groups_.end()) {
185     *group = it->second;
186     return Status::SUCCESS;
187   } else {
188     (void)group->Init(group_name, devices);
189     groups_[group_name] = *group;
190 
191     vector<uint32_t> ranks;
192     (void)std::transform(std::begin(devices), std::end(devices), std::back_inserter(ranks),
193                          [](const Device dev) { return (uint32_t)dev.rank(); });
194     // Create group through the executor
195     auto context_ptr = MsContext::GetInstance();
196     MS_EXCEPTION_IF_NULL(context_ptr);
197     std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
198     uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
199 
200     std::pair<std::string, std::vector<uint32_t>> group_info = std::make_pair(group_name, ranks);
201     group_info_.push_back(group_info);
202 
203     bool ret = CreateGroupByExecutor(device_name, group_name, ranks, device_id);
204     if (!ret) {
205       MS_LOG(WARNING) << "Create group failed, group name is " << group_name;
206       return Status::FAILED;
207     }
208 
209     MS_LOG(INFO) << "Create group success, group name is " << group_name;
210     return Status::SUCCESS;
211   }
212 }
213 
DestroyGroup(const std::string & group_name)214 Status GroupManager::DestroyGroup(const std::string &group_name) {
215   auto context_ptr = MsContext::GetInstance();
216   MS_EXCEPTION_IF_NULL(context_ptr);
217   std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
218   uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
219   bool ret = DestroyGroupByExecutor(device_name, group_name, device_id);
220   if (!ret) {
221     return Status::FAILED;
222   }
223   return Status::SUCCESS;
224 }
225 
DestroyGroup(mindspore::parallel::Group * const group)226 Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) {
227   std::string name = (*group).name();
228   auto it = groups_.find(name);
229   if (it == groups_.end()) {
230     MS_LOG(ERROR) << "Could not find group name :" << name;
231     return Status::FAILED;
232   }
233   (void)groups_.erase(it);
234   return DestroyGroup(name);
235 }
236 
DestroyAllGroups()237 Status GroupManager::DestroyAllGroups() {
238   for (auto &it : groups_) {
239     std::string name = it.first;
240     auto ret = DestroyGroup(name);
241     if (ret != Status::SUCCESS) {
242       return Status::FAILED;
243     }
244   }
245   groups_.clear();
246   return Status::SUCCESS;
247 }
248 
GetRankID(const std::string & name,uint32_t * const rank_id)249 Status GroupManager::GetRankID(const std::string &name, uint32_t *const rank_id) {
250   auto it = groups_.find(name);
251   if (it == groups_.end()) {
252     MS_LOG(ERROR) << "Could not find group name :" << name;
253     return Status::FAILED;
254   }
255   bool ret = CommManager::GetInstance().GetRankID(name, rank_id);
256   if (!ret) {
257     return Status::FAILED;
258   }
259   return Status::SUCCESS;
260 }
261 
GetRankSize(const std::string & name,uint32_t * const rank_size)262 Status GroupManager::GetRankSize(const std::string &name, uint32_t *const rank_size) {
263   auto it = groups_.find(name);
264   if (it == groups_.end()) {
265     MS_LOG(ERROR) << "Could not find group name :" << name;
266     return Status::FAILED;
267   }
268   bool ret = CommManager::GetInstance().GetRankSize(name, rank_size);
269   if (!ret) {
270     return Status::FAILED;
271   }
272   return Status::SUCCESS;
273 }
274 
FindGroup(const std::string & name,mindspore::parallel::Group ** group)275 Status GroupManager::FindGroup(const std::string &name, mindspore::parallel::Group **group) {
276   auto it = groups_.find(name);
277   if (it == groups_.end()) {
278     return Status::FAILED;
279   }
280   *group = &it->second;
281   return Status::SUCCESS;
282 }
283 
Clear()284 void GroupManager::Clear() { (void)DestroyAllGroups(); }
285 }  // namespace parallel
286 }  // namespace mindspore
287