1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fl/server/server.h"
18 #include <memory>
19 #include <string>
20 #include <csignal>
21 #ifdef ENABLE_ARMOUR
22 #include "fl/armour/secure_protocol/secret_sharing.h"
23 #endif
24 #include "fl/server/round.h"
25 #include "fl/server/model_store.h"
26 #include "fl/server/iteration.h"
27 #include "fl/server/collective_ops_impl.h"
28 #include "fl/server/distributed_metadata_store.h"
29 #include "fl/server/distributed_count_service.h"
30 #include "fl/server/kernel/round/round_kernel_factory.h"
31
32 namespace mindspore {
33 namespace fl {
34 namespace server {
35 // The handler to capture the signal of SIGTERM. Normally this signal is triggered by cloud cluster managers like K8S.
36 std::shared_ptr<ps::core::CommunicatorBase> g_communicator_with_server = nullptr;
37 std::vector<std::shared_ptr<ps::core::CommunicatorBase>> g_communicators_with_worker = {};
SignalHandler(int signal)38 void SignalHandler(int signal) {
39 MS_LOG(WARNING) << "SIGTERM captured: " << signal;
40 (void)std::for_each(g_communicators_with_worker.begin(), g_communicators_with_worker.end(),
41 [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
42 MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
43 (void)communicator->Stop();
44 });
45
46 MS_ERROR_IF_NULL_WO_RET_VAL(g_communicator_with_server);
47 (void)g_communicator_with_server->Stop();
48 return;
49 }
50
Initialize(bool use_tcp,bool use_http,uint16_t http_port,const std::vector<RoundConfig> & rounds_config,const CipherConfig & cipher_config,const FuncGraphPtr & func_graph,size_t executor_threshold)51 void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
52 const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold) {
53 MS_EXCEPTION_IF_NULL(func_graph);
54 func_graph_ = func_graph;
55
56 if (rounds_config.empty()) {
57 MS_LOG(EXCEPTION) << "Rounds are empty.";
58 return;
59 }
60 rounds_config_ = rounds_config;
61 cipher_config_ = cipher_config;
62
63 use_tcp_ = use_tcp;
64 use_http_ = use_http;
65 http_port_ = http_port;
66 executor_threshold_ = executor_threshold;
67 (void)signal(SIGTERM, SignalHandler);
68 return;
69 }
70
71 // Each step of the server pipeline may have dependency on other steps, which includes:
72
73 // InitServerContext must be the first step to set contexts for later steps.
74
75 // Server Running relies on URL or Message Type Register:
76 // StartCommunicator---->InitIteration
77
78 // Metadata Register relies on Hash Ring of Servers which relies on Network Building Completion:
79 // RegisterRoundKernel---->StartCommunicator
80
81 // Kernel Initialization relies on Executor Initialization:
82 // RegisterRoundKernel---->InitExecutor
83
84 // Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization:
85 // InitCipher---->InitExecutor
Run()86 void Server::Run() {
87 std::unique_lock<std::mutex> lock(scaling_mtx_);
88 InitServerContext();
89 InitCluster();
90 InitIteration();
91 RegisterCommCallbacks();
92 StartCommunicator();
93 InitExecutor();
94 std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
95 if (encrypt_type != ps::kNotEncryptType) {
96 InitCipher();
97 MS_LOG(INFO) << "Parameters for secure aggregation have been initiated.";
98 }
99 RegisterRoundKernel();
100 InitMetrics();
101 MS_LOG(INFO) << "Server started successfully.";
102 safemode_ = false;
103 lock.unlock();
104
105 // Wait communicators to stop so the main thread is blocked.
106 (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
107 [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
108 MS_EXCEPTION_IF_NULL(communicator);
109 communicator->Join();
110 });
111 MS_EXCEPTION_IF_NULL(communicator_with_server_);
112 communicator_with_server_->Join();
113 MsException::Instance().CheckException();
114 return;
115 }
116
SwitchToSafeMode()117 void Server::SwitchToSafeMode() {
118 MS_LOG(INFO) << "Server switch to safemode.";
119 safemode_ = true;
120 }
121
CancelSafeMode()122 void Server::CancelSafeMode() {
123 MS_LOG(INFO) << "Server cancel safemode.";
124 safemode_ = false;
125 }
126
IsSafeMode() const127 bool Server::IsSafeMode() const { return safemode_.load(); }
128
WaitExitSafeMode() const129 void Server::WaitExitSafeMode() const {
130 while (safemode_.load()) {
131 std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
132 }
133 }
134
InitServerContext()135 void Server::InitServerContext() {
136 ps::PSContext::instance()->GenerateResetterRound();
137 scheduler_ip_ = ps::PSContext::instance()->scheduler_host();
138 scheduler_port_ = ps::PSContext::instance()->scheduler_port();
139 worker_num_ = ps::PSContext::instance()->initial_worker_num();
140 server_num_ = ps::PSContext::instance()->initial_server_num();
141 std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
142 if (encrypt_type == ps::kPWEncryptType && server_num_ > 1) {
143 MS_LOG(EXCEPTION) << "Only single server is supported for PW_ENCRYPT now, but got server_num is:." << server_num_;
144 return;
145 }
146 return;
147 }
148
InitCluster()149 void Server::InitCluster() {
150 server_node_ = std::make_shared<ps::core::ServerNode>();
151 MS_EXCEPTION_IF_NULL(server_node_);
152 task_executor_ = std::make_shared<ps::core::TaskExecutor>(kExecutorThreadPoolSize);
153 MS_EXCEPTION_IF_NULL(task_executor_);
154 if (!InitCommunicatorWithServer()) {
155 MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed.";
156 return;
157 }
158 if (!InitCommunicatorWithWorker()) {
159 MS_LOG(EXCEPTION) << "Initializing worker-server communicator failed.";
160 return;
161 }
162 return;
163 }
164
InitCommunicatorWithServer()165 bool Server::InitCommunicatorWithServer() {
166 MS_EXCEPTION_IF_NULL(task_executor_);
167 MS_EXCEPTION_IF_NULL(server_node_);
168 communicator_with_server_ = server_node_->GetOrCreateTcpComm(scheduler_ip_, static_cast<int16_t>(scheduler_port_),
169 worker_num_, server_num_, task_executor_);
170 MS_EXCEPTION_IF_NULL(communicator_with_server_);
171 g_communicator_with_server = communicator_with_server_;
172 return true;
173 }
174
InitCommunicatorWithWorker()175 bool Server::InitCommunicatorWithWorker() {
176 MS_EXCEPTION_IF_NULL(server_node_);
177 MS_EXCEPTION_IF_NULL(task_executor_);
178 if (!use_tcp_ && !use_http_) {
179 MS_LOG(EXCEPTION) << "At least one type of protocol should be set.";
180 return false;
181 }
182 if (use_tcp_) {
183 MS_EXCEPTION_IF_NULL(communicator_with_server_);
184 auto tcp_comm = communicator_with_server_;
185 MS_EXCEPTION_IF_NULL(tcp_comm);
186 communicators_with_worker_.push_back(tcp_comm);
187 }
188 if (use_http_) {
189 auto http_comm = server_node_->GetOrCreateHttpComm(server_node_->BoundIp(), http_port_, task_executor_);
190 MS_EXCEPTION_IF_NULL(http_comm);
191 communicators_with_worker_.push_back(http_comm);
192 }
193 g_communicators_with_worker = communicators_with_worker_;
194 return true;
195 }
196
InitIteration()197 void Server::InitIteration() {
198 iteration_ = &Iteration::GetInstance();
199 MS_EXCEPTION_IF_NULL(iteration_);
200
201 // 1.Add rounds to the iteration according to the server mode.
202 for (const RoundConfig &config : rounds_config_) {
203 std::shared_ptr<Round> round =
204 std::make_shared<Round>(config.name, config.check_timeout, config.time_window, config.check_count,
205 config.threshold_count, config.server_num_as_threshold);
206 MS_LOG(INFO) << "Add round " << config.name << ", check_timeout: " << config.check_timeout
207 << ", time window: " << config.time_window << ", check_count: " << config.check_count
208 << ", threshold: " << config.threshold_count
209 << ", server_num_as_threshold: " << config.server_num_as_threshold;
210 iteration_->AddRound(round);
211 }
212
213 #ifdef ENABLE_ARMOUR
214 std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
215 if (encrypt_type == ps::kPWEncryptType) {
216 cipher_exchange_keys_cnt_ = cipher_config_.exchange_keys_threshold;
217 cipher_get_keys_cnt_ = cipher_config_.get_keys_threshold;
218 cipher_share_secrets_cnt_ = cipher_config_.share_secrets_threshold;
219 cipher_get_secrets_cnt_ = cipher_config_.get_secrets_threshold;
220 cipher_get_clientlist_cnt_ = cipher_config_.client_list_threshold;
221 cipher_reconstruct_secrets_up_cnt_ = cipher_config_.reconstruct_secrets_threshold;
222 cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold - 1;
223 cipher_time_window_ = cipher_config_.cipher_time_window;
224
225 MS_LOG(INFO) << "Initializing cipher:";
226 MS_LOG(INFO) << " cipher_exchange_keys_cnt_: " << cipher_exchange_keys_cnt_
227 << " cipher_get_keys_cnt_: " << cipher_get_keys_cnt_
228 << " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_;
229 MS_LOG(INFO) << " cipher_get_secrets_cnt_: " << cipher_get_secrets_cnt_
230 << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
231 << " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_
232 << " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_
233 << " cipher_time_window_: " << cipher_time_window_;
234 }
235 #endif
236
237 // 2.Initialize all the rounds.
238 TimeOutCb time_out_cb = std::bind(&Iteration::NotifyNext, iteration_, std::placeholders::_1, std::placeholders::_2);
239 FinishIterCb finish_iter_cb =
240 std::bind(&Iteration::NotifyNext, iteration_, std::placeholders::_1, std::placeholders::_2);
241 iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
242 return;
243 }
244
InitCipher()245 void Server::InitCipher() {
246 #ifdef ENABLE_ARMOUR
247 cipher_init_ = &armour::CipherInit::GetInstance();
248 int cipher_t = SizeToInt(cipher_reconstruct_secrets_down_cnt_);
249 unsigned char cipher_p[SECRET_MAX_LEN] = {0};
250 const int cipher_g = 1;
251 float dp_eps = ps::PSContext::instance()->dp_eps();
252 float dp_delta = ps::PSContext::instance()->dp_delta();
253 float dp_norm_clip = ps::PSContext::instance()->dp_norm_clip();
254 std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
255
256 mindspore::armour::CipherPublicPara param;
257 param.g = cipher_g;
258 param.t = cipher_t;
259 int ret = memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, sizeof(cipher_p));
260 if (ret != 0) {
261 MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
262 return;
263 }
264 param.dp_delta = dp_delta;
265 param.dp_eps = dp_eps;
266 param.dp_norm_clip = dp_norm_clip;
267 param.encrypt_type = encrypt_type;
268
269 BIGNUM *prim = BN_new();
270 if (prim == NULL) {
271 MS_LOG(EXCEPTION) << "new bn failed.";
272 ret = -1;
273 } else {
274 ret = mindspore::armour::GetPrime(prim);
275 }
276 if (ret == 0) {
277 (void)BN_bn2bin(prim, reinterpret_cast<uint8_t *>(param.prime));
278 } else {
279 MS_LOG(EXCEPTION) << "Get prime failed.";
280 }
281 if (prim != NULL) {
282 BN_clear_free(prim);
283 }
284
285 (void)cipher_init_->Init(param, 0, cipher_exchange_keys_cnt_, cipher_get_keys_cnt_, cipher_share_secrets_cnt_,
286 cipher_get_secrets_cnt_, cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_up_cnt_);
287 #endif
288 }
289
RegisterCommCallbacks()290 void Server::RegisterCommCallbacks() {
291 // The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register
292 // rounds' callbacks.
293 MS_EXCEPTION_IF_NULL(server_node_);
294 MS_EXCEPTION_IF_NULL(iteration_);
295
296 auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
297 MS_EXCEPTION_IF_NULL(tcp_comm);
298
299 // Set message callbacks for server-to-server communication.
300 DistributedMetadataStore::GetInstance().RegisterMessageCallback(tcp_comm);
301 DistributedCountService::GetInstance().RegisterMessageCallback(tcp_comm);
302 iteration_->RegisterMessageCallback(tcp_comm);
303 iteration_->RegisterEventCallback(server_node_);
304
305 // Set exception event callbacks for server.
306 RegisterExceptionEventCallback(tcp_comm);
307 // Set message callbacks for server.
308 RegisterMessageCallback(tcp_comm);
309
310 if (!server_node_->InitFollowerScaler()) {
311 MS_LOG(EXCEPTION) << "Initializing follower elastic scaler failed.";
312 return;
313 }
314 // Set scaling barriers before scaling.
315 server_node_->RegisterFollowerScalerBarrierBeforeScaleOut("ServerPipeline",
316 std::bind(&Server::ProcessBeforeScalingOut, this));
317 server_node_->RegisterFollowerScalerBarrierBeforeScaleIn("ServerPipeline",
318 std::bind(&Server::ProcessBeforeScalingIn, this));
319 // Set handlers after scheduler scaling operations are done.
320 server_node_->RegisterFollowerScalerHandlerAfterScaleOut("ServerPipeline",
321 std::bind(&Server::ProcessAfterScalingOut, this));
322 server_node_->RegisterFollowerScalerHandlerAfterScaleIn("ServerPipeline",
323 std::bind(&Server::ProcessAfterScalingIn, this));
324 }
325
RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpCommunicator> & communicator)326 void Server::RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
327 MS_EXCEPTION_IF_NULL(communicator);
328 communicator->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() {
329 MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
330 safemode_ = true;
331 (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
332 [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
333 MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
334 (void)communicator->Stop();
335 });
336
337 MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
338 (void)communicator_with_server_->Stop();
339 });
340
341 communicator->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [&]() {
342 MS_LOG(ERROR)
343 << "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
344 "network building phase.";
345 safemode_ = true;
346 (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
347 [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
348 MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
349 (void)communicator->Stop();
350 });
351
352 MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
353 (void)communicator_with_server_->Stop();
354 });
355 }
356
RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> & communicator)357 void Server::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
358 MS_EXCEPTION_IF_NULL(communicator);
359 // Register handler for restful requests receviced by scheduler.
360 communicator->RegisterMsgCallBack("enableFLS",
361 std::bind(&Server::HandleEnableServerRequest, this, std::placeholders::_1));
362 communicator->RegisterMsgCallBack("disableFLS",
363 std::bind(&Server::HandleDisableServerRequest, this, std::placeholders::_1));
364 communicator->RegisterMsgCallBack("newInstance",
365 std::bind(&Server::HandleNewInstanceRequest, this, std::placeholders::_1));
366 communicator->RegisterMsgCallBack("queryInstance",
367 std::bind(&Server::HandleQueryInstanceRequest, this, std::placeholders::_1));
368 }
369
InitExecutor()370 void Server::InitExecutor() {
371 MS_EXCEPTION_IF_NULL(func_graph_);
372 if (executor_threshold_ == 0) {
373 MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0.";
374 return;
375 }
376 // The train engine instance is used in both push-type and pull-type kernels,
377 // so the required_cnt of these kernels must be the same as executor_threshold_.
378 MS_LOG(INFO) << "Required count for push-type and pull-type kernels is " << executor_threshold_;
379 Executor::GetInstance().Initialize(func_graph_, executor_threshold_);
380 ModelStore::GetInstance().Initialize();
381 return;
382 }
383
RegisterRoundKernel()384 void Server::RegisterRoundKernel() {
385 MS_EXCEPTION_IF_NULL(iteration_);
386 auto &rounds = iteration_->rounds();
387 if (rounds.empty()) {
388 MS_LOG(EXCEPTION) << "Server has no round registered.";
389 return;
390 }
391
392 for (auto &round : rounds) {
393 MS_EXCEPTION_IF_NULL(round);
394 const std::string &name = round->name();
395 std::shared_ptr<kernel::RoundKernel> round_kernel = kernel::RoundKernelFactory::GetInstance().Create(name);
396 if (round_kernel == nullptr) {
397 MS_LOG(EXCEPTION) << "Round kernel for round " << name << " is not registered.";
398 return;
399 }
400
401 // For some round kernels, the threshold count should be set.
402 round_kernel->InitKernel(round->threshold_count());
403 round->BindRoundKernel(round_kernel);
404 }
405 return;
406 }
407
InitMetrics()408 void Server::InitMetrics() {
409 if (server_node_->rank_id() == kLeaderServerRank) {
410 MS_EXCEPTION_IF_NULL(iteration_);
411 std::shared_ptr<IterationMetrics> iteration_metrics =
412 std::make_shared<IterationMetrics>(ps::PSContext::instance()->config_file_path());
413 if (!iteration_metrics->Initialize()) {
414 MS_LOG(WARNING) << "Initializing metrics failed.";
415 return;
416 }
417 iteration_->set_metrics(iteration_metrics);
418 }
419 }
420
StartCommunicator()421 void Server::StartCommunicator() {
422 if (communicators_with_worker_.empty()) {
423 MS_LOG(EXCEPTION) << "Communicators for communication with worker is empty.";
424 return;
425 }
426
427 MS_EXCEPTION_IF_NULL(server_node_);
428 MS_EXCEPTION_IF_NULL(communicator_with_server_);
429 MS_LOG(INFO) << "Start communicator with server.";
430 if (!communicator_with_server_->Start()) {
431 MS_LOG(EXCEPTION) << "Starting communicator with server failed.";
432 return;
433 }
434 DistributedMetadataStore::GetInstance().Initialize(server_node_);
435 CollectiveOpsImpl::GetInstance().Initialize(server_node_);
436 DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank);
437 MS_LOG(INFO) << "This server rank is " << server_node_->rank_id();
438
439 MS_LOG(INFO) << "Start communicator with worker.";
440 (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
441 [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
442 MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
443 if (!communicator->Start()) {
444 MS_LOG(EXCEPTION) << "Starting communicator with worker failed.";
445 }
446 });
447 }
448
ProcessBeforeScalingOut()449 void Server::ProcessBeforeScalingOut() {
450 MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
451 iteration_->ScalingBarrier();
452 safemode_ = true;
453 }
454
ProcessBeforeScalingIn()455 void Server::ProcessBeforeScalingIn() {
456 MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
457 iteration_->ScalingBarrier();
458 safemode_ = true;
459 }
460
ProcessAfterScalingOut()461 void Server::ProcessAfterScalingOut() {
462 std::unique_lock<std::mutex> lock(scaling_mtx_);
463 MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
464 if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) {
465 MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
466 }
467 if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) {
468 MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
469 }
470 if (!DistributedCountService::GetInstance().ReInitForScaling()) {
471 MS_LOG(WARNING) << "DistributedCountService reinitializing failed.";
472 }
473 if (!iteration_->ReInitForScaling(IntToUint(server_node_->server_num()), server_node_->rank_id())) {
474 MS_LOG(WARNING) << "Iteration reinitializing failed.";
475 }
476 if (!Executor::GetInstance().ReInitForScaling()) {
477 MS_LOG(WARNING) << "Executor reinitializing failed.";
478 }
479 std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
480 safemode_ = false;
481 }
482
ProcessAfterScalingIn()483 void Server::ProcessAfterScalingIn() {
484 std::unique_lock<std::mutex> lock(scaling_mtx_);
485 MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
486 if (server_node_->rank_id() == UINT32_MAX) {
487 MS_LOG(WARNING) << "This server the one to be scaled in. Server need to wait SIGTERM to exit.";
488 return;
489 }
490
491 // If the server is not the one to be scaled in, reintialize modules and recover service.
492 if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) {
493 MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
494 }
495 if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) {
496 MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
497 }
498 if (!DistributedCountService::GetInstance().ReInitForScaling()) {
499 MS_LOG(WARNING) << "DistributedCountService reinitializing failed.";
500 }
501 if (!iteration_->ReInitForScaling(IntToUint(server_node_->server_num()), server_node_->rank_id())) {
502 MS_LOG(WARNING) << "Iteration reinitializing failed.";
503 }
504 if (!Executor::GetInstance().ReInitForScaling()) {
505 MS_LOG(WARNING) << "Executor reinitializing failed.";
506 }
507 std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
508 safemode_ = false;
509 }
510
HandleEnableServerRequest(const std::shared_ptr<ps::core::MessageHandler> & message)511 void Server::HandleEnableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
512 MS_ERROR_IF_NULL_WO_RET_VAL(message);
513 MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
514 MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
515 auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
516 MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
517
518 std::string result_message = "";
519 bool result = iteration_->EnableServerInstance(&result_message);
520 nlohmann::json response;
521 response["result"] = result;
522 response["message"] = result_message;
523 if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
524 MS_LOG(ERROR) << "Sending response failed.";
525 return;
526 }
527 }
528
HandleDisableServerRequest(const std::shared_ptr<ps::core::MessageHandler> & message)529 void Server::HandleDisableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
530 MS_ERROR_IF_NULL_WO_RET_VAL(message);
531 MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
532 MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
533 auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
534 MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
535
536 std::string result_message = "";
537 bool result = iteration_->DisableServerInstance(&result_message);
538 nlohmann::json response;
539 response["result"] = result;
540 response["message"] = result_message;
541 if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
542 MS_LOG(ERROR) << "Sending response failed.";
543 return;
544 }
545 }
546
HandleNewInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> & message)547 void Server::HandleNewInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
548 MS_ERROR_IF_NULL_WO_RET_VAL(message);
549 MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
550 MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
551 auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
552 MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
553
554 MS_ERROR_IF_NULL_WO_RET_VAL(message->data());
555 std::string hyper_params_str(static_cast<const char *>(message->data()), message->len());
556 nlohmann::json new_instance_json;
557 nlohmann::json response;
558 try {
559 new_instance_json = nlohmann::json::parse(hyper_params_str);
560 } catch (const std::exception &e) {
561 response["result"] = false;
562 response["message"] = "The hyper-parameter data is not in json format.";
563 if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
564 MS_LOG(ERROR) << "Sending response failed.";
565 return;
566 }
567 }
568
569 std::string result_message = "";
570 bool result = iteration_->NewInstance(new_instance_json, &result_message);
571 response["result"] = result;
572 response["message"] = result_message;
573 if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
574 MS_LOG(ERROR) << "Sending response failed.";
575 return;
576 }
577 }
578
HandleQueryInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> & message)579 void Server::HandleQueryInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
580 MS_ERROR_IF_NULL_WO_RET_VAL(message);
581 nlohmann::basic_json<std::map, std::vector, std::string, bool, int64_t, uint64_t, float> response;
582 response["start_fl_job_threshold"] = ps::PSContext::instance()->start_fl_job_threshold();
583 response["start_fl_job_time_window"] = ps::PSContext::instance()->start_fl_job_time_window();
584 response["update_model_ratio"] = ps::PSContext::instance()->update_model_ratio();
585 response["update_model_time_window"] = ps::PSContext::instance()->update_model_time_window();
586 response["fl_iteration_num"] = ps::PSContext::instance()->fl_iteration_num();
587 response["client_epoch_num"] = ps::PSContext::instance()->client_epoch_num();
588 response["client_batch_size"] = ps::PSContext::instance()->client_batch_size();
589 response["client_learning_rate"] = ps::PSContext::instance()->client_learning_rate();
590 auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
591 MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
592 if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
593 MS_LOG(ERROR) << "Sending response failed.";
594 return;
595 }
596 }
597 } // namespace server
598 } // namespace fl
599 } // namespace mindspore
600