1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "plugin/device/cpu/hal/hardware/ms_collective_comm_lib.h"
18 #include "utils/ms_context.h"
19 #include "include/backend/distributed/constants.h"
20 #include "include/backend/distributed/recovery/recovery_context.h"
21 #include "runtime/collective/collective_communication_lib.h"
22 #include "plugin/device/cpu/hal/hardware/allreduce_impl.h"
23
24 namespace mindspore {
25 namespace device {
26 namespace cpu {
27 using distributed::cluster::topology::kDefaultRetryInterLower;
28 using distributed::cluster::topology::kDefaultRetryInterUpper;
29 using distributed::cluster::topology::kEnvNodeTimeOut;
30 using distributed::cluster::topology::kEnvRetryIntervalLower;
31 using distributed::cluster::topology::kEnvRetryIntervalUpper;
32 using distributed::recovery::RecoveryContext;
33
34 // These keywords is used for synchronization of collective communication's metadata(eg. unique id).
35 constexpr char kGroupInfoPrefix[] = "group_info_";
36 constexpr char kGroupName[] = "group_name";
37 constexpr char kUniqueId[] = "unique_id";
MsCollectiveCommLib()38 MsCollectiveCommLib::MsCollectiveCommLib() {
39 // Generate the global group name with node role.
40 global_group_name_ = kMCCLGlobalGroupName;
41 MS_LOG(INFO) << "Global group name of MindSpore collective communication library is " << global_group_name_;
42 }
43
Initialize(uint32_t global_rank,uint32_t global_rank_size,uint32_t local_rank_id)44 bool MsCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size, uint32_t local_rank_id) {
45 if (initialized_) {
46 MS_LOG(WARNING) << "MsCollectiveCommLib has already been initialized.";
47 return true;
48 }
49
50 // Only use AllReduceLauncher when this is CPU backend.
51 if (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kCPUDevice) {
52 launcher_ = std::make_unique<AllReduceLauncher>();
53 CHECK_IF_NULL(launcher_);
54 if (!launcher_->Initialize()) {
55 MS_LOG(EXCEPTION) << "Failed to initialize the allreduce launcher.";
56 }
57 node_ = launcher_->collective_node();
58 }
59
60 cgn_ = std::dynamic_pointer_cast<distributed::cluster::topology::ComputeGraphNode>(
61 ClusterContext::instance()->node_base());
62
63 std::string timeout_env = common::GetEnv(kEnvNodeTimeOut);
64 if (!timeout_env.empty()) {
65 MS_LOG(INFO) << "MS_NODE_TIMEOUT env set by user: " << timeout_env;
66 retry_count_ = std::stoi(timeout_env) / 3;
67 } else {
68 retry_count_ = kMSCollectiveRetryTime / 3;
69 }
70 MS_LOG(INFO) << "Query retry count is " << retry_count_;
71
72 int random_time_lower = common::GetEnv(kEnvRetryIntervalLower).empty()
73 ? kDefaultRetryInterLower
74 : std::stoi(common::GetEnv(kEnvRetryIntervalLower));
75 int random_time_upper = common::GetEnv(kEnvRetryIntervalUpper).empty()
76 ? kDefaultRetryInterUpper
77 : std::stoi(common::GetEnv(kEnvRetryIntervalUpper));
78 MS_LOG(INFO) << "Interval of retry allgather hostname lower and upper are " << random_time_lower << " and "
79 << random_time_upper;
80 rand_distrib_ = std::uniform_int_distribution<>(random_time_lower, random_time_upper);
81
82 global_rank_id_ = global_rank;
83 global_rank_size_ = global_rank_size;
84 local_rank_id_ = local_rank_id;
85 initialized_ = true;
86 finalized_ = false;
87 return true;
88 }
89
Finalize()90 bool MsCollectiveCommLib::Finalize() {
91 if (launcher_ != nullptr) {
92 return launcher_->Finalize();
93 }
94 return true;
95 }
96
CreateCommunicationGroup(const std::string & group_name,const std::vector<uint32_t> & group_ranks,uint32_t local_group_rank,uint32_t local_group_size)97 bool MsCollectiveCommLib::CreateCommunicationGroup(const std::string &group_name,
98 const std::vector<uint32_t> &group_ranks, uint32_t local_group_rank,
99 uint32_t local_group_size) {
100 if (groups_.count(group_name) != 0) {
101 MS_LOG(WARNING) << "The group " << group_name << " has already existed.";
102 return true;
103 }
104
105 MsCommunicationGroupPtr group = std::make_shared<MsCommunicationGroup>(group_name, group_ranks, global_rank_id_,
106 local_group_rank, local_group_size);
107 CHECK_IF_NULL(group);
108 groups_[group_name] = group;
109 return true;
110 }
111
AllGatherHostHashName(size_t host_hash_name,std::vector<size_t> * host_hash_names)112 bool MsCollectiveCommLib::AllGatherHostHashName(size_t host_hash_name, std::vector<size_t> *host_hash_names) {
113 CHECK_IF_NULL(host_hash_names);
114 CHECK_IF_NULL(cgn_);
115
116 auto role = common::GetEnv(distributed::kEnvRole);
117 bool success = false;
118
119 // Retry every random time interval.
120 std::random_device rd;
121 std::mt19937 gen(rd());
122 size_t retry = RecoveryContext::GetInstance()->enable_recovery() ? SIZE_MAX : retry_count_;
123 while (!success && --retry > 0) {
124 auto hostnames = cgn_->GetHostNames(role);
125 if (hostnames.size() < host_hash_names->size()) {
126 auto sleep_time = rand_distrib_(gen);
127 MS_LOG(WARNING) << "Retry to get hostname from the meta server node...Retry time: " << retry << "/"
128 << retry_count_ << ", sleep " << sleep_time;
129 (void)sleep(sleep_time);
130 continue;
131 } else if (hostnames.size() > host_hash_names->size()) {
132 MS_LOG(ERROR) << "Invalid number of hostnames, expected number of hostnames: " << host_hash_names->size()
133 << ", actual number of hostnames: " << hostnames.size();
134 return false;
135 }
136
137 for (size_t i = 0; i < host_hash_names->size(); i++) {
138 size_t host_hash = std::hash<std::string>()(hostnames[i]);
139 (*host_hash_names)[i] = host_hash;
140 }
141 success = true;
142 }
143 if (!success) {
144 MS_LOG(EXCEPTION) << "Failed to AllGather host's hash name due to timeout.";
145 }
146
147 return true;
148 }
149
BroadcastUniqueID(const std::string & group_name,size_t root_info_size,void * root_info)150 bool MsCollectiveCommLib::BroadcastUniqueID(const std::string &group_name, size_t root_info_size, void *root_info) {
151 CHECK_IF_NULL(root_info);
152 CHECK_IF_NULL(cgn_);
153 auto group = GetGroup(group_name);
154 CHECK_IF_NULL(group);
155
156 uint32_t group_rank_id = group->GetGroupRank(cgn_->rank_id());
157 if (group_rank_id == 0) {
158 while (!SendUniqueID(group_name, root_info_size, root_info)) {
159 MS_LOG(WARNING) << "Send unique id to scheduler failed, retrying...";
160 if (finalized_.load()) {
161 return false;
162 }
163
164 std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration));
165 }
166 } else {
167 while (!QueryUniqueID(group_name, root_info_size, root_info)) {
168 MS_LOG(WARNING) << "Query unique id from scheduler failed, retrying...";
169 if (finalized_.load()) {
170 return false;
171 }
172
173 std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration));
174 }
175 }
176 return true;
177 }
178
SendUniqueID(const std::string & group_name,size_t root_info_size,const void * root_info) const179 bool MsCollectiveCommLib::SendUniqueID(const std::string &group_name, size_t root_info_size,
180 const void *root_info) const {
181 CHECK_IF_NULL(root_info);
182 CHECK_IF_NULL(cgn_);
183
184 // Create the group info which contains the unique id and send it to the meta server.
185 std::string node_role_prefix = cgn_->role() + "_";
186 std::string group_info_key = node_role_prefix + kGroupInfoPrefix + group_name;
187
188 bool success = false;
189 // It this is not recovery scenario, retry for 3*200s, which is 10 minutes.
190 const size_t interval = 3;
191 size_t retry = RecoveryContext::GetInstance()->enable_recovery() ? SIZE_MAX : retry_count_;
192 while (!success && --retry > 0) {
193 success = cgn_->PutMetadata(group_info_key, root_info, root_info_size);
194 if (!success) {
195 MS_LOG(WARNING) << "Failed to send unique id for group " << group_name << ". Retry time: " << retry << "/"
196 << retry_count_;
197 (void)sleep(interval);
198 }
199 }
200 if (!success) {
201 MS_LOG(EXCEPTION) << "Failed to send unique id to the meta server node due to timeout.";
202 }
203 return true;
204 }
205
QueryUniqueID(const std::string & group_name,size_t root_info_size,void * root_info)206 bool MsCollectiveCommLib::QueryUniqueID(const std::string &group_name, size_t root_info_size, void *root_info) {
207 CHECK_IF_NULL(root_info);
208 CHECK_IF_NULL(cgn_);
209
210 std::string node_role_prefix = cgn_->role() + "_";
211 std::string group_info_key = node_role_prefix + kGroupInfoPrefix + group_name;
212 bool success = false;
213
214 // Retry every random time interval.
215 std::random_device rd;
216 std::mt19937 gen(rd());
217 size_t retry = RecoveryContext::GetInstance()->enable_recovery() ? SIZE_MAX : retry_count_;
218 while (!success && --retry > 0) {
219 auto unique_id = cgn_->GetMetadata(group_info_key);
220 if (unique_id.length() > 0) {
221 auto ret = memcpy_s(root_info, root_info_size, unique_id.data(), unique_id.length());
222 if (ret != EOK) {
223 MS_LOG(WARNING) << "The memcpy_s error, errorno(" << ret << ")";
224 return false;
225 }
226 success = true;
227 } else {
228 auto sleep_time = rand_distrib_(gen);
229 MS_LOG(WARNING) << "Retry to lookup the unique id for group " << group_name
230 << " from the meta server node...Retry time: " << retry << "/" << retry_count_ << ", sleep "
231 << sleep_time;
232 (void)sleep(sleep_time);
233 }
234 }
235 if (!success) {
236 const auto &group_info = groups_.at(group_name);
237 uint32_t root_rank = group_info->group_ranks().at(0);
238 MS_LOG(EXCEPTION)
239 << "Failed to fetch the unique id of the collective lib from the meta server node. Maybe the root rank process "
240 "of this group has exited or has not executed to QueryUniqueID step. Please check root rank: "
241 << root_rank << "'s log.";
242 }
243 return true;
244 }
245
AllReduce(const void * send_buff,void * recv_buff,size_t send_count,TypeId data_type,CollectiveOpReduceType reduce_op,const std::string & group_name,void *)246 bool MsCollectiveCommLib::AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
247 CollectiveOpReduceType reduce_op, const std::string &group_name, void *) {
248 CHECK_IF_NULL(send_buff);
249 CHECK_IF_NULL(recv_buff);
250 CHECK_IF_NULL(launcher_);
251 if (data_type != TypeId::kNumberTypeFloat32) {
252 MS_LOG(EXCEPTION) << "AllReduce only support float32.";
253 }
254 if (reduce_op != CollectiveOpReduceType::Reduce_Sum) {
255 MS_LOG(EXCEPTION) << "AllReduce only support reduce sum.";
256 }
257 bool ret = launcher_->Execute(send_buff, recv_buff, send_count);
258 return ret;
259 }
260
AllGather(const void * send_buff,void * recv_buff,size_t send_count,TypeId data_type,const std::string &,void *)261 bool MsCollectiveCommLib::AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
262 const std::string &, void *) {
263 CHECK_IF_NULL(send_buff);
264 CHECK_IF_NULL(recv_buff);
265 CHECK_IF_NULL(node_);
266
267 switch (data_type) {
268 case TypeId::kNumberTypeInt8:
269 return CollectiveOpsImpl::GetInstance().AllGather<char>(send_buff, recv_buff, send_count, node_);
270 case TypeId::kNumberTypeInt32:
271 case TypeId::kNumberTypeInt:
272 return CollectiveOpsImpl::GetInstance().AllGather<int32_t>(send_buff, recv_buff, send_count, node_);
273 case TypeId::kNumberTypeUInt64:
274 return CollectiveOpsImpl::GetInstance().AllGather<uint64_t>(send_buff, recv_buff, send_count, node_);
275 case TypeId::kNumberTypeFloat32:
276 case TypeId::kNumberTypeFloat:
277 return CollectiveOpsImpl::GetInstance().AllGather<float>(send_buff, recv_buff, send_count, node_);
278 default:
279 return false;
280 }
281 }
282
Broadcast(const void * send_buff,void * recv_buff,size_t send_count,TypeId data_type,uint32_t root_rank,const std::string & group_name,void *)283 bool MsCollectiveCommLib::Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
284 uint32_t root_rank, const std::string &group_name, void *) {
285 CHECK_IF_NULL(send_buff);
286 CHECK_IF_NULL(recv_buff);
287 CHECK_IF_NULL(node_);
288
289 if (groups_.count(group_name) == 0) {
290 MS_LOG(ERROR) << "The group " << group_name << " does not exist.";
291 return false;
292 }
293
294 auto group = groups_[group_name];
295 CHECK_IF_NULL(group);
296 CommunicationGroupInfo group_info = {};
297 group_info.size = group->group_size();
298 group_info.global_rank = global_rank_id_;
299 group_info.group_ranks = group->group_ranks();
300 group_info.global_to_group_ranks = group->global_to_group_ranks();
301 group_info.group_to_global_ranks = group->group_to_global_ranks();
302
303 switch (data_type) {
304 case TypeId::kNumberTypeInt8:
305 return CollectiveOpsImpl::GetInstance().Broadcast<char>(send_buff, recv_buff, send_count, root_rank, node_,
306 group_info);
307 case TypeId::kNumberTypeInt32:
308 [[fallthrough]];
309 case TypeId::kNumberTypeInt:
310 return CollectiveOpsImpl::GetInstance().Broadcast<int32_t>(send_buff, recv_buff, send_count, root_rank, node_,
311 group_info);
312 case TypeId::kNumberTypeUInt64:
313 return CollectiveOpsImpl::GetInstance().Broadcast<uint64_t>(send_buff, recv_buff, send_count, root_rank, node_,
314 group_info);
315 case TypeId::kNumberTypeFloat32:
316 [[fallthrough]];
317 case TypeId::kNumberTypeFloat:
318 return CollectiveOpsImpl::GetInstance().Broadcast<float>(send_buff, recv_buff, send_count, root_rank, node_,
319 group_info);
320 default:
321 return false;
322 }
323 }
324 } // namespace cpu
325 } // namespace device
326 } // namespace mindspore
327