1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/group_manager.h"
18 #include <algorithm>
19 #include <vector>
20 #include <utility>
21 #if (!defined(_WIN32) && !defined(__APPLE__) && !(defined(ENABLE_TESTCASES) || defined(ENABLE_TEST)))
22 #include "backend/common/session/executor_manager.h"
23 #else
24 #include "frontend/parallel/parallel_stub/executor_manager_stub.h"
25 #endif
26 #include "frontend/parallel/device_manager.h"
27 #include "include/common/utils/comm_manager.h"
28 #include "include/common/utils/parallel_context.h"
29 #include "utils/ms_context.h"
30
31 namespace mindspore {
32 namespace parallel {
Group()33 Group::Group() {
34 name_.clear();
35 devices_.clear();
36 }
37
Init(const std::string & name,const std::vector<Device> & devices)38 Status Group::Init(const std::string &name, const std::vector<Device> &devices) {
39 this->name_ = name;
40 this->devices_ = devices;
41 return Status::SUCCESS;
42 }
43
GetDevicesList() const44 std::vector<Device> Group::GetDevicesList() const { return devices_; }
45
IsInThisGroup(int64_t device_rank)46 bool Group::IsInThisGroup(int64_t device_rank) {
47 for (auto &device : devices_) {
48 if (device.rank() == device_rank) {
49 return true;
50 }
51 }
52 return false;
53 }
54
55 // Get the position of the device in the group
GetIndex(size_t * index)56 Status Group::GetIndex(size_t *index) {
57 size_t pos = 0;
58 auto rank = ParallelContext::GetInstance()->global_rank();
59 if (!ParallelContext::GetInstance()->do_transform()) {
60 CheckGlobalDeviceManager();
61 MS_EXCEPTION_IF_NULL(g_device_manager);
62 rank = g_device_manager->global_rank();
63 }
64 for (auto &device : devices_) {
65 if (device.rank() == rank) {
66 *index = pos;
67 return Status::SUCCESS;
68 } else {
69 pos++;
70 }
71 }
72 MS_LOG(ERROR) << "Could not find device rank " << rank << "in this group!";
73 return Status::FAILED;
74 }
75
GetIndexByRank(int64_t rank,size_t * index)76 Status Group::GetIndexByRank(int64_t rank, size_t *index) {
77 size_t pos = 0;
78 for (auto &device : devices_) {
79 if (device.rank() == rank) {
80 *index = pos;
81 return Status::SUCCESS;
82 } else {
83 pos++;
84 }
85 }
86 MS_LOG(ERROR) << "Could not find device rank " << rank << "in this group!";
87 return Status::FAILED;
88 }
89
GroupManager()90 GroupManager::GroupManager() { groups_.clear(); }
91
92 #if (!defined(_WIN32) && !defined(__APPLE__) && !(defined(ENABLE_TESTCASES) || defined(ENABLE_TEST)))
CreateGroupByExecutor(const std::string & device_name,const std::string & group_name,const std::vector<uint32_t> ranks,uint32_t device_id) const93 bool GroupManager::CreateGroupByExecutor(const std::string &device_name, const std::string &group_name,
94 const std::vector<uint32_t> ranks, uint32_t device_id) const {
95 // The group operation thread must be same with nccl init thread in the GPU device.
96 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT) ||
97 (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice)) {
98 return CommManager::GetInstance().CreateGroupSync(group_name, ranks);
99 } else {
100 auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
101 MS_EXCEPTION_IF_NULL(executor);
102 return executor->CreateCommGroup(group_name, ranks);
103 }
104 }
105
DestroyGroupByExecutor(const std::string & device_name,const std::string & group_name,uint32_t device_id) const106 bool GroupManager::DestroyGroupByExecutor(const std::string &device_name, const std::string &group_name,
107 uint32_t device_id) const {
108 // The group operation thread must be same with nccl init thread in the GPU device.
109 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT) ||
110 (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice)) {
111 return CommManager::GetInstance().DestroyGroup(group_name);
112 } else {
113 auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
114 MS_EXCEPTION_IF_NULL(executor);
115 return executor->DestroyCommGroup(group_name);
116 }
117 }
118
CreateGroups(const std::vector<std::pair<std::string,std::vector<uint32_t>>> & group_info)119 Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info) {
120 // Create group through the executor
121 auto context_ptr = MsContext::GetInstance();
122 MS_EXCEPTION_IF_NULL(context_ptr);
123 auto print_func = [](const auto &group, bool judge) {
124 if (judge) {
125 MS_LOG(INFO) << "Create group success, group name is " << group.first << ", ranks is " << group.second;
126 } else {
127 MS_LOG(ERROR) << "Create group failed, group name is " << group.first << ", ranks is " << group.second;
128 }
129 };
130 bool ret = true;
131 // The group operation thread must be same with nccl init thread in the GPU device.
132 if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT) ||
133 (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice)) {
134 for (auto &group : group_info) {
135 ret = CommManager::GetInstance().CreateGroupSync(group.first, group.second);
136 print_func(group, ret);
137 if (!ret) {
138 return FAILED;
139 }
140 }
141 } else {
142 std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
143 uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
144 auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
145 MS_EXCEPTION_IF_NULL(executor);
146 for (auto &group : group_info) {
147 ret = executor->CreateCommGroup(group.first, group.second);
148 print_func(group, ret);
149 if (!ret) {
150 return FAILED;
151 }
152 }
153 }
154 return SUCCESS;
155 }
156 #else
CreateGroupByExecutor(const std::string & device_name,const std::string & group_name,const std::vector<uint32_t> ranks,uint32_t device_id) const157 bool GroupManager::CreateGroupByExecutor(const std::string &device_name, const std::string &group_name,
158 const std::vector<uint32_t> ranks, uint32_t device_id) const {
159 MS_LOG(WARNING) << "Create group in stub";
160 auto executor = parallel::ExecutorManager::Instance().GetExecutor(device_name, device_id);
161 MS_EXCEPTION_IF_NULL(executor);
162 return executor->CreateCommGroup(group_name, ranks);
163 }
164
DestroyGroupByExecutor(const std::string & device_name,const std::string & group_name,uint32_t device_id) const165 bool GroupManager::DestroyGroupByExecutor(const std::string &device_name, const std::string &group_name,
166 uint32_t device_id) const {
167 MS_LOG(WARNING) << "Destroy group in stub";
168 auto executor = parallel::ExecutorManager::Instance().GetExecutor(device_name, device_id);
169 MS_EXCEPTION_IF_NULL(executor);
170 return executor->DestroyCommGroup(group_name);
171 }
172
CreateGroups(const std::vector<std::pair<std::string,std::vector<uint32_t>>> & group_info)173 Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info) {
174 // Create group through the executor
175 auto context_ptr = MsContext::GetInstance();
176 MS_EXCEPTION_IF_NULL(context_ptr);
177 std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
178 uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
179 auto executor = parallel::ExecutorManager::Instance().GetExecutor(device_name, device_id);
180 MS_EXCEPTION_IF_NULL(executor);
181 for (auto &group : group_info) {
182 bool ret = executor->CreateCommGroup(group.first, group.second);
183 if (!ret) {
184 MS_LOG(ERROR) << "Create group failed, group name is " << group.first << ", ranks is " << group.second;
185 return FAILED;
186 }
187 MS_LOG(INFO) << "Create group success, group name is " << group.first << ", ranks is " << group.second;
188 }
189
190 return SUCCESS;
191 }
192 #endif
193
CreateGroup(const std::string & group_name,const std::vector<Device> & devices,mindspore::parallel::Group * const group)194 Status GroupManager::CreateGroup(const std::string &group_name, const std::vector<Device> &devices,
195 mindspore::parallel::Group *const group) {
196 // it is simple to use size to determine whether it is a world group
197 if (devices.size() == g_device_manager->DeviceNum()) {
198 auto iter = groups_.find(world_group_);
199 if (iter == groups_.end()) {
200 (void)group->Init(world_group_, devices);
201 groups_[world_group_] = *group;
202 } else {
203 *group = iter->second;
204 }
205 MS_LOG(INFO) << "It is world group " << world_group_ << ", no need to create it.";
206 return Status::SUCCESS;
207 }
208
209 auto it = groups_.find(group_name);
210 // If there already exits a group with the desired 'name',
211 // let the pointer point to the group.
212 if (it != groups_.end()) {
213 *group = it->second;
214 return Status::SUCCESS;
215 } else {
216 (void)group->Init(group_name, devices);
217 groups_[group_name] = *group;
218
219 std::vector<uint32_t> ranks;
220 (void)std::transform(std::begin(devices), std::end(devices), std::back_inserter(ranks),
221 [](const Device dev) { return static_cast<uint32_t>(dev.rank()); });
222 // Create group through the executor
223 auto context_ptr = MsContext::GetInstance();
224 MS_EXCEPTION_IF_NULL(context_ptr);
225 std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
226 uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
227
228 std::pair<std::string, std::vector<uint32_t>> group_info = std::make_pair(group_name, ranks);
229 group_info_.push_back(group_info);
230
231 bool ret = CreateGroupByExecutor(device_name, group_name, ranks, device_id);
232 if (!ret) {
233 MS_LOG(WARNING) << "Create group failed, group name is " << group_name;
234 return Status::FAILED;
235 }
236
237 MS_LOG(INFO) << "Create group success, group name is " << group_name;
238 return Status::SUCCESS;
239 }
240 }
241
DestroyGroup(const std::string & group_name) const242 Status GroupManager::DestroyGroup(const std::string &group_name) const {
243 auto context_ptr = MsContext::GetInstance();
244 MS_EXCEPTION_IF_NULL(context_ptr);
245 std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
246 uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
247 bool ret = DestroyGroupByExecutor(device_name, group_name, device_id);
248 if (!ret) {
249 return Status::FAILED;
250 }
251 return Status::SUCCESS;
252 }
253
DestroyGroup(mindspore::parallel::Group * const group)254 Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) {
255 const std::string name = (*group).name();
256 auto it = groups_.find(name);
257 if (it == groups_.cend()) {
258 MS_LOG(ERROR) << "Could not find group name :" << name;
259 return Status::FAILED;
260 }
261 (void)groups_.erase(it);
262 return DestroyGroup(name);
263 }
264
DestroyAllGroups()265 Status GroupManager::DestroyAllGroups() {
266 for (auto it = groups_.cbegin(); it != groups_.cend(); ++it) {
267 std::string name = it->first;
268 auto ret = DestroyGroup(name);
269 if (ret != Status::SUCCESS) {
270 return Status::FAILED;
271 }
272 }
273 groups_.clear();
274 return Status::SUCCESS;
275 }
276
GetRankID(const std::string & name,uint32_t * const rank_id)277 Status GroupManager::GetRankID(const std::string &name, uint32_t *const rank_id) {
278 auto it = groups_.find(name);
279 if (it == groups_.cend()) {
280 MS_LOG(ERROR) << "Could not find group name :" << name;
281 return Status::FAILED;
282 }
283 bool ret = CommManager::GetInstance().GetRankID(name, rank_id);
284 if (!ret) {
285 return Status::FAILED;
286 }
287 return Status::SUCCESS;
288 }
289
GetRankSize(const std::string & name,uint32_t * const rank_size)290 Status GroupManager::GetRankSize(const std::string &name, uint32_t *const rank_size) {
291 auto it = groups_.find(name);
292 if (it == groups_.cend()) {
293 MS_LOG(ERROR) << "Could not find group name :" << name;
294 return Status::FAILED;
295 }
296 bool ret = CommManager::GetInstance().GetRankSize(name, rank_size);
297 if (!ret) {
298 return Status::FAILED;
299 }
300 return Status::SUCCESS;
301 }
302
FindGroup(const std::string & name,mindspore::parallel::Group ** group)303 Status GroupManager::FindGroup(const std::string &name, mindspore::parallel::Group **group) {
304 auto it = groups_.find(name);
305 if (it == groups_.cend()) {
306 return Status::FAILED;
307 }
308 *group = &it->second;
309 return Status::SUCCESS;
310 }
311
Clear()312 void GroupManager::Clear() { (void)DestroyAllGroups(); }
313 } // namespace parallel
314 } // namespace mindspore
315