• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fl/server/server.h"
18 #include <memory>
19 #include <string>
20 #include <csignal>
21 #ifdef ENABLE_ARMOUR
22 #include "fl/armour/secure_protocol/secret_sharing.h"
23 #endif
24 #include "fl/server/round.h"
25 #include "fl/server/model_store.h"
26 #include "fl/server/iteration.h"
27 #include "fl/server/collective_ops_impl.h"
28 #include "fl/server/distributed_metadata_store.h"
29 #include "fl/server/distributed_count_service.h"
30 #include "fl/server/kernel/round/round_kernel_factory.h"
31 
32 namespace mindspore {
33 namespace fl {
34 namespace server {
35 // The handler to capture the signal of SIGTERM. Normally this signal is triggered by cloud cluster managers like K8S.
36 std::shared_ptr<ps::core::CommunicatorBase> g_communicator_with_server = nullptr;
37 std::vector<std::shared_ptr<ps::core::CommunicatorBase>> g_communicators_with_worker = {};
SignalHandler(int signal)38 void SignalHandler(int signal) {
39   MS_LOG(WARNING) << "SIGTERM captured: " << signal;
40   (void)std::for_each(g_communicators_with_worker.begin(), g_communicators_with_worker.end(),
41                       [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
42                         MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
43                         (void)communicator->Stop();
44                       });
45 
46   MS_ERROR_IF_NULL_WO_RET_VAL(g_communicator_with_server);
47   (void)g_communicator_with_server->Stop();
48   return;
49 }
50 
Initialize(bool use_tcp,bool use_http,uint16_t http_port,const std::vector<RoundConfig> & rounds_config,const CipherConfig & cipher_config,const FuncGraphPtr & func_graph,size_t executor_threshold)51 void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
52                         const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold) {
53   MS_EXCEPTION_IF_NULL(func_graph);
54   func_graph_ = func_graph;
55 
56   if (rounds_config.empty()) {
57     MS_LOG(EXCEPTION) << "Rounds are empty.";
58     return;
59   }
60   rounds_config_ = rounds_config;
61   cipher_config_ = cipher_config;
62 
63   use_tcp_ = use_tcp;
64   use_http_ = use_http;
65   http_port_ = http_port;
66   executor_threshold_ = executor_threshold;
67   (void)signal(SIGTERM, SignalHandler);
68   return;
69 }
70 
71 // Each step of the server pipeline may have dependency on other steps, which includes:
72 
73 // InitServerContext must be the first step to set contexts for later steps.
74 
75 // Server Running relies on URL or Message Type Register:
76 // StartCommunicator---->InitIteration
77 
78 // Metadata Register relies on Hash Ring of Servers which relies on Network Building Completion:
79 // RegisterRoundKernel---->StartCommunicator
80 
81 // Kernel Initialization relies on Executor Initialization:
82 // RegisterRoundKernel---->InitExecutor
83 
84 // Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization:
85 // InitCipher---->InitExecutor
Run()86 void Server::Run() {
87   std::unique_lock<std::mutex> lock(scaling_mtx_);
88   InitServerContext();
89   InitCluster();
90   InitIteration();
91   RegisterCommCallbacks();
92   StartCommunicator();
93   InitExecutor();
94   std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
95   if (encrypt_type != ps::kNotEncryptType) {
96     InitCipher();
97     MS_LOG(INFO) << "Parameters for secure aggregation have been initiated.";
98   }
99   RegisterRoundKernel();
100   InitMetrics();
101   MS_LOG(INFO) << "Server started successfully.";
102   safemode_ = false;
103   lock.unlock();
104 
105   // Wait communicators to stop so the main thread is blocked.
106   (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
107                       [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
108                         MS_EXCEPTION_IF_NULL(communicator);
109                         communicator->Join();
110                       });
111   MS_EXCEPTION_IF_NULL(communicator_with_server_);
112   communicator_with_server_->Join();
113   MsException::Instance().CheckException();
114   return;
115 }
116 
SwitchToSafeMode()117 void Server::SwitchToSafeMode() {
118   MS_LOG(INFO) << "Server switch to safemode.";
119   safemode_ = true;
120 }
121 
CancelSafeMode()122 void Server::CancelSafeMode() {
123   MS_LOG(INFO) << "Server cancel safemode.";
124   safemode_ = false;
125 }
126 
IsSafeMode() const127 bool Server::IsSafeMode() const { return safemode_.load(); }
128 
WaitExitSafeMode() const129 void Server::WaitExitSafeMode() const {
130   while (safemode_.load()) {
131     std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
132   }
133 }
134 
InitServerContext()135 void Server::InitServerContext() {
136   ps::PSContext::instance()->GenerateResetterRound();
137   scheduler_ip_ = ps::PSContext::instance()->scheduler_host();
138   scheduler_port_ = ps::PSContext::instance()->scheduler_port();
139   worker_num_ = ps::PSContext::instance()->initial_worker_num();
140   server_num_ = ps::PSContext::instance()->initial_server_num();
141   std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
142   if (encrypt_type == ps::kPWEncryptType && server_num_ > 1) {
143     MS_LOG(EXCEPTION) << "Only single server is supported for PW_ENCRYPT now, but got server_num is:." << server_num_;
144     return;
145   }
146   return;
147 }
148 
InitCluster()149 void Server::InitCluster() {
150   server_node_ = std::make_shared<ps::core::ServerNode>();
151   MS_EXCEPTION_IF_NULL(server_node_);
152   task_executor_ = std::make_shared<ps::core::TaskExecutor>(kExecutorThreadPoolSize);
153   MS_EXCEPTION_IF_NULL(task_executor_);
154   if (!InitCommunicatorWithServer()) {
155     MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed.";
156     return;
157   }
158   if (!InitCommunicatorWithWorker()) {
159     MS_LOG(EXCEPTION) << "Initializing worker-server communicator failed.";
160     return;
161   }
162   return;
163 }
164 
InitCommunicatorWithServer()165 bool Server::InitCommunicatorWithServer() {
166   MS_EXCEPTION_IF_NULL(task_executor_);
167   MS_EXCEPTION_IF_NULL(server_node_);
168   communicator_with_server_ = server_node_->GetOrCreateTcpComm(scheduler_ip_, static_cast<int16_t>(scheduler_port_),
169                                                                worker_num_, server_num_, task_executor_);
170   MS_EXCEPTION_IF_NULL(communicator_with_server_);
171   g_communicator_with_server = communicator_with_server_;
172   return true;
173 }
174 
InitCommunicatorWithWorker()175 bool Server::InitCommunicatorWithWorker() {
176   MS_EXCEPTION_IF_NULL(server_node_);
177   MS_EXCEPTION_IF_NULL(task_executor_);
178   if (!use_tcp_ && !use_http_) {
179     MS_LOG(EXCEPTION) << "At least one type of protocol should be set.";
180     return false;
181   }
182   if (use_tcp_) {
183     MS_EXCEPTION_IF_NULL(communicator_with_server_);
184     auto tcp_comm = communicator_with_server_;
185     MS_EXCEPTION_IF_NULL(tcp_comm);
186     communicators_with_worker_.push_back(tcp_comm);
187   }
188   if (use_http_) {
189     auto http_comm = server_node_->GetOrCreateHttpComm(server_node_->BoundIp(), http_port_, task_executor_);
190     MS_EXCEPTION_IF_NULL(http_comm);
191     communicators_with_worker_.push_back(http_comm);
192   }
193   g_communicators_with_worker = communicators_with_worker_;
194   return true;
195 }
196 
InitIteration()197 void Server::InitIteration() {
198   iteration_ = &Iteration::GetInstance();
199   MS_EXCEPTION_IF_NULL(iteration_);
200 
201   // 1.Add rounds to the iteration according to the server mode.
202   for (const RoundConfig &config : rounds_config_) {
203     std::shared_ptr<Round> round =
204       std::make_shared<Round>(config.name, config.check_timeout, config.time_window, config.check_count,
205                               config.threshold_count, config.server_num_as_threshold);
206     MS_LOG(INFO) << "Add round " << config.name << ", check_timeout: " << config.check_timeout
207                  << ", time window: " << config.time_window << ", check_count: " << config.check_count
208                  << ", threshold: " << config.threshold_count
209                  << ", server_num_as_threshold: " << config.server_num_as_threshold;
210     iteration_->AddRound(round);
211   }
212 
213 #ifdef ENABLE_ARMOUR
214   std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
215   if (encrypt_type == ps::kPWEncryptType) {
216     cipher_exchange_keys_cnt_ = cipher_config_.exchange_keys_threshold;
217     cipher_get_keys_cnt_ = cipher_config_.get_keys_threshold;
218     cipher_share_secrets_cnt_ = cipher_config_.share_secrets_threshold;
219     cipher_get_secrets_cnt_ = cipher_config_.get_secrets_threshold;
220     cipher_get_clientlist_cnt_ = cipher_config_.client_list_threshold;
221     cipher_reconstruct_secrets_up_cnt_ = cipher_config_.reconstruct_secrets_threshold;
222     cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold - 1;
223     cipher_time_window_ = cipher_config_.cipher_time_window;
224 
225     MS_LOG(INFO) << "Initializing cipher:";
226     MS_LOG(INFO) << " cipher_exchange_keys_cnt_: " << cipher_exchange_keys_cnt_
227                  << " cipher_get_keys_cnt_: " << cipher_get_keys_cnt_
228                  << " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_;
229     MS_LOG(INFO) << " cipher_get_secrets_cnt_: " << cipher_get_secrets_cnt_
230                  << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
231                  << " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_
232                  << " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_
233                  << " cipher_time_window_: " << cipher_time_window_;
234   }
235 #endif
236 
237   // 2.Initialize all the rounds.
238   TimeOutCb time_out_cb = std::bind(&Iteration::NotifyNext, iteration_, std::placeholders::_1, std::placeholders::_2);
239   FinishIterCb finish_iter_cb =
240     std::bind(&Iteration::NotifyNext, iteration_, std::placeholders::_1, std::placeholders::_2);
241   iteration_->InitRounds(communicators_with_worker_, time_out_cb, finish_iter_cb);
242   return;
243 }
244 
InitCipher()245 void Server::InitCipher() {
246 #ifdef ENABLE_ARMOUR
247   cipher_init_ = &armour::CipherInit::GetInstance();
248   int cipher_t = SizeToInt(cipher_reconstruct_secrets_down_cnt_);
249   unsigned char cipher_p[SECRET_MAX_LEN] = {0};
250   const int cipher_g = 1;
251   float dp_eps = ps::PSContext::instance()->dp_eps();
252   float dp_delta = ps::PSContext::instance()->dp_delta();
253   float dp_norm_clip = ps::PSContext::instance()->dp_norm_clip();
254   std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
255 
256   mindspore::armour::CipherPublicPara param;
257   param.g = cipher_g;
258   param.t = cipher_t;
259   int ret = memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, sizeof(cipher_p));
260   if (ret != 0) {
261     MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
262     return;
263   }
264   param.dp_delta = dp_delta;
265   param.dp_eps = dp_eps;
266   param.dp_norm_clip = dp_norm_clip;
267   param.encrypt_type = encrypt_type;
268 
269   BIGNUM *prim = BN_new();
270   if (prim == NULL) {
271     MS_LOG(EXCEPTION) << "new bn failed.";
272     ret = -1;
273   } else {
274     ret = mindspore::armour::GetPrime(prim);
275   }
276   if (ret == 0) {
277     (void)BN_bn2bin(prim, reinterpret_cast<uint8_t *>(param.prime));
278   } else {
279     MS_LOG(EXCEPTION) << "Get prime failed.";
280   }
281   if (prim != NULL) {
282     BN_clear_free(prim);
283   }
284 
285   (void)cipher_init_->Init(param, 0, cipher_exchange_keys_cnt_, cipher_get_keys_cnt_, cipher_share_secrets_cnt_,
286                            cipher_get_secrets_cnt_, cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_up_cnt_);
287 #endif
288 }
289 
RegisterCommCallbacks()290 void Server::RegisterCommCallbacks() {
291   // The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register
292   // rounds' callbacks.
293   MS_EXCEPTION_IF_NULL(server_node_);
294   MS_EXCEPTION_IF_NULL(iteration_);
295 
296   auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
297   MS_EXCEPTION_IF_NULL(tcp_comm);
298 
299   // Set message callbacks for server-to-server communication.
300   DistributedMetadataStore::GetInstance().RegisterMessageCallback(tcp_comm);
301   DistributedCountService::GetInstance().RegisterMessageCallback(tcp_comm);
302   iteration_->RegisterMessageCallback(tcp_comm);
303   iteration_->RegisterEventCallback(server_node_);
304 
305   // Set exception event callbacks for server.
306   RegisterExceptionEventCallback(tcp_comm);
307   // Set message callbacks for server.
308   RegisterMessageCallback(tcp_comm);
309 
310   if (!server_node_->InitFollowerScaler()) {
311     MS_LOG(EXCEPTION) << "Initializing follower elastic scaler failed.";
312     return;
313   }
314   // Set scaling barriers before scaling.
315   server_node_->RegisterFollowerScalerBarrierBeforeScaleOut("ServerPipeline",
316                                                             std::bind(&Server::ProcessBeforeScalingOut, this));
317   server_node_->RegisterFollowerScalerBarrierBeforeScaleIn("ServerPipeline",
318                                                            std::bind(&Server::ProcessBeforeScalingIn, this));
319   // Set handlers after scheduler scaling operations are done.
320   server_node_->RegisterFollowerScalerHandlerAfterScaleOut("ServerPipeline",
321                                                            std::bind(&Server::ProcessAfterScalingOut, this));
322   server_node_->RegisterFollowerScalerHandlerAfterScaleIn("ServerPipeline",
323                                                           std::bind(&Server::ProcessAfterScalingIn, this));
324 }
325 
RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpCommunicator> & communicator)326 void Server::RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
327   MS_EXCEPTION_IF_NULL(communicator);
328   communicator->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() {
329     MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
330     safemode_ = true;
331     (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
332                         [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
333                           MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
334                           (void)communicator->Stop();
335                         });
336 
337     MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
338     (void)communicator_with_server_->Stop();
339   });
340 
341   communicator->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [&]() {
342     MS_LOG(ERROR)
343       << "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
344          "network building phase.";
345     safemode_ = true;
346     (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
347                         [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
348                           MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
349                           (void)communicator->Stop();
350                         });
351 
352     MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
353     (void)communicator_with_server_->Stop();
354   });
355 }
356 
RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> & communicator)357 void Server::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
358   MS_EXCEPTION_IF_NULL(communicator);
359   // Register handler for restful requests receviced by scheduler.
360   communicator->RegisterMsgCallBack("enableFLS",
361                                     std::bind(&Server::HandleEnableServerRequest, this, std::placeholders::_1));
362   communicator->RegisterMsgCallBack("disableFLS",
363                                     std::bind(&Server::HandleDisableServerRequest, this, std::placeholders::_1));
364   communicator->RegisterMsgCallBack("newInstance",
365                                     std::bind(&Server::HandleNewInstanceRequest, this, std::placeholders::_1));
366   communicator->RegisterMsgCallBack("queryInstance",
367                                     std::bind(&Server::HandleQueryInstanceRequest, this, std::placeholders::_1));
368 }
369 
InitExecutor()370 void Server::InitExecutor() {
371   MS_EXCEPTION_IF_NULL(func_graph_);
372   if (executor_threshold_ == 0) {
373     MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0.";
374     return;
375   }
376   // The train engine instance is used in both push-type and pull-type kernels,
377   // so the required_cnt of these kernels must be the same as executor_threshold_.
378   MS_LOG(INFO) << "Required count for push-type and pull-type kernels is " << executor_threshold_;
379   Executor::GetInstance().Initialize(func_graph_, executor_threshold_);
380   ModelStore::GetInstance().Initialize();
381   return;
382 }
383 
RegisterRoundKernel()384 void Server::RegisterRoundKernel() {
385   MS_EXCEPTION_IF_NULL(iteration_);
386   auto &rounds = iteration_->rounds();
387   if (rounds.empty()) {
388     MS_LOG(EXCEPTION) << "Server has no round registered.";
389     return;
390   }
391 
392   for (auto &round : rounds) {
393     MS_EXCEPTION_IF_NULL(round);
394     const std::string &name = round->name();
395     std::shared_ptr<kernel::RoundKernel> round_kernel = kernel::RoundKernelFactory::GetInstance().Create(name);
396     if (round_kernel == nullptr) {
397       MS_LOG(EXCEPTION) << "Round kernel for round " << name << " is not registered.";
398       return;
399     }
400 
401     // For some round kernels, the threshold count should be set.
402     round_kernel->InitKernel(round->threshold_count());
403     round->BindRoundKernel(round_kernel);
404   }
405   return;
406 }
407 
InitMetrics()408 void Server::InitMetrics() {
409   if (server_node_->rank_id() == kLeaderServerRank) {
410     MS_EXCEPTION_IF_NULL(iteration_);
411     std::shared_ptr<IterationMetrics> iteration_metrics =
412       std::make_shared<IterationMetrics>(ps::PSContext::instance()->config_file_path());
413     if (!iteration_metrics->Initialize()) {
414       MS_LOG(WARNING) << "Initializing metrics failed.";
415       return;
416     }
417     iteration_->set_metrics(iteration_metrics);
418   }
419 }
420 
StartCommunicator()421 void Server::StartCommunicator() {
422   if (communicators_with_worker_.empty()) {
423     MS_LOG(EXCEPTION) << "Communicators for communication with worker is empty.";
424     return;
425   }
426 
427   MS_EXCEPTION_IF_NULL(server_node_);
428   MS_EXCEPTION_IF_NULL(communicator_with_server_);
429   MS_LOG(INFO) << "Start communicator with server.";
430   if (!communicator_with_server_->Start()) {
431     MS_LOG(EXCEPTION) << "Starting communicator with server failed.";
432     return;
433   }
434   DistributedMetadataStore::GetInstance().Initialize(server_node_);
435   CollectiveOpsImpl::GetInstance().Initialize(server_node_);
436   DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank);
437   MS_LOG(INFO) << "This server rank is " << server_node_->rank_id();
438 
439   MS_LOG(INFO) << "Start communicator with worker.";
440   (void)std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
441                       [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
442                         MS_ERROR_IF_NULL_WO_RET_VAL(communicator);
443                         if (!communicator->Start()) {
444                           MS_LOG(EXCEPTION) << "Starting communicator with worker failed.";
445                         }
446                       });
447 }
448 
ProcessBeforeScalingOut()449 void Server::ProcessBeforeScalingOut() {
450   MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
451   iteration_->ScalingBarrier();
452   safemode_ = true;
453 }
454 
ProcessBeforeScalingIn()455 void Server::ProcessBeforeScalingIn() {
456   MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
457   iteration_->ScalingBarrier();
458   safemode_ = true;
459 }
460 
ProcessAfterScalingOut()461 void Server::ProcessAfterScalingOut() {
462   std::unique_lock<std::mutex> lock(scaling_mtx_);
463   MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
464   if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) {
465     MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
466   }
467   if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) {
468     MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
469   }
470   if (!DistributedCountService::GetInstance().ReInitForScaling()) {
471     MS_LOG(WARNING) << "DistributedCountService reinitializing failed.";
472   }
473   if (!iteration_->ReInitForScaling(IntToUint(server_node_->server_num()), server_node_->rank_id())) {
474     MS_LOG(WARNING) << "Iteration reinitializing failed.";
475   }
476   if (!Executor::GetInstance().ReInitForScaling()) {
477     MS_LOG(WARNING) << "Executor reinitializing failed.";
478   }
479   std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
480   safemode_ = false;
481 }
482 
ProcessAfterScalingIn()483 void Server::ProcessAfterScalingIn() {
484   std::unique_lock<std::mutex> lock(scaling_mtx_);
485   MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
486   if (server_node_->rank_id() == UINT32_MAX) {
487     MS_LOG(WARNING) << "This server the one to be scaled in. Server need to wait SIGTERM to exit.";
488     return;
489   }
490 
491   // If the server is not the one to be scaled in, reintialize modules and recover service.
492   if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) {
493     MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
494   }
495   if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) {
496     MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
497   }
498   if (!DistributedCountService::GetInstance().ReInitForScaling()) {
499     MS_LOG(WARNING) << "DistributedCountService reinitializing failed.";
500   }
501   if (!iteration_->ReInitForScaling(IntToUint(server_node_->server_num()), server_node_->rank_id())) {
502     MS_LOG(WARNING) << "Iteration reinitializing failed.";
503   }
504   if (!Executor::GetInstance().ReInitForScaling()) {
505     MS_LOG(WARNING) << "Executor reinitializing failed.";
506   }
507   std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
508   safemode_ = false;
509 }
510 
HandleEnableServerRequest(const std::shared_ptr<ps::core::MessageHandler> & message)511 void Server::HandleEnableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
512   MS_ERROR_IF_NULL_WO_RET_VAL(message);
513   MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
514   MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
515   auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
516   MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
517 
518   std::string result_message = "";
519   bool result = iteration_->EnableServerInstance(&result_message);
520   nlohmann::json response;
521   response["result"] = result;
522   response["message"] = result_message;
523   if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
524     MS_LOG(ERROR) << "Sending response failed.";
525     return;
526   }
527 }
528 
HandleDisableServerRequest(const std::shared_ptr<ps::core::MessageHandler> & message)529 void Server::HandleDisableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
530   MS_ERROR_IF_NULL_WO_RET_VAL(message);
531   MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
532   MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
533   auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
534   MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
535 
536   std::string result_message = "";
537   bool result = iteration_->DisableServerInstance(&result_message);
538   nlohmann::json response;
539   response["result"] = result;
540   response["message"] = result_message;
541   if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
542     MS_LOG(ERROR) << "Sending response failed.";
543     return;
544   }
545 }
546 
HandleNewInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> & message)547 void Server::HandleNewInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
548   MS_ERROR_IF_NULL_WO_RET_VAL(message);
549   MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
550   MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
551   auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
552   MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
553 
554   MS_ERROR_IF_NULL_WO_RET_VAL(message->data());
555   std::string hyper_params_str(static_cast<const char *>(message->data()), message->len());
556   nlohmann::json new_instance_json;
557   nlohmann::json response;
558   try {
559     new_instance_json = nlohmann::json::parse(hyper_params_str);
560   } catch (const std::exception &e) {
561     response["result"] = false;
562     response["message"] = "The hyper-parameter data is not in json format.";
563     if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
564       MS_LOG(ERROR) << "Sending response failed.";
565       return;
566     }
567   }
568 
569   std::string result_message = "";
570   bool result = iteration_->NewInstance(new_instance_json, &result_message);
571   response["result"] = result;
572   response["message"] = result_message;
573   if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
574     MS_LOG(ERROR) << "Sending response failed.";
575     return;
576   }
577 }
578 
HandleQueryInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> & message)579 void Server::HandleQueryInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
580   MS_ERROR_IF_NULL_WO_RET_VAL(message);
581   nlohmann::basic_json<std::map, std::vector, std::string, bool, int64_t, uint64_t, float> response;
582   response["start_fl_job_threshold"] = ps::PSContext::instance()->start_fl_job_threshold();
583   response["start_fl_job_time_window"] = ps::PSContext::instance()->start_fl_job_time_window();
584   response["update_model_ratio"] = ps::PSContext::instance()->update_model_ratio();
585   response["update_model_time_window"] = ps::PSContext::instance()->update_model_time_window();
586   response["fl_iteration_num"] = ps::PSContext::instance()->fl_iteration_num();
587   response["client_epoch_num"] = ps::PSContext::instance()->client_epoch_num();
588   response["client_batch_size"] = ps::PSContext::instance()->client_batch_size();
589   response["client_learning_rate"] = ps::PSContext::instance()->client_learning_rate();
590   auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
591   MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
592   if (!tcp_comm->SendResponse(response.dump().c_str(), response.dump().size(), message)) {
593     MS_LOG(ERROR) << "Sending response failed.";
594     return;
595   }
596 }
597 }  // namespace server
598 }  // namespace fl
599 }  // namespace mindspore
600