• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/comm_manager.h"
18 #include "utils/convert_utils.h"
19 #include "utils/ms_context.h"
20 #include "frontend/parallel/context.h"
21 #include "frontend/parallel/group_manager.h"
22 
23 #ifndef NO_DLIB
24 #include "runtime/hccl_adapter/hccl_adapter.h"
25 #include "hccl/hcom.h"
26 #include "runtime/device/ascend/distribute/ascend_collective.h"
27 #endif
28 
29 #if defined(ENABLE_GPU)
30 #include "runtime/device/gpu/distribution/collective_init.h"
31 using CollectiveInitializer = mindspore::device::gpu::CollectiveInitializer;
32 using CreateCommGroupFunc = mindspore::device::gpu::CreateCommGroupFunc;
33 using GetRankIDByGroupFunc = mindspore::device::gpu::GetRankIDByGroupFunc;
34 using GetGroupSizeFunc = mindspore::device::gpu::GetGroupSizeFunc;
35 using DestroyGroupFunc = mindspore::device::gpu::DestroyGroupFunc;
36 #endif
37 
38 namespace mindspore {
39 #ifndef NO_DLIB
GetInstance()40 CommManager &CommManager::GetInstance() noexcept {
41   static CommManager instance("hccl");
42   return instance;
43 }
44 
45 #define HCCL_RUN_CHECK(op_name, group, op)                      \
46   do {                                                          \
47     auto hccl_result = (op);                                    \
48     if (hccl_result != 0) {                                     \
49       MS_LOG(ERROR) << op_name << " failed: #" << group << "#"; \
50       return false;                                             \
51     }                                                           \
52   } while (0)
53 
54 #define HCCL_GROUP_CHECK_EMPTY(group)                              \
55   do {                                                             \
56     if (group.length() == 0) {                                     \
57       MS_LOG(ERROR) << "The length of group name should not be 0"; \
58       return false;                                                \
59     }                                                              \
60   } while (0)
61 
62 #define HCCL_GROUP_CHECK_IS_WORLD(group)                                \
63   do {                                                                  \
64     if (group == "hccl_world_group") {                                  \
65       MS_LOG(ERROR) << "The group name should not be hccl_world_group"; \
66       return false;                                                     \
67     }                                                                   \
68   } while (0)
69 
CreateGroupSync(const string & group,const vector<unsigned int> & rank_id_list) const70 bool CommManager::CreateGroupSync(const string &group, const vector<unsigned int> &rank_id_list) const {
71   auto rank_size = rank_id_list.size();
72   HCCL_GROUP_CHECK_EMPTY(group);
73   HCCL_GROUP_CHECK_IS_WORLD(group);
74   auto context_ptr = MsContext::GetInstance();
75   MS_EXCEPTION_IF_NULL(context_ptr);
76   bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
77   auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
78   if (!is_task_sink && mode == kGraphMode) {
79     HcclCollectiveGroup::instance().CreateCommGroup(group, rank_id_list);
80   } else {
81     HCCL_RUN_CHECK(string("create communicate group"), group,
82                    hccl::HcclAdapter::GetInstance().HcclCreateGroup(group, UlongToUint(rank_size),
83                                                                     vector<unsigned int>(rank_id_list).data()));
84   }
85   return true;
86 }
87 
GetRankID(const string & group,unsigned int * rank_id) const88 bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const {
89   auto context = MsContext::GetInstance();
90   MS_EXCEPTION_IF_NULL(context);
91   if (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
92     HCCL_GROUP_CHECK_EMPTY(group);
93     if (!context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
94       *rank_id = static_cast<unsigned int>(HcclCollectiveGroup::instance().GetRankId(group));
95     } else {
96       HCCL_RUN_CHECK(string("get rank_id"), group, hccl::HcclAdapter::GetInstance().HcclGetRankId(group, rank_id));
97     }
98   } else {
99     HCCL_RUN_CHECK(string("get rank_id"), group, hccl::HcclAdapter::GetInstance().HcclGetRankId(rank_id));
100   }
101   return true;
102 }
103 
GetRankSize(const string & group,unsigned int * rank_size) const104 bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const {
105   auto context = MsContext::GetInstance();
106   MS_EXCEPTION_IF_NULL(context);
107   if (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
108     HCCL_GROUP_CHECK_EMPTY(group);
109     if (!context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
110       *rank_size = static_cast<unsigned int>(HcclCollectiveGroup::instance().GetRankSize(group));
111     } else {
112       HCCL_RUN_CHECK(string("get rank size"), group,
113                      hccl::HcclAdapter::GetInstance().HcclGetRankSize(group, rank_size));
114     }
115   } else {
116     HCCL_RUN_CHECK(string("get rank size"), group, hccl::HcclAdapter::GetInstance().HcclGetRankSize(rank_size));
117   }
118   return true;
119 }
120 
DestroyGroup(const string & group) const121 bool CommManager::DestroyGroup(const string &group) const {
122   HCCL_GROUP_CHECK_EMPTY(group);
123   HCCL_GROUP_CHECK_IS_WORLD(group);
124   HCCL_RUN_CHECK(string("destroy communicate group"), group, hccl::HcclAdapter::GetInstance().HcclDestroyGroup(group));
125   return true;
126 }
127 #elif defined(ENABLE_GPU)
128 CommManager &CommManager::GetInstance() noexcept {
129   static CommManager instance("nccl");
130   return instance;
131 }
132 
133 bool CommManager::CreateGroupSync(const string &group, const vector<unsigned int> &rank_id_list) const {
134   const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
135   if (!collective_handle_) {
136     MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
137   }
138   MS_LOG(INFO) << "Create communication group " << group << " by rank id list " << rank_id_list;
139   auto create_comm_group_funcptr =
140     reinterpret_cast<CreateCommGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "CreateCommGroup"));
141   MS_EXCEPTION_IF_NULL(create_comm_group_funcptr);
142   bool ret = (*create_comm_group_funcptr)(group, rank_id_list);
143   if (!ret) {
144     MS_LOG(ERROR) << "Creating group " << group << "for rank id list" << rank_id_list << "failed.";
145     return ret;
146   }
147   return ret;
148 }
149 
150 bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const {
151   const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
152   if (!collective_handle_) {
153     MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
154   }
155   auto get_rank_id_funcptr =
156     reinterpret_cast<GetRankIDByGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "GetRankIDByGroup"));
157   MS_EXCEPTION_IF_NULL(get_rank_id_funcptr);
158   int rank = (*get_rank_id_funcptr)(group);
159   *rank_id = static_cast<unsigned int>(rank);
160   MS_LOG(INFO) << "This process rank id is " << *rank_id << " in group " << group;
161   return true;
162 }
163 
164 bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const {
165   const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
166   if (!collective_handle_) {
167     MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
168   }
169   auto get_group_size_funcptr =
170     reinterpret_cast<GetGroupSizeFunc>(dlsym(const_cast<void *>(collective_handle_), "GetGroupSize"));
171   MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
172   int size = (*get_group_size_funcptr)(group);
173   *rank_size = static_cast<unsigned int>(size);
174   MS_LOG(INFO) << "Group " << group << " size is " << *rank_size;
175   return true;
176 }
177 
178 bool CommManager::DestroyGroup(const string &group) const {
179   const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
180   if (!collective_handle_) {
181     MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
182   }
183   auto destroy_group_funcptr =
184     reinterpret_cast<DestroyGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "DestroyGroup"));
185   MS_EXCEPTION_IF_NULL(destroy_group_funcptr);
186 
187   bool ret = (*destroy_group_funcptr)(group);
188   if (!ret) {
189     MS_LOG(ERROR) << "Destroying group " << group << " failed.";
190     return ret;
191   }
192   return ret;
193 }
194 #else
195 CommManager &CommManager::GetInstance() noexcept {
196   static CommManager instance("hccl");
197   return instance;
198 }
199 
200 bool CommManager::CreateGroupSync(const string &, const vector<unsigned int> &) const { return true; }
201 
202 bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const { return true; }
203 
204 bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const {
205   *rank_size = NO_COMM_DLIB_RANK_SIZE;
206   return true;
207 }
208 
209 bool CommManager::DestroyGroup(const string &group) const { return true; }
210 #endif
211 
GetRank()212 uint32_t GetRank() {
213   uint32_t rank_id = 0;
214   auto ms_context = MsContext::GetInstance();
215   MS_EXCEPTION_IF_NULL(ms_context);
216   std::string world_group;
217   std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
218   if (backend == kAscendDevice) {
219     world_group = parallel::HCCL_WORLD_GROUP;
220   } else if (backend == kGPUDevice) {
221     world_group = parallel::NCCL_WORLD_GROUP;
222   } else {
223     // Other backends like CPU not support parallel, return rank_id with default 0.
224     return rank_id;
225   }
226   auto parallel_context = parallel::ParallelContext::GetInstance();
227   MS_EXCEPTION_IF_NULL(parallel_context);
228   if (parallel_context->parallel_mode() != parallel::STAND_ALONE) {
229 #ifndef NO_DLIB
230     // Check HCCL inited.
231     if (!hccl::HcclAdapter::GetInstance().Inited()) {
232       MS_LOG(DEBUG) << "HCCL not inited, return rank_id = 0";
233       return rank_id;
234     }
235 #elif defined(ENABLE_GPU)
236     // Check NCCL inited.
237     if (!CollectiveInitializer::instance().collective_inited()) {
238       MS_LOG(DEBUG) << "NCLL not inited, return rank_id = 0";
239       return rank_id;
240     }
241 #endif
242     if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
243       MS_LOG(EXCEPTION) << "Get rank id failed.";
244     }
245   }
246   return rank_id;
247 }
248 
IsStandAlone()249 bool IsStandAlone() {
250   auto parallel_context = parallel::ParallelContext::GetInstance();
251   MS_EXCEPTION_IF_NULL(parallel_context);
252   return parallel_context->parallel_mode() == parallel::STAND_ALONE;
253 }
254 }  // namespace mindspore
255