• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "plugin/device/cpu/hal/hardware/ms_collective_comm_lib.h"
18 #include "utils/ms_context.h"
19 #include "include/backend/distributed/constants.h"
20 #include "include/backend/distributed/recovery/recovery_context.h"
21 #include "runtime/collective/collective_communication_lib.h"
22 #include "plugin/device/cpu/hal/hardware/allreduce_impl.h"
23 
24 namespace mindspore {
25 namespace device {
26 namespace cpu {
27 using distributed::cluster::topology::kDefaultRetryInterLower;
28 using distributed::cluster::topology::kDefaultRetryInterUpper;
29 using distributed::cluster::topology::kEnvNodeTimeOut;
30 using distributed::cluster::topology::kEnvRetryIntervalLower;
31 using distributed::cluster::topology::kEnvRetryIntervalUpper;
32 using distributed::recovery::RecoveryContext;
33 
34 // These keywords is used for synchronization of collective communication's metadata(eg. unique id).
35 constexpr char kGroupInfoPrefix[] = "group_info_";
36 constexpr char kGroupName[] = "group_name";
37 constexpr char kUniqueId[] = "unique_id";
MsCollectiveCommLib()38 MsCollectiveCommLib::MsCollectiveCommLib() {
39   // Generate the global group name with node role.
40   global_group_name_ = kMCCLGlobalGroupName;
41   MS_LOG(INFO) << "Global group name of MindSpore collective communication library is " << global_group_name_;
42 }
43 
Initialize(uint32_t global_rank,uint32_t global_rank_size,uint32_t local_rank_id)44 bool MsCollectiveCommLib::Initialize(uint32_t global_rank, uint32_t global_rank_size, uint32_t local_rank_id) {
45   if (initialized_) {
46     MS_LOG(WARNING) << "MsCollectiveCommLib has already been initialized.";
47     return true;
48   }
49 
50   // Only use AllReduceLauncher when this is CPU backend.
51   if (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kCPUDevice) {
52     launcher_ = std::make_unique<AllReduceLauncher>();
53     CHECK_IF_NULL(launcher_);
54     if (!launcher_->Initialize()) {
55       MS_LOG(EXCEPTION) << "Failed to initialize the allreduce launcher.";
56     }
57     node_ = launcher_->collective_node();
58   }
59 
60   cgn_ = std::dynamic_pointer_cast<distributed::cluster::topology::ComputeGraphNode>(
61     ClusterContext::instance()->node_base());
62 
63   std::string timeout_env = common::GetEnv(kEnvNodeTimeOut);
64   if (!timeout_env.empty()) {
65     MS_LOG(INFO) << "MS_NODE_TIMEOUT env set by user: " << timeout_env;
66     retry_count_ = std::stoi(timeout_env) / 3;
67   } else {
68     retry_count_ = kMSCollectiveRetryTime / 3;
69   }
70   MS_LOG(INFO) << "Query retry count is " << retry_count_;
71 
72   int random_time_lower = common::GetEnv(kEnvRetryIntervalLower).empty()
73                             ? kDefaultRetryInterLower
74                             : std::stoi(common::GetEnv(kEnvRetryIntervalLower));
75   int random_time_upper = common::GetEnv(kEnvRetryIntervalUpper).empty()
76                             ? kDefaultRetryInterUpper
77                             : std::stoi(common::GetEnv(kEnvRetryIntervalUpper));
78   MS_LOG(INFO) << "Interval of retry allgather hostname lower and upper are " << random_time_lower << " and "
79                << random_time_upper;
80   rand_distrib_ = std::uniform_int_distribution<>(random_time_lower, random_time_upper);
81 
82   global_rank_id_ = global_rank;
83   global_rank_size_ = global_rank_size;
84   local_rank_id_ = local_rank_id;
85   initialized_ = true;
86   finalized_ = false;
87   return true;
88 }
89 
Finalize()90 bool MsCollectiveCommLib::Finalize() {
91   if (launcher_ != nullptr) {
92     return launcher_->Finalize();
93   }
94   return true;
95 }
96 
CreateCommunicationGroup(const std::string & group_name,const std::vector<uint32_t> & group_ranks,uint32_t local_group_rank,uint32_t local_group_size)97 bool MsCollectiveCommLib::CreateCommunicationGroup(const std::string &group_name,
98                                                    const std::vector<uint32_t> &group_ranks, uint32_t local_group_rank,
99                                                    uint32_t local_group_size) {
100   if (groups_.count(group_name) != 0) {
101     MS_LOG(WARNING) << "The group " << group_name << " has already existed.";
102     return true;
103   }
104 
105   MsCommunicationGroupPtr group = std::make_shared<MsCommunicationGroup>(group_name, group_ranks, global_rank_id_,
106                                                                          local_group_rank, local_group_size);
107   CHECK_IF_NULL(group);
108   groups_[group_name] = group;
109   return true;
110 }
111 
AllGatherHostHashName(size_t host_hash_name,std::vector<size_t> * host_hash_names)112 bool MsCollectiveCommLib::AllGatherHostHashName(size_t host_hash_name, std::vector<size_t> *host_hash_names) {
113   CHECK_IF_NULL(host_hash_names);
114   CHECK_IF_NULL(cgn_);
115 
116   auto role = common::GetEnv(distributed::kEnvRole);
117   bool success = false;
118 
119   // Retry every random time interval.
120   std::random_device rd;
121   std::mt19937 gen(rd());
122   size_t retry = RecoveryContext::GetInstance()->enable_recovery() ? SIZE_MAX : retry_count_;
123   while (!success && --retry > 0) {
124     auto hostnames = cgn_->GetHostNames(role);
125     if (hostnames.size() < host_hash_names->size()) {
126       auto sleep_time = rand_distrib_(gen);
127       MS_LOG(WARNING) << "Retry to get hostname from the meta server node...Retry time: " << retry << "/"
128                       << retry_count_ << ", sleep " << sleep_time;
129       (void)sleep(sleep_time);
130       continue;
131     } else if (hostnames.size() > host_hash_names->size()) {
132       MS_LOG(ERROR) << "Invalid number of hostnames, expected number of hostnames: " << host_hash_names->size()
133                     << ", actual number of hostnames: " << hostnames.size();
134       return false;
135     }
136 
137     for (size_t i = 0; i < host_hash_names->size(); i++) {
138       size_t host_hash = std::hash<std::string>()(hostnames[i]);
139       (*host_hash_names)[i] = host_hash;
140     }
141     success = true;
142   }
143   if (!success) {
144     MS_LOG(EXCEPTION) << "Failed to AllGather host's hash name due to timeout.";
145   }
146 
147   return true;
148 }
149 
BroadcastUniqueID(const std::string & group_name,size_t root_info_size,void * root_info)150 bool MsCollectiveCommLib::BroadcastUniqueID(const std::string &group_name, size_t root_info_size, void *root_info) {
151   CHECK_IF_NULL(root_info);
152   CHECK_IF_NULL(cgn_);
153   auto group = GetGroup(group_name);
154   CHECK_IF_NULL(group);
155 
156   uint32_t group_rank_id = group->GetGroupRank(cgn_->rank_id());
157   if (group_rank_id == 0) {
158     while (!SendUniqueID(group_name, root_info_size, root_info)) {
159       MS_LOG(WARNING) << "Send unique id to scheduler failed, retrying...";
160       if (finalized_.load()) {
161         return false;
162       }
163 
164       std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration));
165     }
166   } else {
167     while (!QueryUniqueID(group_name, root_info_size, root_info)) {
168       MS_LOG(WARNING) << "Query unique id from scheduler failed, retrying...";
169       if (finalized_.load()) {
170         return false;
171       }
172 
173       std::this_thread::sleep_for(std::chrono::seconds(kWaitDuration));
174     }
175   }
176   return true;
177 }
178 
SendUniqueID(const std::string & group_name,size_t root_info_size,const void * root_info) const179 bool MsCollectiveCommLib::SendUniqueID(const std::string &group_name, size_t root_info_size,
180                                        const void *root_info) const {
181   CHECK_IF_NULL(root_info);
182   CHECK_IF_NULL(cgn_);
183 
184   // Create the group info which contains the unique id and send it to the meta server.
185   std::string node_role_prefix = cgn_->role() + "_";
186   std::string group_info_key = node_role_prefix + kGroupInfoPrefix + group_name;
187 
188   bool success = false;
189   // It this is not recovery scenario, retry for 3*200s, which is 10 minutes.
190   const size_t interval = 3;
191   size_t retry = RecoveryContext::GetInstance()->enable_recovery() ? SIZE_MAX : retry_count_;
192   while (!success && --retry > 0) {
193     success = cgn_->PutMetadata(group_info_key, root_info, root_info_size);
194     if (!success) {
195       MS_LOG(WARNING) << "Failed to send unique id for group " << group_name << ". Retry time: " << retry << "/"
196                       << retry_count_;
197       (void)sleep(interval);
198     }
199   }
200   if (!success) {
201     MS_LOG(EXCEPTION) << "Failed to send unique id to the meta server node due to timeout.";
202   }
203   return true;
204 }
205 
QueryUniqueID(const std::string & group_name,size_t root_info_size,void * root_info)206 bool MsCollectiveCommLib::QueryUniqueID(const std::string &group_name, size_t root_info_size, void *root_info) {
207   CHECK_IF_NULL(root_info);
208   CHECK_IF_NULL(cgn_);
209 
210   std::string node_role_prefix = cgn_->role() + "_";
211   std::string group_info_key = node_role_prefix + kGroupInfoPrefix + group_name;
212   bool success = false;
213 
214   // Retry every random time interval.
215   std::random_device rd;
216   std::mt19937 gen(rd());
217   size_t retry = RecoveryContext::GetInstance()->enable_recovery() ? SIZE_MAX : retry_count_;
218   while (!success && --retry > 0) {
219     auto unique_id = cgn_->GetMetadata(group_info_key);
220     if (unique_id.length() > 0) {
221       auto ret = memcpy_s(root_info, root_info_size, unique_id.data(), unique_id.length());
222       if (ret != EOK) {
223         MS_LOG(WARNING) << "The memcpy_s error, errorno(" << ret << ")";
224         return false;
225       }
226       success = true;
227     } else {
228       auto sleep_time = rand_distrib_(gen);
229       MS_LOG(WARNING) << "Retry to lookup the unique id for group " << group_name
230                       << " from the meta server node...Retry time: " << retry << "/" << retry_count_ << ", sleep "
231                       << sleep_time;
232       (void)sleep(sleep_time);
233     }
234   }
235   if (!success) {
236     const auto &group_info = groups_.at(group_name);
237     uint32_t root_rank = group_info->group_ranks().at(0);
238     MS_LOG(EXCEPTION)
239       << "Failed to fetch the unique id of the collective lib from the meta server node. Maybe the root rank process "
240          "of this group has exited or has not executed to QueryUniqueID step. Please check root rank: "
241       << root_rank << "'s log.";
242   }
243   return true;
244 }
245 
AllReduce(const void * send_buff,void * recv_buff,size_t send_count,TypeId data_type,CollectiveOpReduceType reduce_op,const std::string & group_name,void *)246 bool MsCollectiveCommLib::AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
247                                     CollectiveOpReduceType reduce_op, const std::string &group_name, void *) {
248   CHECK_IF_NULL(send_buff);
249   CHECK_IF_NULL(recv_buff);
250   CHECK_IF_NULL(launcher_);
251   if (data_type != TypeId::kNumberTypeFloat32) {
252     MS_LOG(EXCEPTION) << "AllReduce only support float32.";
253   }
254   if (reduce_op != CollectiveOpReduceType::Reduce_Sum) {
255     MS_LOG(EXCEPTION) << "AllReduce only support reduce sum.";
256   }
257   bool ret = launcher_->Execute(send_buff, recv_buff, send_count);
258   return ret;
259 }
260 
AllGather(const void * send_buff,void * recv_buff,size_t send_count,TypeId data_type,const std::string &,void *)261 bool MsCollectiveCommLib::AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
262                                     const std::string &, void *) {
263   CHECK_IF_NULL(send_buff);
264   CHECK_IF_NULL(recv_buff);
265   CHECK_IF_NULL(node_);
266 
267   switch (data_type) {
268     case TypeId::kNumberTypeInt8:
269       return CollectiveOpsImpl::GetInstance().AllGather<char>(send_buff, recv_buff, send_count, node_);
270     case TypeId::kNumberTypeInt32:
271     case TypeId::kNumberTypeInt:
272       return CollectiveOpsImpl::GetInstance().AllGather<int32_t>(send_buff, recv_buff, send_count, node_);
273     case TypeId::kNumberTypeUInt64:
274       return CollectiveOpsImpl::GetInstance().AllGather<uint64_t>(send_buff, recv_buff, send_count, node_);
275     case TypeId::kNumberTypeFloat32:
276     case TypeId::kNumberTypeFloat:
277       return CollectiveOpsImpl::GetInstance().AllGather<float>(send_buff, recv_buff, send_count, node_);
278     default:
279       return false;
280   }
281 }
282 
Broadcast(const void * send_buff,void * recv_buff,size_t send_count,TypeId data_type,uint32_t root_rank,const std::string & group_name,void *)283 bool MsCollectiveCommLib::Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
284                                     uint32_t root_rank, const std::string &group_name, void *) {
285   CHECK_IF_NULL(send_buff);
286   CHECK_IF_NULL(recv_buff);
287   CHECK_IF_NULL(node_);
288 
289   if (groups_.count(group_name) == 0) {
290     MS_LOG(ERROR) << "The group " << group_name << " does not exist.";
291     return false;
292   }
293 
294   auto group = groups_[group_name];
295   CHECK_IF_NULL(group);
296   CommunicationGroupInfo group_info = {};
297   group_info.size = group->group_size();
298   group_info.global_rank = global_rank_id_;
299   group_info.group_ranks = group->group_ranks();
300   group_info.global_to_group_ranks = group->global_to_group_ranks();
301   group_info.group_to_global_ranks = group->group_to_global_ranks();
302 
303   switch (data_type) {
304     case TypeId::kNumberTypeInt8:
305       return CollectiveOpsImpl::GetInstance().Broadcast<char>(send_buff, recv_buff, send_count, root_rank, node_,
306                                                               group_info);
307     case TypeId::kNumberTypeInt32:
308       [[fallthrough]];
309     case TypeId::kNumberTypeInt:
310       return CollectiveOpsImpl::GetInstance().Broadcast<int32_t>(send_buff, recv_buff, send_count, root_rank, node_,
311                                                                  group_info);
312     case TypeId::kNumberTypeUInt64:
313       return CollectiveOpsImpl::GetInstance().Broadcast<uint64_t>(send_buff, recv_buff, send_count, root_rank, node_,
314                                                                   group_info);
315     case TypeId::kNumberTypeFloat32:
316       [[fallthrough]];
317     case TypeId::kNumberTypeFloat:
318       return CollectiveOpsImpl::GetInstance().Broadcast<float>(send_buff, recv_buff, send_count, root_rank, node_,
319                                                                group_info);
320     default:
321       return false;
322   }
323 }
324 }  // namespace cpu
325 }  // namespace device
326 }  // namespace mindspore
327