1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fl/server/iteration.h"
18 #include <memory>
19 #include <string>
20 #include <vector>
21 #include <numeric>
22 #include "fl/server/model_store.h"
23 #include "fl/server/server.h"
24
25 namespace mindspore {
26 namespace fl {
27 namespace server {
28 class Server;
29
~Iteration()30 Iteration::~Iteration() {
31 move_to_next_thread_running_ = false;
32 next_iteration_cv_.notify_all();
33 if (move_to_next_thread_.joinable()) {
34 move_to_next_thread_.join();
35 }
36 }
37
RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> & communicator)38 void Iteration::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
39 MS_EXCEPTION_IF_NULL(communicator);
40 communicator_ = communicator;
41 communicator_->RegisterMsgCallBack("syncIteration",
42 std::bind(&Iteration::HandleSyncIterationRequest, this, std::placeholders::_1));
43 communicator_->RegisterMsgCallBack(
44 "notifyLeaderToNextIter",
45 std::bind(&Iteration::HandleNotifyLeaderMoveToNextIterRequest, this, std::placeholders::_1));
46 communicator_->RegisterMsgCallBack(
47 "prepareForNextIter", std::bind(&Iteration::HandlePrepareForNextIterRequest, this, std::placeholders::_1));
48 communicator_->RegisterMsgCallBack("proceedToNextIter",
49 std::bind(&Iteration::HandleMoveToNextIterRequest, this, std::placeholders::_1));
50 communicator_->RegisterMsgCallBack("endLastIter",
51 std::bind(&Iteration::HandleEndLastIterRequest, this, std::placeholders::_1));
52 }
53
RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> & server_node)54 void Iteration::RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node) {
55 MS_EXCEPTION_IF_NULL(server_node);
56 server_node_ = server_node;
57 server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning),
58 std::bind(&Iteration::HandleIterationRunningEvent, this));
59 server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted),
60 std::bind(&Iteration::HandleIterationCompletedEvent, this));
61 }
62
AddRound(const std::shared_ptr<Round> & round)63 void Iteration::AddRound(const std::shared_ptr<Round> &round) {
64 MS_EXCEPTION_IF_NULL(round);
65 rounds_.push_back(round);
66 }
67
InitRounds(const std::vector<std::shared_ptr<ps::core::CommunicatorBase>> & communicators,const TimeOutCb & timeout_cb,const FinishIterCb & finish_iteration_cb)68 void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::CommunicatorBase>> &communicators,
69 const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb) {
70 if (communicators.empty()) {
71 MS_LOG(EXCEPTION) << "Communicators for rounds is empty.";
72 return;
73 }
74
75 (void)std::for_each(communicators.begin(), communicators.end(),
76 [&](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
77 for (auto &round : rounds_) {
78 MS_EXCEPTION_IF_NULL(round);
79 round->Initialize(communicator, timeout_cb, finish_iteration_cb);
80 }
81 });
82
83 // The time window for one iteration, which will be used in some round kernels.
84 size_t iteration_time_window = std::accumulate(rounds_.begin(), rounds_.end(), IntToSize(0),
85 [](size_t total, const std::shared_ptr<Round> &round) {
86 MS_EXCEPTION_IF_NULL(round);
87 return round->check_timeout() ? total + round->time_window() : total;
88 });
89 LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window);
90 MS_LOG(INFO) << "Time window for one iteration is " << iteration_time_window;
91
92 // Initialize the thread which will handle the signal from round kernels.
93 move_to_next_thread_ = std::thread([this]() {
94 while (move_to_next_thread_running_.load()) {
95 std::unique_lock<std::mutex> lock(next_iteration_mutex_);
96 next_iteration_cv_.wait(lock);
97 if (!move_to_next_thread_running_.load()) {
98 break;
99 }
100 MoveToNextIteration(is_last_iteration_valid_, move_to_next_reason_);
101 }
102 });
103 return;
104 }
105
ClearRounds()106 void Iteration::ClearRounds() { rounds_.clear(); }
107
NotifyNext(bool is_last_iter_valid,const std::string & reason)108 void Iteration::NotifyNext(bool is_last_iter_valid, const std::string &reason) {
109 std::unique_lock<std::mutex> lock(next_iteration_mutex_);
110 is_last_iteration_valid_ = is_last_iter_valid;
111 move_to_next_reason_ = reason;
112 next_iteration_cv_.notify_one();
113 }
114
MoveToNextIteration(bool is_last_iter_valid,const std::string & reason)115 void Iteration::MoveToNextIteration(bool is_last_iter_valid, const std::string &reason) {
116 MS_LOG(INFO) << "Notify cluster starts to proceed to next iteration. Iteration is " << iteration_num_
117 << " validation is " << is_last_iter_valid << ". Reason: " << reason;
118 if (IsMoveToNextIterRequestReentrant(iteration_num_)) {
119 return;
120 }
121
122 MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
123 if (server_node_->rank_id() == kLeaderServerRank) {
124 if (!BroadcastPrepareForNextIterRequest(is_last_iter_valid, reason)) {
125 MS_LOG(ERROR) << "Broadcast prepare for next iteration request failed.";
126 return;
127 }
128 if (!BroadcastMoveToNextIterRequest(is_last_iter_valid, reason)) {
129 MS_LOG(ERROR) << "Broadcast proceed to next iteration request failed.";
130 return;
131 }
132 if (!BroadcastEndLastIterRequest(iteration_num_)) {
133 MS_LOG(ERROR) << "Broadcast end last iteration request failed.";
134 return;
135 }
136 } else {
137 // If this server is the follower server, notify leader server to control the cluster to proceed to next iteration.
138 if (!NotifyLeaderMoveToNextIteration(is_last_iter_valid, reason)) {
139 MS_LOG(ERROR) << "Server " << server_node_->rank_id() << " notifying the leader server failed.";
140 return;
141 }
142 }
143 }
144
SetIterationRunning()145 void Iteration::SetIterationRunning() {
146 MS_LOG(INFO) << "Iteration " << iteration_num_ << " start running.";
147 MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
148 if (server_node_->rank_id() == kLeaderServerRank) {
149 // This event helps worker/server to be consistent in iteration state.
150 server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning));
151 }
152
153 std::unique_lock<std::mutex> lock(iteration_state_mtx_);
154 iteration_state_ = IterationState::kRunning;
155 start_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count());
156 }
157
SetIterationCompleted()158 void Iteration::SetIterationCompleted() {
159 MS_LOG(INFO) << "Iteration " << iteration_num_ << " completes.";
160 MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
161 if (server_node_->rank_id() == kLeaderServerRank) {
162 // This event helps worker/server to be consistent in iteration state.
163 server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted));
164 }
165
166 std::unique_lock<std::mutex> lock(iteration_state_mtx_);
167 iteration_state_ = IterationState::kCompleted;
168 complete_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count());
169 }
170
ScalingBarrier()171 void Iteration::ScalingBarrier() {
172 MS_LOG(INFO) << "Starting Iteration scaling barrier.";
173 std::unique_lock<std::mutex> lock(iteration_state_mtx_);
174 if (iteration_state_.load() != IterationState::kCompleted) {
175 iteration_state_cv_.wait(lock);
176 }
177 MS_LOG(INFO) << "Ending Iteration scaling barrier.";
178 }
179
ReInitForScaling(uint32_t server_num,uint32_t server_rank)180 bool Iteration::ReInitForScaling(uint32_t server_num, uint32_t server_rank) {
181 for (auto &round : rounds_) {
182 if (!round->ReInitForScaling(server_num)) {
183 MS_LOG(WARNING) << "Reinitializing round " << round->name() << " for scaling failed.";
184 return false;
185 }
186 }
187 if (server_rank != kLeaderServerRank) {
188 if (!SyncIteration(server_rank)) {
189 MS_LOG(ERROR) << "Synchronizing iteration failed.";
190 return false;
191 }
192 }
193 return true;
194 }
195
ReInitForUpdatingHyperParams(const std::vector<RoundConfig> & updated_rounds_config)196 bool Iteration::ReInitForUpdatingHyperParams(const std::vector<RoundConfig> &updated_rounds_config) {
197 for (const auto &updated_round : updated_rounds_config) {
198 for (const auto &round : rounds_) {
199 if (updated_round.name == round->name()) {
200 MS_LOG(INFO) << "Reinitialize for round " << round->name();
201 if (!round->ReInitForUpdatingHyperParams(updated_round.threshold_count, updated_round.time_window)) {
202 MS_LOG(ERROR) << "Reinitializing for round " << round->name() << " failed.";
203 return false;
204 }
205 }
206 }
207 }
208 return true;
209 }
210
rounds() const211 const std::vector<std::shared_ptr<Round>> &Iteration::rounds() const { return rounds_; }
212
is_last_iteration_valid() const213 bool Iteration::is_last_iteration_valid() const { return is_last_iteration_valid_; }
214
set_metrics(const std::shared_ptr<IterationMetrics> & metrics)215 void Iteration::set_metrics(const std::shared_ptr<IterationMetrics> &metrics) { metrics_ = metrics; }
216
set_loss(float loss)217 void Iteration::set_loss(float loss) { loss_ = loss; }
218
set_accuracy(float accuracy)219 void Iteration::set_accuracy(float accuracy) { accuracy_ = accuracy; }
220
instance_state() const221 InstanceState Iteration::instance_state() const { return instance_state_.load(); }
222
EnableServerInstance(std::string * result)223 bool Iteration::EnableServerInstance(std::string *result) {
224 MS_ERROR_IF_NULL_W_RET_VAL(result, false);
225 // Before enabling server instance, we should judge whether this request should be handled.
226 std::unique_lock<std::mutex> lock(instance_mtx_);
227 if (is_instance_being_updated_) {
228 *result = "The instance is being updated. Please retry enabling server later.";
229 MS_LOG(WARNING) << *result;
230 return false;
231 }
232 if (instance_state_.load() == InstanceState::kFinish) {
233 *result = "The instance is completed. Please do not enabling server now.";
234 MS_LOG(WARNING) << *result;
235 return false;
236 }
237
238 // Start enabling server instance.
239 is_instance_being_updated_ = true;
240
241 instance_state_ = InstanceState::kRunning;
242 *result = "Enabling FL-Server succeeded.";
243 MS_LOG(INFO) << *result;
244
245 // End enabling server instance.
246 is_instance_being_updated_ = false;
247 return true;
248 }
249
DisableServerInstance(std::string * result)250 bool Iteration::DisableServerInstance(std::string *result) {
251 MS_ERROR_IF_NULL_W_RET_VAL(result, false);
252 // Before disabling server instance, we should judge whether this request should be handled.
253 std::unique_lock<std::mutex> lock(instance_mtx_);
254 if (is_instance_being_updated_) {
255 *result = "The instance is being updated. Please retry disabling server later.";
256 MS_LOG(WARNING) << *result;
257 return false;
258 }
259 if (instance_state_.load() == InstanceState::kFinish) {
260 *result = "The instance is completed. Please do not disabling server now.";
261 MS_LOG(WARNING) << *result;
262 return false;
263 }
264 if (instance_state_.load() == InstanceState::kDisable) {
265 *result = "Disabling FL-Server succeeded.";
266 MS_LOG(INFO) << *result;
267 return true;
268 }
269
270 // Start disabling server instance.
271 is_instance_being_updated_ = true;
272
273 // If instance is running, we should drop current iteration and move to the next.
274 instance_state_ = InstanceState::kDisable;
275 if (!ForciblyMoveToNextIteration()) {
276 *result = "Disabling instance failed. Can't drop current iteration and move to the next.";
277 MS_LOG(ERROR) << *result;
278 return false;
279 }
280 *result = "Disabling FL-Server succeeded.";
281 MS_LOG(INFO) << *result;
282
283 // End disabling server instance.
284 is_instance_being_updated_ = false;
285 return true;
286 }
287
NewInstance(const nlohmann::json & new_instance_json,std::string * result)288 bool Iteration::NewInstance(const nlohmann::json &new_instance_json, std::string *result) {
289 MS_ERROR_IF_NULL_W_RET_VAL(result, false);
290 // Before new instance, we should judge whether this request should be handled.
291 std::unique_lock<std::mutex> lock(instance_mtx_);
292 if (is_instance_being_updated_) {
293 *result = "The instance is being updated. Please retry new instance later.";
294 MS_LOG(WARNING) << *result;
295 return false;
296 }
297
298 // Start new server instance.
299 is_instance_being_updated_ = true;
300
301 // Reset current instance.
302 instance_state_ = InstanceState::kFinish;
303 Server::GetInstance().WaitExitSafeMode();
304 WaitAllRoundsFinish();
305 MS_LOG(INFO) << "Proceed to a new instance.";
306 for (auto &round : rounds_) {
307 MS_ERROR_IF_NULL_W_RET_VAL(round, false);
308 round->Reset();
309 }
310 iteration_num_ = 1;
311 LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
312 ModelStore::GetInstance().Reset();
313 if (metrics_ != nullptr) {
314 if (!metrics_->Clear()) {
315 MS_LOG(WARNING) << "Clear metrics fil failed.";
316 }
317 }
318
319 // Update the hyper-parameters on server and reinitialize rounds.
320 if (!UpdateHyperParams(new_instance_json)) {
321 *result = "Updating hyper-parameters failed.";
322 return false;
323 }
324 if (!ReInitRounds()) {
325 *result = "Reinitializing rounds failed.";
326 return false;
327 }
328
329 instance_state_ = InstanceState::kRunning;
330 *result = "New FL-Server instance succeeded.";
331
332 // End new server instance.
333 is_instance_being_updated_ = false;
334 return true;
335 }
336
WaitAllRoundsFinish() const337 void Iteration::WaitAllRoundsFinish() const {
338 while (running_round_num_.load() != 0) {
339 std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
340 }
341 }
342
SyncIteration(uint32_t rank)343 bool Iteration::SyncIteration(uint32_t rank) {
344 MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
345 SyncIterationRequest sync_iter_req;
346 sync_iter_req.set_rank(rank);
347
348 std::shared_ptr<std::vector<unsigned char>> sync_iter_rsp_msg = nullptr;
349 if (!communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, ps::core::TcpUserCommand::kSyncIteration,
350 &sync_iter_rsp_msg)) {
351 MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed.";
352 return false;
353 }
354
355 MS_ERROR_IF_NULL_W_RET_VAL(sync_iter_rsp_msg, false);
356 SyncIterationResponse sync_iter_rsp;
357 (void)sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), SizeToInt(sync_iter_rsp_msg->size()));
358 iteration_num_ = sync_iter_rsp.iteration();
359 MS_LOG(INFO) << "After synchronizing, server " << rank << " current iteration number is "
360 << sync_iter_rsp.iteration();
361 return true;
362 }
363
HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> & message)364 void Iteration::HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
365 MS_ERROR_IF_NULL_WO_RET_VAL(message);
366 MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
367
368 SyncIterationRequest sync_iter_req;
369 (void)sync_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
370 uint32_t rank = sync_iter_req.rank();
371 MS_LOG(INFO) << "Synchronizing iteration request from rank " << rank;
372
373 SyncIterationResponse sync_iter_rsp;
374 sync_iter_rsp.set_iteration(iteration_num_);
375 std::string sync_iter_rsp_msg = sync_iter_rsp.SerializeAsString();
376 if (!communicator_->SendResponse(sync_iter_rsp_msg.data(), sync_iter_rsp_msg.size(), message)) {
377 MS_LOG(ERROR) << "Sending response failed.";
378 return;
379 }
380 }
381
IsMoveToNextIterRequestReentrant(uint64_t iteration_num)382 bool Iteration::IsMoveToNextIterRequestReentrant(uint64_t iteration_num) {
383 std::unique_lock<std::mutex> lock(pinned_mtx_);
384 if (pinned_iter_num_ == iteration_num) {
385 MS_LOG(WARNING) << "MoveToNextIteration is not reentrant. Ignore this call.";
386 return true;
387 }
388 pinned_iter_num_ = iteration_num;
389 return false;
390 }
391
NotifyLeaderMoveToNextIteration(bool is_last_iter_valid,const std::string & reason)392 bool Iteration::NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason) {
393 MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
394 MS_LOG(INFO) << "Notify leader server to control the cluster to proceed to next iteration.";
395 NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req;
396 notify_leader_to_next_iter_req.set_rank(server_node_->rank_id());
397 notify_leader_to_next_iter_req.set_is_last_iter_valid(is_last_iter_valid);
398 notify_leader_to_next_iter_req.set_iter_num(iteration_num_);
399 notify_leader_to_next_iter_req.set_reason(reason);
400 if (!communicator_->SendPbRequest(notify_leader_to_next_iter_req, kLeaderServerRank,
401 ps::core::TcpUserCommand::kNotifyLeaderToNextIter)) {
402 MS_LOG(WARNING) << "Sending notify leader server to proceed next iteration request to leader server 0 failed.";
403 return false;
404 }
405 return true;
406 }
407
HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> & message)408 void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
409 MS_ERROR_IF_NULL_WO_RET_VAL(message);
410 MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
411 NotifyLeaderMoveToNextIterResponse notify_leader_to_next_iter_rsp;
412 notify_leader_to_next_iter_rsp.set_result("success");
413 if (!communicator_->SendResponse(notify_leader_to_next_iter_rsp.SerializeAsString().data(),
414 notify_leader_to_next_iter_rsp.SerializeAsString().size(), message)) {
415 MS_LOG(ERROR) << "Sending response failed.";
416 return;
417 }
418
419 NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req;
420 (void)notify_leader_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
421 const auto &rank = notify_leader_to_next_iter_req.rank();
422 const auto &is_last_iter_valid = notify_leader_to_next_iter_req.is_last_iter_valid();
423 const auto &iter_num = notify_leader_to_next_iter_req.iter_num();
424 const auto &reason = notify_leader_to_next_iter_req.reason();
425 MS_LOG(INFO) << "Leader server receives NotifyLeaderMoveToNextIterRequest from rank " << rank
426 << ". Iteration number: " << iter_num << ". Reason: " << reason;
427
428 if (IsMoveToNextIterRequestReentrant(iter_num)) {
429 return;
430 }
431
432 if (!BroadcastPrepareForNextIterRequest(is_last_iter_valid, reason)) {
433 MS_LOG(ERROR) << "Broadcast prepare for next iteration request failed.";
434 return;
435 }
436 if (!BroadcastMoveToNextIterRequest(is_last_iter_valid, reason)) {
437 MS_LOG(ERROR) << "Broadcast proceed to next iteration request failed.";
438 return;
439 }
440 if (!BroadcastEndLastIterRequest(iteration_num_)) {
441 MS_LOG(ERROR) << "Broadcast end last iteration request failed.";
442 return;
443 }
444 }
445
BroadcastPrepareForNextIterRequest(bool is_last_iter_valid,const std::string & reason)446 bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason) {
447 MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
448 PrepareForNextIter();
449 MS_LOG(INFO) << "Notify all follower servers to prepare for next iteration.";
450 PrepareForNextIterRequest prepare_next_iter_req;
451 prepare_next_iter_req.set_is_last_iter_valid(is_last_iter_valid);
452 prepare_next_iter_req.set_reason(reason);
453
454 std::vector<uint32_t> offline_servers = {};
455 for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
456 if (!communicator_->SendPbRequest(prepare_next_iter_req, i, ps::core::TcpUserCommand::kPrepareForNextIter)) {
457 MS_LOG(WARNING) << "Sending prepare for next iteration request to server " << i << " failed. Retry later.";
458 offline_servers.push_back(i);
459 continue;
460 }
461 }
462
463 // Retry sending to offline servers to notify them to prepare.
464 (void)std::for_each(offline_servers.begin(), offline_servers.end(), [&](uint32_t rank) {
465 // Should avoid endless loop if the server communicator is stopped.
466 while (communicator_->running() &&
467 !communicator_->SendPbRequest(prepare_next_iter_req, rank, ps::core::TcpUserCommand::kPrepareForNextIter)) {
468 MS_LOG(WARNING) << "Retry sending prepare for next iteration request to server " << rank
469 << " failed. The server has not recovered yet.";
470 std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationForPrepareForNextIter));
471 }
472 MS_LOG(INFO) << "Offline server " << rank << " preparing for next iteration success.";
473 });
474 std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
475 return true;
476 }
477
HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> & message)478 void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
479 MS_ERROR_IF_NULL_WO_RET_VAL(message);
480 MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
481 PrepareForNextIterRequest prepare_next_iter_req;
482 (void)prepare_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
483 const auto &reason = prepare_next_iter_req.reason();
484 MS_LOG(INFO) << "Prepare next iteration for this rank " << server_node_->rank_id() << ", reason: " << reason;
485 PrepareForNextIter();
486
487 PrepareForNextIterResponse prepare_next_iter_rsp;
488 prepare_next_iter_rsp.set_result("success");
489 if (!communicator_->SendResponse(prepare_next_iter_rsp.SerializeAsString().data(),
490 prepare_next_iter_rsp.SerializeAsString().size(), message)) {
491 MS_LOG(ERROR) << "Sending response failed.";
492 return;
493 }
494 }
495
PrepareForNextIter()496 void Iteration::PrepareForNextIter() {
497 MS_LOG(INFO) << "Prepare for next iteration. Switch the server to safemode.";
498 Server::GetInstance().SwitchToSafeMode();
499 WaitAllRoundsFinish();
500 }
501
BroadcastMoveToNextIterRequest(bool is_last_iter_valid,const std::string & reason)502 bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason) {
503 MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
504 MS_LOG(INFO) << "Notify all follower servers to proceed to next iteration. Set last iteration number "
505 << iteration_num_;
506 MoveToNextIterRequest proceed_to_next_iter_req;
507 proceed_to_next_iter_req.set_is_last_iter_valid(is_last_iter_valid);
508 proceed_to_next_iter_req.set_last_iter_num(iteration_num_);
509 proceed_to_next_iter_req.set_reason(reason);
510 for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
511 if (!communicator_->SendPbRequest(proceed_to_next_iter_req, i, ps::core::TcpUserCommand::kProceedToNextIter)) {
512 MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed.";
513 continue;
514 }
515 }
516
517 Next(is_last_iter_valid, reason);
518 return true;
519 }
520
HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> & message)521 void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
522 MS_ERROR_IF_NULL_WO_RET_VAL(message);
523 MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
524
525 MoveToNextIterRequest proceed_to_next_iter_req;
526 (void)proceed_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
527 const auto &is_last_iter_valid = proceed_to_next_iter_req.is_last_iter_valid();
528 const auto &last_iter_num = proceed_to_next_iter_req.last_iter_num();
529 const auto &reason = proceed_to_next_iter_req.reason();
530
531 MS_LOG(INFO) << "Receive proceeding to next iteration request. This server current iteration is " << iteration_num_
532 << ". The iteration number from leader server is " << last_iter_num
533 << ". Last iteration is valid or not: " << is_last_iter_valid << ". Reason: " << reason;
534 // Synchronize the iteration number with leader server.
535 iteration_num_ = last_iter_num;
536 Next(is_last_iter_valid, reason);
537
538 MoveToNextIterResponse proceed_to_next_iter_rsp;
539 proceed_to_next_iter_rsp.set_result("success");
540 if (!communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(),
541 proceed_to_next_iter_rsp.SerializeAsString().size(), message)) {
542 MS_LOG(ERROR) << "Sending response failed.";
543 return;
544 }
545 }
546
Next(bool is_iteration_valid,const std::string & reason)547 void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
548 MS_LOG(INFO) << "Prepare for next iteration.";
549 is_last_iteration_valid_ = is_iteration_valid;
550 if (is_iteration_valid) {
551 // Store the model which is successfully aggregated for this iteration.
552 const auto &model = Executor::GetInstance().GetModel();
553 ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
554 MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
555 } else {
556 // Store last iteration's model because this iteration is considered as invalid.
557 const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
558 size_t latest_iter_num = iter_to_model.rbegin()->first;
559 const auto &model = ModelStore::GetInstance().GetModelByIterNum(latest_iter_num);
560 ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
561 MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason;
562 }
563
564 for (auto &round : rounds_) {
565 MS_ERROR_IF_NULL_WO_RET_VAL(round);
566 round->Reset();
567 }
568 }
569
BroadcastEndLastIterRequest(uint64_t last_iter_num)570 bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
571 MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
572 MS_LOG(INFO) << "Notify all follower servers to end last iteration.";
573 EndLastIterRequest end_last_iter_req;
574 end_last_iter_req.set_last_iter_num(last_iter_num);
575 for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
576 if (!communicator_->SendPbRequest(end_last_iter_req, i, ps::core::TcpUserCommand::kEndLastIter)) {
577 MS_LOG(WARNING) << "Sending ending last iteration request to server " << i << " failed.";
578 continue;
579 }
580 }
581
582 EndLastIter();
583 return true;
584 }
585
HandleEndLastIterRequest(const std::shared_ptr<ps::core::MessageHandler> & message)586 void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
587 MS_ERROR_IF_NULL_WO_RET_VAL(message);
588 MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
589 EndLastIterRequest end_last_iter_req;
590 (void)end_last_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
591 const auto &last_iter_num = end_last_iter_req.last_iter_num();
592 // If the iteration number is not matched, return error.
593 if (last_iter_num != iteration_num_) {
594 std::string reason = "The iteration of this server " + std::to_string(server_node_->rank_id()) + " is " +
595 std::to_string(iteration_num_) + ", iteration to be ended is " + std::to_string(last_iter_num);
596 EndLastIterResponse end_last_iter_rsp;
597 end_last_iter_rsp.set_result(reason);
598 if (!communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(),
599 end_last_iter_rsp.SerializeAsString().size(), message)) {
600 MS_LOG(ERROR) << "Sending response failed.";
601 return;
602 }
603 return;
604 }
605
606 EndLastIter();
607
608 EndLastIterResponse end_last_iter_rsp;
609 end_last_iter_rsp.set_result("success");
610 if (!communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(),
611 end_last_iter_rsp.SerializeAsString().size(), message)) {
612 MS_LOG(ERROR) << "Sending response failed.";
613 return;
614 }
615 }
616
EndLastIter()617 void Iteration::EndLastIter() {
618 MS_LOG(INFO) << "End the last iteration " << iteration_num_;
619 if (iteration_num_ == ps::PSContext::instance()->fl_iteration_num()) {
620 MS_LOG(INFO) << "Iteration loop " << iteration_loop_count_
621 << " is completed. Iteration number: " << ps::PSContext::instance()->fl_iteration_num();
622 iteration_loop_count_++;
623 instance_state_ = InstanceState::kFinish;
624 }
625
626 std::unique_lock<std::mutex> lock(pinned_mtx_);
627 pinned_iter_num_ = 0;
628 lock.unlock();
629
630 SetIterationCompleted();
631 if (!SummarizeIteration()) {
632 MS_LOG(WARNING) << "Summarizing iteration data failed.";
633 }
634 iteration_num_++;
635 LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
636 Server::GetInstance().CancelSafeMode();
637 iteration_state_cv_.notify_all();
638 MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
639 }
640
ForciblyMoveToNextIteration()641 bool Iteration::ForciblyMoveToNextIteration() {
642 NotifyNext(false, "Forcibly move to next iteration.");
643 return true;
644 }
645
SummarizeIteration()646 bool Iteration::SummarizeIteration() {
647 // If the metrics_ is not initialized or the server is not the leader server, do not summarize.
648 if (server_node_->rank_id() != kLeaderServerRank || metrics_ == nullptr) {
649 MS_LOG(INFO) << "This server will not summarize for iteration.";
650 return true;
651 }
652
653 metrics_->set_fl_name(ps::PSContext::instance()->fl_name());
654 metrics_->set_fl_iteration_num(ps::PSContext::instance()->fl_iteration_num());
655 metrics_->set_cur_iteration_num(iteration_num_);
656 metrics_->set_instance_state(instance_state_.load());
657 metrics_->set_loss(loss_);
658 metrics_->set_accuracy(accuracy_);
659 // The joined client number is equal to the threshold of updateModel.
660 size_t update_model_threshold = static_cast<size_t>(
661 std::ceil(ps::PSContext::instance()->start_fl_job_threshold() * ps::PSContext::instance()->update_model_ratio()));
662 metrics_->set_joined_client_num(update_model_threshold);
663 // The rejected client number is equal to threshold of startFLJob minus threshold of updateModel.
664 metrics_->set_rejected_client_num(ps::PSContext::instance()->start_fl_job_threshold() - update_model_threshold);
665
666 if (complete_timestamp_ < start_timestamp_) {
667 MS_LOG(ERROR) << "The complete_timestamp_: " << complete_timestamp_ << ", start_timestamp_: " << start_timestamp_
668 << ". One of them is invalid.";
669 metrics_->set_iteration_time_cost(UINT64_MAX);
670 } else {
671 metrics_->set_iteration_time_cost(complete_timestamp_ - start_timestamp_);
672 }
673
674 if (!metrics_->Summarize()) {
675 MS_LOG(ERROR) << "Summarizing metrics failed.";
676 return false;
677 }
678 return true;
679 }
680
UpdateHyperParams(const nlohmann::json & json)681 bool Iteration::UpdateHyperParams(const nlohmann::json &json) {
682 for (const auto &item : json.items()) {
683 std::string key = item.key();
684 if (key == "start_fl_job_threshold") {
685 ps::PSContext::instance()->set_start_fl_job_threshold(item.value().get<uint64_t>());
686 continue;
687 }
688 if (key == "start_fl_job_time_window") {
689 ps::PSContext::instance()->set_start_fl_job_time_window(item.value().get<uint64_t>());
690 continue;
691 }
692 if (key == "update_model_ratio") {
693 ps::PSContext::instance()->set_update_model_ratio(item.value().get<float>());
694 continue;
695 }
696 if (key == "update_model_time_window") {
697 ps::PSContext::instance()->set_update_model_time_window(item.value().get<uint64_t>());
698 continue;
699 }
700 if (key == "fl_iteration_num") {
701 ps::PSContext::instance()->set_fl_iteration_num(item.value().get<uint64_t>());
702 continue;
703 }
704 if (key == "client_epoch_num") {
705 ps::PSContext::instance()->set_client_epoch_num(item.value().get<uint64_t>());
706 continue;
707 }
708 if (key == "client_batch_size") {
709 ps::PSContext::instance()->set_client_batch_size(item.value().get<uint64_t>());
710 continue;
711 }
712 if (key == "client_learning_rate") {
713 ps::PSContext::instance()->set_client_learning_rate(item.value().get<float>());
714 continue;
715 }
716 }
717 return true;
718 }
719
ReInitRounds()720 bool Iteration::ReInitRounds() {
721 size_t start_fl_job_threshold = ps::PSContext::instance()->start_fl_job_threshold();
722 float update_model_ratio = ps::PSContext::instance()->update_model_ratio();
723 size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * update_model_ratio));
724 uint64_t start_fl_job_time_window = ps::PSContext::instance()->start_fl_job_time_window();
725 uint64_t update_model_time_window = ps::PSContext::instance()->update_model_time_window();
726 std::vector<RoundConfig> new_round_config = {
727 {"startFLJob", true, start_fl_job_time_window, true, start_fl_job_threshold},
728 {"updateModel", true, update_model_time_window, true, update_model_threshold}};
729 if (!ReInitForUpdatingHyperParams(new_round_config)) {
730 MS_LOG(ERROR) << "Reinitializing for updating hyper-parameters failed.";
731 return false;
732 }
733
734 size_t executor_threshold = 0;
735 const std::string &server_mode = ps::PSContext::instance()->server_mode();
736 uint32_t worker_num = ps::PSContext::instance()->initial_worker_num();
737 if (server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) {
738 executor_threshold = update_model_threshold;
739 } else if (server_mode == ps::kServerModePS) {
740 executor_threshold = worker_num;
741 } else {
742 MS_LOG(ERROR) << "Server mode " << server_mode << " is not supported.";
743 return false;
744 }
745 if (!Executor::GetInstance().ReInitForUpdatingHyperParams(executor_threshold)) {
746 MS_LOG(ERROR) << "Reinitializing executor failed.";
747 return false;
748 }
749 return true;
750 }
751 } // namespace server
752 } // namespace fl
753 } // namespace mindspore
754