1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/group_manager.h"
18 #include <algorithm>
19 #include <vector>
20 #include <utility>
21 #if !defined(NO_DLIB) || defined(ENABLE_GPU)
22 #include "backend/session/executor_manager.h"
23 #else
24 #include "frontend/parallel/parallel_stub/executor_manager_stub.h"
25 #endif
26 #include "frontend/parallel/device_manager.h"
27 #include "utils/comm_manager.h"
28 #include "utils/ms_context.h"
29
30 namespace mindspore {
31 namespace parallel {
Group()32 Group::Group() {
33 name_.clear();
34 devices_.clear();
35 }
36
Init(const std::string & name,const std::vector<Device> & devices)37 Status Group::Init(const std::string &name, const std::vector<Device> &devices) {
38 this->name_ = name;
39 this->devices_ = devices;
40 return Status::SUCCESS;
41 }
42
GetDevicesList() const43 std::vector<Device> Group::GetDevicesList() const { return devices_; }
44
IsInThisGroup(int64_t device_rank)45 bool Group::IsInThisGroup(int64_t device_rank) {
46 for (auto &device : devices_) {
47 if (device.rank() == device_rank) {
48 return true;
49 }
50 }
51 return false;
52 }
53
54 // Get the position of the device in the group
GetIndex(size_t * index)55 Status Group::GetIndex(size_t *index) {
56 size_t pos = 0;
57 CheckGlobalDeviceManager();
58 int64_t rank = g_device_manager->global_rank();
59 for (auto &device : devices_) {
60 if (device.rank() == rank) {
61 *index = pos;
62 return Status::SUCCESS;
63 } else {
64 pos++;
65 }
66 }
67 MS_LOG(ERROR) << "Could not find device rank " << rank << "in this group!";
68 return Status::FAILED;
69 }
70
GroupManager()71 GroupManager::GroupManager() { groups_.clear(); }
72
73 #if !defined(NO_DLIB) || defined(ENABLE_GPU)
CreateGroupByExecutor(const std::string & device_name,const std::string & group_name,const std::vector<uint32_t> ranks,uint32_t device_id)74 bool GroupManager::CreateGroupByExecutor(const std::string &device_name, const std::string &group_name,
75 const std::vector<uint32_t> ranks, uint32_t device_id) {
76 // The group operation thread must be same with nccl init thread in the GPU device.
77 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT) ||
78 (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice)) {
79 return CommManager::GetInstance().CreateGroupSync(group_name, ranks);
80 } else {
81 auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
82 MS_EXCEPTION_IF_NULL(executor);
83 return executor->CreateCommGroup(group_name, ranks);
84 }
85 }
86
DestroyGroupByExecutor(const std::string & device_name,const std::string & group_name,uint32_t device_id)87 bool GroupManager::DestroyGroupByExecutor(const std::string &device_name, const std::string &group_name,
88 uint32_t device_id) {
89 // The group operation thread must be same with nccl init thread in the GPU device.
90 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT) ||
91 (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice)) {
92 return CommManager::GetInstance().DestroyGroup(group_name);
93 } else {
94 auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
95 MS_EXCEPTION_IF_NULL(executor);
96 return executor->DestroyCommGroup(group_name);
97 }
98 }
99
CreateGroups(const std::vector<std::pair<std::string,std::vector<uint32_t>>> & group_info)100 Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info) {
101 // Create group through the executor
102 auto context_ptr = MsContext::GetInstance();
103 MS_EXCEPTION_IF_NULL(context_ptr);
104 std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
105 uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
106 auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
107 MS_EXCEPTION_IF_NULL(executor);
108 for (auto &group : group_info) {
109 bool ret = true;
110 // The group operation thread must be same with nccl init thread in the GPU device.
111 if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT) ||
112 (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice)) {
113 ret = CommManager::GetInstance().CreateGroupSync(group.first, group.second);
114 } else {
115 ret = executor->CreateCommGroup(group.first, group.second);
116 }
117 if (!ret) {
118 MS_LOG(ERROR) << "Create group failed, group name is " << group.first << ", ranks is " << group.second;
119 return FAILED;
120 }
121 MS_LOG(INFO) << "Create group success, group name is " << group.first << ", ranks is " << group.second;
122 }
123
124 return SUCCESS;
125 }
126 #else
CreateGroupByExecutor(const std::string & device_name,const std::string & group_name,const std::vector<uint32_t> ranks,uint32_t device_id)127 bool GroupManager::CreateGroupByExecutor(const std::string &device_name, const std::string &group_name,
128 const std::vector<uint32_t> ranks, uint32_t device_id) {
129 MS_LOG(WARNING) << "Create group in stub";
130 auto executor = parallel::ExecutorManager::Instance().GetExecutor(device_name, device_id);
131 MS_EXCEPTION_IF_NULL(executor);
132 return executor->CreateCommGroup(group_name, ranks);
133 }
134
DestroyGroupByExecutor(const std::string & device_name,const std::string & group_name,uint32_t device_id)135 bool GroupManager::DestroyGroupByExecutor(const std::string &device_name, const std::string &group_name,
136 uint32_t device_id) {
137 MS_LOG(WARNING) << "Destroy group in stub";
138 auto executor = parallel::ExecutorManager::Instance().GetExecutor(device_name, device_id);
139 MS_EXCEPTION_IF_NULL(executor);
140 return executor->DestroyCommGroup(group_name);
141 }
142
CreateGroups(const std::vector<std::pair<std::string,std::vector<uint32_t>>> & group_info)143 Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info) {
144 // Create group through the executor
145 auto context_ptr = MsContext::GetInstance();
146 MS_EXCEPTION_IF_NULL(context_ptr);
147 std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
148 uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
149 auto executor = parallel::ExecutorManager::Instance().GetExecutor(device_name, device_id);
150 MS_EXCEPTION_IF_NULL(executor);
151 for (auto &group : group_info) {
152 bool ret = executor->CreateCommGroup(group.first, group.second);
153 if (!ret) {
154 MS_LOG(ERROR) << "Create group failed, group name is " << group.first << ", ranks is " << group.second;
155 return FAILED;
156 }
157 MS_LOG(INFO) << "Create group success, group name is " << group.first << ", ranks is " << group.second;
158 }
159
160 return SUCCESS;
161 }
162 #endif
CreateGroup(const std::string & group_name,const std::vector<Device> & devices,mindspore::parallel::Group * const group)163 Status GroupManager::CreateGroup(const std::string &group_name, const std::vector<Device> &devices,
164 mindspore::parallel::Group *const group) {
165 // it is simple to use size to determine whether it is a world group
166 uint32_t world_size = 0;
167 (void)CommManager::GetInstance().GetRankSize(world_group_, &world_size);
168
169 if (devices.size() == world_size) {
170 auto iter = groups_.find(world_group_);
171 if (iter == groups_.end()) {
172 (void)group->Init(world_group_, devices);
173 groups_[world_group_] = *group;
174 } else {
175 *group = iter->second;
176 }
177 MS_LOG(INFO) << "It is world group " << world_group_ << ", no need to create it.";
178 return Status::SUCCESS;
179 }
180
181 auto it = groups_.find(group_name);
182 // If there already exits a group with the desired 'name',
183 // let the pointer point to the group.
184 if (it != groups_.end()) {
185 *group = it->second;
186 return Status::SUCCESS;
187 } else {
188 (void)group->Init(group_name, devices);
189 groups_[group_name] = *group;
190
191 vector<uint32_t> ranks;
192 (void)std::transform(std::begin(devices), std::end(devices), std::back_inserter(ranks),
193 [](const Device dev) { return (uint32_t)dev.rank(); });
194 // Create group through the executor
195 auto context_ptr = MsContext::GetInstance();
196 MS_EXCEPTION_IF_NULL(context_ptr);
197 std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
198 uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
199
200 std::pair<std::string, std::vector<uint32_t>> group_info = std::make_pair(group_name, ranks);
201 group_info_.push_back(group_info);
202
203 bool ret = CreateGroupByExecutor(device_name, group_name, ranks, device_id);
204 if (!ret) {
205 MS_LOG(WARNING) << "Create group failed, group name is " << group_name;
206 return Status::FAILED;
207 }
208
209 MS_LOG(INFO) << "Create group success, group name is " << group_name;
210 return Status::SUCCESS;
211 }
212 }
213
DestroyGroup(const std::string & group_name)214 Status GroupManager::DestroyGroup(const std::string &group_name) {
215 auto context_ptr = MsContext::GetInstance();
216 MS_EXCEPTION_IF_NULL(context_ptr);
217 std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
218 uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
219 bool ret = DestroyGroupByExecutor(device_name, group_name, device_id);
220 if (!ret) {
221 return Status::FAILED;
222 }
223 return Status::SUCCESS;
224 }
225
DestroyGroup(mindspore::parallel::Group * const group)226 Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) {
227 std::string name = (*group).name();
228 auto it = groups_.find(name);
229 if (it == groups_.end()) {
230 MS_LOG(ERROR) << "Could not find group name :" << name;
231 return Status::FAILED;
232 }
233 (void)groups_.erase(it);
234 return DestroyGroup(name);
235 }
236
DestroyAllGroups()237 Status GroupManager::DestroyAllGroups() {
238 for (auto &it : groups_) {
239 std::string name = it.first;
240 auto ret = DestroyGroup(name);
241 if (ret != Status::SUCCESS) {
242 return Status::FAILED;
243 }
244 }
245 groups_.clear();
246 return Status::SUCCESS;
247 }
248
GetRankID(const std::string & name,uint32_t * const rank_id)249 Status GroupManager::GetRankID(const std::string &name, uint32_t *const rank_id) {
250 auto it = groups_.find(name);
251 if (it == groups_.end()) {
252 MS_LOG(ERROR) << "Could not find group name :" << name;
253 return Status::FAILED;
254 }
255 bool ret = CommManager::GetInstance().GetRankID(name, rank_id);
256 if (!ret) {
257 return Status::FAILED;
258 }
259 return Status::SUCCESS;
260 }
261
GetRankSize(const std::string & name,uint32_t * const rank_size)262 Status GroupManager::GetRankSize(const std::string &name, uint32_t *const rank_size) {
263 auto it = groups_.find(name);
264 if (it == groups_.end()) {
265 MS_LOG(ERROR) << "Could not find group name :" << name;
266 return Status::FAILED;
267 }
268 bool ret = CommManager::GetInstance().GetRankSize(name, rank_size);
269 if (!ret) {
270 return Status::FAILED;
271 }
272 return Status::SUCCESS;
273 }
274
FindGroup(const std::string & name,mindspore::parallel::Group ** group)275 Status GroupManager::FindGroup(const std::string &name, mindspore::parallel::Group **group) {
276 auto it = groups_.find(name);
277 if (it == groups_.end()) {
278 return Status::FAILED;
279 }
280 *group = &it->second;
281 return Status::SUCCESS;
282 }
283
Clear()284 void GroupManager::Clear() { (void)DestroyAllGroups(); }
285 } // namespace parallel
286 } // namespace mindspore
287