1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fl/server/distributed_count_service.h"
18 #include <string>
19 #include <memory>
20 #include <vector>
21
22 namespace mindspore {
23 namespace fl {
24 namespace server {
Initialize(const std::shared_ptr<ps::core::ServerNode> & server_node,uint32_t counting_server_rank)25 void DistributedCountService::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node,
26 uint32_t counting_server_rank) {
27 MS_EXCEPTION_IF_NULL(server_node);
28 server_node_ = server_node;
29 local_rank_ = server_node_->rank_id();
30 server_num_ = ps::PSContext::instance()->initial_server_num();
31 counting_server_rank_ = counting_server_rank;
32 return;
33 }
34
RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> & communicator)35 void DistributedCountService::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
36 MS_EXCEPTION_IF_NULL(communicator);
37 communicator_ = communicator;
38 communicator_->RegisterMsgCallBack(
39 "count", std::bind(&DistributedCountService::HandleCountRequest, this, std::placeholders::_1));
40 communicator_->RegisterMsgCallBack(
41 "countReachThreshold",
42 std::bind(&DistributedCountService::HandleCountReachThresholdRequest, this, std::placeholders::_1));
43 communicator_->RegisterMsgCallBack(
44 "counterEvent", std::bind(&DistributedCountService::HandleCounterEvent, this, std::placeholders::_1));
45 }
46
RegisterCounter(const std::string & name,size_t global_threshold_count,const CounterHandlers & counter_handlers)47 void DistributedCountService::RegisterCounter(const std::string &name, size_t global_threshold_count,
48 const CounterHandlers &counter_handlers) {
49 if (!counter_handlers.first_count_handler || !counter_handlers.last_count_handler) {
50 MS_LOG(EXCEPTION) << "First count handler or last count handler is not set.";
51 return;
52 }
53 if (global_threshold_count_.count(name) != 0) {
54 MS_LOG(INFO) << "Counter for " << name << " is already set.";
55 return;
56 }
57
58 MS_LOG(INFO) << "Rank " << local_rank_ << " register counter for " << name << " count:" << global_threshold_count;
59 // If the server is the leader server, it needs to set the counter handlers and do the real counting.
60 if (local_rank_ == counting_server_rank_) {
61 global_current_count_[name] = {};
62 global_threshold_count_[name] = global_threshold_count;
63 mutex_[name];
64 }
65 counter_handlers_[name] = counter_handlers;
66 return;
67 }
68
ReInitCounter(const std::string & name,size_t global_threshold_count)69 bool DistributedCountService::ReInitCounter(const std::string &name, size_t global_threshold_count) {
70 MS_LOG(INFO) << "Rank " << local_rank_ << " reinitialize counter for " << name << " count:" << global_threshold_count;
71 if (local_rank_ == counting_server_rank_) {
72 std::unique_lock<std::mutex> lock(mutex_[name]);
73 if (global_threshold_count_.count(name) == 0) {
74 MS_LOG(INFO) << "Counter for " << name << " is not set.";
75 return false;
76 }
77 global_current_count_[name] = {};
78 global_threshold_count_[name] = global_threshold_count;
79 }
80 return true;
81 }
82
Count(const std::string & name,const std::string & id,std::string * reason)83 bool DistributedCountService::Count(const std::string &name, const std::string &id, std::string *reason) {
84 MS_LOG(INFO) << "Rank " << local_rank_ << " reports count for " << name << " of " << id;
85 if (local_rank_ == counting_server_rank_) {
86 if (global_threshold_count_.count(name) == 0) {
87 MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
88 return false;
89 }
90
91 std::unique_lock<std::mutex> lock(mutex_[name]);
92 if (global_current_count_[name].size() >= global_threshold_count_[name]) {
93 MS_LOG(ERROR) << "Count for " << name << " is already enough. Threshold count is "
94 << global_threshold_count_[name];
95 return false;
96 }
97
98 MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
99 (void)global_current_count_[name].insert(id);
100 if (!TriggerCounterEvent(name, reason)) {
101 MS_LOG(ERROR) << "Leader server trigger count event failed.";
102 return false;
103 }
104 } else {
105 // If this server is a follower server, it needs to send CountRequest to the leader server.
106 CountRequest report_count_req;
107 report_count_req.set_name(name);
108 report_count_req.set_id(id);
109
110 std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr;
111 if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, ps::core::TcpUserCommand::kCount,
112 &report_cnt_rsp_msg)) {
113 MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name;
114 if (reason != nullptr) {
115 *reason = kNetworkError;
116 }
117 return false;
118 }
119
120 MS_ERROR_IF_NULL_W_RET_VAL(report_cnt_rsp_msg, false);
121 CountResponse count_rsp;
122 (void)count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size()));
123 if (!count_rsp.result()) {
124 MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason();
125 // If the error is caused by the network issue, return the reason.
126 if (reason != nullptr && count_rsp.reason().find(kNetworkError) != std::string::npos) {
127 *reason = kNetworkError;
128 }
129 return false;
130 }
131 }
132 return true;
133 }
134
CountReachThreshold(const std::string & name)135 bool DistributedCountService::CountReachThreshold(const std::string &name) {
136 MS_LOG(INFO) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name;
137 if (local_rank_ == counting_server_rank_) {
138 if (global_threshold_count_.count(name) == 0) {
139 MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
140 return false;
141 }
142
143 std::unique_lock<std::mutex> lock(mutex_[name]);
144 return global_current_count_[name].size() == global_threshold_count_[name];
145 } else {
146 CountReachThresholdRequest count_reach_threshold_req;
147 count_reach_threshold_req.set_name(name);
148
149 std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr;
150 if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_,
151 ps::core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
152 MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name;
153 return false;
154 }
155
156 MS_ERROR_IF_NULL_W_RET_VAL(query_cnt_enough_rsp_msg, false);
157 CountReachThresholdResponse count_reach_threshold_rsp;
158 (void)count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(),
159 SizeToInt(query_cnt_enough_rsp_msg->size()));
160 return count_reach_threshold_rsp.is_enough();
161 }
162 }
163
ResetCounter(const std::string & name)164 void DistributedCountService::ResetCounter(const std::string &name) {
165 if (local_rank_ == counting_server_rank_) {
166 MS_LOG(DEBUG) << "Leader server reset count for " << name;
167 global_current_count_[name].clear();
168 }
169 return;
170 }
171
ReInitForScaling()172 bool DistributedCountService::ReInitForScaling() {
173 // If DistributedCountService is not initialized yet but the scaling event is triggered, do not throw exception.
174 if (server_node_ == nullptr) {
175 return true;
176 }
177
178 MS_LOG(INFO) << "Cluster scaling completed. Reinitialize for distributed count service.";
179 local_rank_ = server_node_->rank_id();
180 server_num_ = IntToUint(server_node_->server_num());
181 MS_LOG(INFO) << "After scheduler scaling, this server's rank is " << local_rank_ << ", server number is "
182 << server_num_;
183
184 // Clear old counter data of this server.
185 global_current_count_.clear();
186 global_threshold_count_.clear();
187 counter_handlers_.clear();
188 return true;
189 }
190
HandleCountRequest(const std::shared_ptr<ps::core::MessageHandler> & message)191 void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
192 MS_ERROR_IF_NULL_WO_RET_VAL(message);
193 CountRequest report_count_req;
194 (void)report_count_req.ParseFromArray(message->data(), SizeToInt(message->len()));
195 const std::string &name = report_count_req.name();
196 const std::string &id = report_count_req.id();
197
198 CountResponse count_rsp;
199 std::unique_lock<std::mutex> lock(mutex_[name]);
200 // If leader server has no counter for the name registered, return an error.
201 if (global_threshold_count_.count(name) == 0) {
202 std::string reason = "Counter for " + name + " is not registered.";
203 count_rsp.set_result(false);
204 count_rsp.set_reason(reason);
205 MS_LOG(ERROR) << reason;
206 if (!communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(),
207 message)) {
208 MS_LOG(ERROR) << "Sending response failed.";
209 return;
210 }
211 return;
212 }
213
214 // If leader server already has enough count for the name, return an error.
215 if (global_current_count_[name].size() >= global_threshold_count_[name]) {
216 std::string reason =
217 "Count for " + name + " is already enough. Threshold count is " + std::to_string(global_threshold_count_[name]);
218 count_rsp.set_result(false);
219 count_rsp.set_reason(reason);
220 MS_LOG(ERROR) << reason;
221 if (!communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(),
222 message)) {
223 MS_LOG(ERROR) << "Sending response failed.";
224 return;
225 }
226 return;
227 }
228
229 // Insert the id for the counter, which means the count for the name is increased.
230 MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id;
231 (void)global_current_count_[name].insert(id);
232 std::string reason = "success";
233 if (!TriggerCounterEvent(name, &reason)) {
234 count_rsp.set_result(false);
235 count_rsp.set_reason(reason);
236 } else {
237 count_rsp.set_result(true);
238 count_rsp.set_reason(reason);
239 }
240 if (!communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(),
241 message)) {
242 MS_LOG(ERROR) << "Sending response failed.";
243 return;
244 }
245 return;
246 }
247
HandleCountReachThresholdRequest(const std::shared_ptr<ps::core::MessageHandler> & message)248 void DistributedCountService::HandleCountReachThresholdRequest(
249 const std::shared_ptr<ps::core::MessageHandler> &message) {
250 MS_ERROR_IF_NULL_WO_RET_VAL(message);
251 CountReachThresholdRequest count_reach_threshold_req;
252 (void)count_reach_threshold_req.ParseFromArray(message->data(), SizeToInt(message->len()));
253 const std::string &name = count_reach_threshold_req.name();
254
255 std::unique_lock<std::mutex> lock(mutex_[name]);
256 if (global_threshold_count_.count(name) == 0) {
257 MS_LOG(ERROR) << "Counter for " << name << " is not registered.";
258 return;
259 }
260
261 CountReachThresholdResponse count_reach_threshold_rsp;
262 count_reach_threshold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]);
263 if (!communicator_->SendResponse(count_reach_threshold_rsp.SerializeAsString().data(),
264 count_reach_threshold_rsp.SerializeAsString().size(), message)) {
265 MS_LOG(ERROR) << "Sending response failed.";
266 return;
267 }
268 return;
269 }
270
HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> & message)271 void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
272 MS_ERROR_IF_NULL_WO_RET_VAL(message);
273 // Respond as soon as possible so the leader server won't wait for each follower servers to finish calling the
274 // callbacks.
275 std::string couter_event_rsp_msg = "success";
276 if (!communicator_->SendResponse(couter_event_rsp_msg.data(), couter_event_rsp_msg.size(), message)) {
277 MS_LOG(ERROR) << "Sending response failed.";
278 return;
279 }
280
281 CounterEvent counter_event;
282 (void)counter_event.ParseFromArray(message->data(), SizeToInt(message->len()));
283 const auto &type = counter_event.type();
284 const auto &name = counter_event.name();
285
286 if (counter_handlers_.count(name) == 0) {
287 MS_LOG(ERROR) << "The counter handler of " << name << " is not registered.";
288 return;
289 }
290 MS_LOG(DEBUG) << "Rank " << local_rank_ << " do counter event " << type << " for " << name;
291 if (type == CounterEventType::FIRST_CNT) {
292 counter_handlers_[name].first_count_handler(message);
293 } else if (type == CounterEventType::LAST_CNT) {
294 counter_handlers_[name].last_count_handler(message);
295 } else {
296 MS_LOG(ERROR) << "DistributedCountService event type " << type << " is invalid.";
297 return;
298 }
299 return;
300 }
301
TriggerCounterEvent(const std::string & name,std::string * reason)302 bool DistributedCountService::TriggerCounterEvent(const std::string &name, std::string *reason) {
303 if (global_current_count_.count(name) == 0 || global_threshold_count_.count(name) == 0) {
304 MS_LOG(ERROR) << "The counter of " << name << " is not registered.";
305 return false;
306 }
307
308 MS_LOG(INFO) << "Current count for " << name << " is " << global_current_count_[name].size()
309 << ", threshold count is " << global_threshold_count_[name];
310 // The threshold count may be 1 so the first and last count event should be both activated.
311 if (global_current_count_[name].size() == 1) {
312 if (!TriggerFirstCountEvent(name, reason)) {
313 return false;
314 }
315 }
316 if (global_current_count_[name].size() == global_threshold_count_[name]) {
317 if (!TriggerLastCountEvent(name, reason)) {
318 return false;
319 }
320 }
321 return true;
322 }
323
TriggerFirstCountEvent(const std::string & name,std::string * reason)324 bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, std::string *reason) {
325 MS_LOG(DEBUG) << "Activating first count event for " << name;
326 CounterEvent first_count_event;
327 first_count_event.set_type(CounterEventType::FIRST_CNT);
328 first_count_event.set_name(name);
329
330 // Broadcast to all follower servers.
331 for (uint32_t i = 1; i < server_num_; i++) {
332 if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
333 MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
334 if (reason != nullptr) {
335 *reason = kNetworkError;
336 }
337 return false;
338 }
339 }
340
341 if (counter_handlers_.count(name) == 0) {
342 MS_LOG(ERROR) << "The counter handler of " << name << " is not registered.";
343 return false;
344 }
345 // Leader server directly calls the callback.
346 counter_handlers_[name].first_count_handler(nullptr);
347 return true;
348 }
349
TriggerLastCountEvent(const std::string & name,std::string * reason)350 bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std::string *reason) {
351 MS_LOG(INFO) << "Activating last count event for " << name;
352 CounterEvent last_count_event;
353 last_count_event.set_type(CounterEventType::LAST_CNT);
354 last_count_event.set_name(name);
355
356 // Broadcast to all follower servers.
357 for (uint32_t i = 1; i < server_num_; i++) {
358 if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
359 MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
360 if (reason != nullptr) {
361 *reason = kNetworkError;
362 }
363 return false;
364 }
365 }
366
367 if (counter_handlers_.count(name) == 0) {
368 MS_LOG(ERROR) << "The counter handler of " << name << " is not registered.";
369 return false;
370 }
371 // Leader server directly calls the callback.
372 counter_handlers_[name].last_count_handler(nullptr);
373 return true;
374 }
375 } // namespace server
376 } // namespace fl
377 } // namespace mindspore
378