• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fl/server/distributed_count_service.h"
18 #include <string>
19 #include <memory>
20 #include <vector>
21 
22 namespace mindspore {
23 namespace fl {
24 namespace server {
Initialize(const std::shared_ptr<ps::core::ServerNode> & server_node,uint32_t counting_server_rank)25 void DistributedCountService::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node,
26                                          uint32_t counting_server_rank) {
27   MS_EXCEPTION_IF_NULL(server_node);
28   server_node_ = server_node;
29   local_rank_ = server_node_->rank_id();
30   server_num_ = ps::PSContext::instance()->initial_server_num();
31   counting_server_rank_ = counting_server_rank;
32   return;
33 }
34 
RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> & communicator)35 void DistributedCountService::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
36   MS_EXCEPTION_IF_NULL(communicator);
37   communicator_ = communicator;
38   communicator_->RegisterMsgCallBack(
39     "count", std::bind(&DistributedCountService::HandleCountRequest, this, std::placeholders::_1));
40   communicator_->RegisterMsgCallBack(
41     "countReachThreshold",
42     std::bind(&DistributedCountService::HandleCountReachThresholdRequest, this, std::placeholders::_1));
43   communicator_->RegisterMsgCallBack(
44     "counterEvent", std::bind(&DistributedCountService::HandleCounterEvent, this, std::placeholders::_1));
45 }
46 
RegisterCounter(const std::string & name,size_t global_threshold_count,const CounterHandlers & counter_handlers)47 void DistributedCountService::RegisterCounter(const std::string &name, size_t global_threshold_count,
48                                               const CounterHandlers &counter_handlers) {
49   if (!counter_handlers.first_count_handler || !counter_handlers.last_count_handler) {
50     MS_LOG(EXCEPTION) << "First count handler or last count handler is not set.";
51     return;
52   }
53   if (global_threshold_count_.count(name) != 0) {
54     MS_LOG(INFO) << "Counter for " << name << " is already set.";
55     return;
56   }
57 
58   MS_LOG(INFO) << "Rank " << local_rank_ << " register counter for " << name << " count:" << global_threshold_count;
59   // If the server is the leader server, it needs to set the counter handlers and do the real counting.
60   if (local_rank_ == counting_server_rank_) {
61     global_current_count_[name] = {};
62     global_threshold_count_[name] = global_threshold_count;
63     mutex_[name];
64   }
65   counter_handlers_[name] = counter_handlers;
66   return;
67 }
68 
ReInitCounter(const std::string & name,size_t global_threshold_count)69 bool DistributedCountService::ReInitCounter(const std::string &name, size_t global_threshold_count) {
70   MS_LOG(INFO) << "Rank " << local_rank_ << " reinitialize counter for " << name << " count:" << global_threshold_count;
71   if (local_rank_ == counting_server_rank_) {
72     std::unique_lock<std::mutex> lock(mutex_[name]);
73     if (global_threshold_count_.count(name) == 0) {
74       MS_LOG(INFO) << "Counter for " << name << " is not set.";
75       return false;
76     }
77     global_current_count_[name] = {};
78     global_threshold_count_[name] = global_threshold_count;
79   }
80   return true;
81 }
82 
Count(const std::string & name,const std::string & id,std::string * reason)83 bool DistributedCountService::Count(const std::string &name, const std::string &id, std::string *reason) {
84   MS_LOG(INFO) << "Rank " << local_rank_ << " reports count for " << name << " of " << id;
85   if (local_rank_ == counting_server_rank_) {
86     if (global_threshold_count_.count(name) == 0) {
87       MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
88       return false;
89     }
90 
91     std::unique_lock<std::mutex> lock(mutex_[name]);
92     if (global_current_count_[name].size() >= global_threshold_count_[name]) {
93       MS_LOG(ERROR) << "Count for " << name << " is already enough. Threshold count is "
94                     << global_threshold_count_[name];
95       return false;
96     }
97 
98     MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
99     (void)global_current_count_[name].insert(id);
100     if (!TriggerCounterEvent(name, reason)) {
101       MS_LOG(ERROR) << "Leader server trigger count event failed.";
102       return false;
103     }
104   } else {
105     // If this server is a follower server, it needs to send CountRequest to the leader server.
106     CountRequest report_count_req;
107     report_count_req.set_name(name);
108     report_count_req.set_id(id);
109 
110     std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr;
111     if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, ps::core::TcpUserCommand::kCount,
112                                       &report_cnt_rsp_msg)) {
113       MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name;
114       if (reason != nullptr) {
115         *reason = kNetworkError;
116       }
117       return false;
118     }
119 
120     MS_ERROR_IF_NULL_W_RET_VAL(report_cnt_rsp_msg, false);
121     CountResponse count_rsp;
122     (void)count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size()));
123     if (!count_rsp.result()) {
124       MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason();
125       // If the error is caused by the network issue, return the reason.
126       if (reason != nullptr && count_rsp.reason().find(kNetworkError) != std::string::npos) {
127         *reason = kNetworkError;
128       }
129       return false;
130     }
131   }
132   return true;
133 }
134 
CountReachThreshold(const std::string & name)135 bool DistributedCountService::CountReachThreshold(const std::string &name) {
136   MS_LOG(INFO) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name;
137   if (local_rank_ == counting_server_rank_) {
138     if (global_threshold_count_.count(name) == 0) {
139       MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
140       return false;
141     }
142 
143     std::unique_lock<std::mutex> lock(mutex_[name]);
144     return global_current_count_[name].size() == global_threshold_count_[name];
145   } else {
146     CountReachThresholdRequest count_reach_threshold_req;
147     count_reach_threshold_req.set_name(name);
148 
149     std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr;
150     if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_,
151                                       ps::core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
152       MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name;
153       return false;
154     }
155 
156     MS_ERROR_IF_NULL_W_RET_VAL(query_cnt_enough_rsp_msg, false);
157     CountReachThresholdResponse count_reach_threshold_rsp;
158     (void)count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(),
159                                                    SizeToInt(query_cnt_enough_rsp_msg->size()));
160     return count_reach_threshold_rsp.is_enough();
161   }
162 }
163 
ResetCounter(const std::string & name)164 void DistributedCountService::ResetCounter(const std::string &name) {
165   if (local_rank_ == counting_server_rank_) {
166     MS_LOG(DEBUG) << "Leader server reset count for " << name;
167     global_current_count_[name].clear();
168   }
169   return;
170 }
171 
ReInitForScaling()172 bool DistributedCountService::ReInitForScaling() {
173   // If DistributedCountService is not initialized yet but the scaling event is triggered, do not throw exception.
174   if (server_node_ == nullptr) {
175     return true;
176   }
177 
178   MS_LOG(INFO) << "Cluster scaling completed. Reinitialize for distributed count service.";
179   local_rank_ = server_node_->rank_id();
180   server_num_ = IntToUint(server_node_->server_num());
181   MS_LOG(INFO) << "After scheduler scaling, this server's rank is " << local_rank_ << ", server number is "
182                << server_num_;
183 
184   // Clear old counter data of this server.
185   global_current_count_.clear();
186   global_threshold_count_.clear();
187   counter_handlers_.clear();
188   return true;
189 }
190 
HandleCountRequest(const std::shared_ptr<ps::core::MessageHandler> & message)191 void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
192   MS_ERROR_IF_NULL_WO_RET_VAL(message);
193   CountRequest report_count_req;
194   (void)report_count_req.ParseFromArray(message->data(), SizeToInt(message->len()));
195   const std::string &name = report_count_req.name();
196   const std::string &id = report_count_req.id();
197 
198   CountResponse count_rsp;
199   std::unique_lock<std::mutex> lock(mutex_[name]);
200   // If leader server has no counter for the name registered, return an error.
201   if (global_threshold_count_.count(name) == 0) {
202     std::string reason = "Counter for " + name + " is not registered.";
203     count_rsp.set_result(false);
204     count_rsp.set_reason(reason);
205     MS_LOG(ERROR) << reason;
206     if (!communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(),
207                                      message)) {
208       MS_LOG(ERROR) << "Sending response failed.";
209       return;
210     }
211     return;
212   }
213 
214   // If leader server already has enough count for the name, return an error.
215   if (global_current_count_[name].size() >= global_threshold_count_[name]) {
216     std::string reason =
217       "Count for " + name + " is already enough. Threshold count is " + std::to_string(global_threshold_count_[name]);
218     count_rsp.set_result(false);
219     count_rsp.set_reason(reason);
220     MS_LOG(ERROR) << reason;
221     if (!communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(),
222                                      message)) {
223       MS_LOG(ERROR) << "Sending response failed.";
224       return;
225     }
226     return;
227   }
228 
229   // Insert the id for the counter, which means the count for the name is increased.
230   MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
231   (void)global_current_count_[name].insert(id);
232   std::string reason = "success";
233   if (!TriggerCounterEvent(name, &reason)) {
234     count_rsp.set_result(false);
235     count_rsp.set_reason(reason);
236   } else {
237     count_rsp.set_result(true);
238     count_rsp.set_reason(reason);
239   }
240   if (!communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(),
241                                    message)) {
242     MS_LOG(ERROR) << "Sending response failed.";
243     return;
244   }
245   return;
246 }
247 
HandleCountReachThresholdRequest(const std::shared_ptr<ps::core::MessageHandler> & message)248 void DistributedCountService::HandleCountReachThresholdRequest(
249   const std::shared_ptr<ps::core::MessageHandler> &message) {
250   MS_ERROR_IF_NULL_WO_RET_VAL(message);
251   CountReachThresholdRequest count_reach_threshold_req;
252   (void)count_reach_threshold_req.ParseFromArray(message->data(), SizeToInt(message->len()));
253   const std::string &name = count_reach_threshold_req.name();
254 
255   std::unique_lock<std::mutex> lock(mutex_[name]);
256   if (global_threshold_count_.count(name) == 0) {
257     MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
258     return;
259   }
260 
261   CountReachThresholdResponse count_reach_threshold_rsp;
262   count_reach_threshold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]);
263   if (!communicator_->SendResponse(count_reach_threshold_rsp.SerializeAsString().data(),
264                                    count_reach_threshold_rsp.SerializeAsString().size(), message)) {
265     MS_LOG(ERROR) << "Sending response failed.";
266     return;
267   }
268   return;
269 }
270 
HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> & message)271 void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
272   MS_ERROR_IF_NULL_WO_RET_VAL(message);
273   // Respond as soon as possible so the leader server won't wait for each follower servers to finish calling the
274   // callbacks.
275   std::string couter_event_rsp_msg = "success";
276   if (!communicator_->SendResponse(couter_event_rsp_msg.data(), couter_event_rsp_msg.size(), message)) {
277     MS_LOG(ERROR) << "Sending response failed.";
278     return;
279   }
280 
281   CounterEvent counter_event;
282   (void)counter_event.ParseFromArray(message->data(), SizeToInt(message->len()));
283   const auto &type = counter_event.type();
284   const auto &name = counter_event.name();
285 
286   if (counter_handlers_.count(name) == 0) {
287     MS_LOG(ERROR) << "The counter handler of " << name << " is not registered.";
288     return;
289   }
290   MS_LOG(DEBUG) << "Rank " << local_rank_ << " do counter event " << type << " for " << name;
291   if (type == CounterEventType::FIRST_CNT) {
292     counter_handlers_[name].first_count_handler(message);
293   } else if (type == CounterEventType::LAST_CNT) {
294     counter_handlers_[name].last_count_handler(message);
295   } else {
296     MS_LOG(ERROR) << "DistributedCountService event type " << type << " is invalid.";
297     return;
298   }
299   return;
300 }
301 
TriggerCounterEvent(const std::string & name,std::string * reason)302 bool DistributedCountService::TriggerCounterEvent(const std::string &name, std::string *reason) {
303   if (global_current_count_.count(name) == 0 || global_threshold_count_.count(name) == 0) {
304     MS_LOG(ERROR) << "The counter of " << name << " is not registered.";
305     return false;
306   }
307 
308   MS_LOG(INFO) << "Current count for " << name << " is " << global_current_count_[name].size()
309                << ", threshold count is " << global_threshold_count_[name];
310   // The threshold count may be 1 so the first and last count event should be both activated.
311   if (global_current_count_[name].size() == 1) {
312     if (!TriggerFirstCountEvent(name, reason)) {
313       return false;
314     }
315   }
316   if (global_current_count_[name].size() == global_threshold_count_[name]) {
317     if (!TriggerLastCountEvent(name, reason)) {
318       return false;
319     }
320   }
321   return true;
322 }
323 
TriggerFirstCountEvent(const std::string & name,std::string * reason)324 bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, std::string *reason) {
325   MS_LOG(DEBUG) << "Activating first count event for " << name;
326   CounterEvent first_count_event;
327   first_count_event.set_type(CounterEventType::FIRST_CNT);
328   first_count_event.set_name(name);
329 
330   // Broadcast to all follower servers.
331   for (uint32_t i = 1; i < server_num_; i++) {
332     if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
333       MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
334       if (reason != nullptr) {
335         *reason = kNetworkError;
336       }
337       return false;
338     }
339   }
340 
341   if (counter_handlers_.count(name) == 0) {
342     MS_LOG(ERROR) << "The counter handler of " << name << " is not registered.";
343     return false;
344   }
345   // Leader server directly calls the callback.
346   counter_handlers_[name].first_count_handler(nullptr);
347   return true;
348 }
349 
TriggerLastCountEvent(const std::string & name,std::string * reason)350 bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std::string *reason) {
351   MS_LOG(INFO) << "Activating last count event for " << name;
352   CounterEvent last_count_event;
353   last_count_event.set_type(CounterEventType::LAST_CNT);
354   last_count_event.set_name(name);
355 
356   // Broadcast to all follower servers.
357   for (uint32_t i = 1; i < server_num_; i++) {
358     if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
359       MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
360       if (reason != nullptr) {
361         *reason = kNetworkError;
362       }
363       return false;
364     }
365   }
366 
367   if (counter_handlers_.count(name) == 0) {
368     MS_LOG(ERROR) << "The counter handler of " << name << " is not registered.";
369     return false;
370   }
371   // Leader server directly calls the callback.
372   counter_handlers_[name].last_count_handler(nullptr);
373   return true;
374 }
375 }  // namespace server
376 }  // namespace fl
377 }  // namespace mindspore
378