• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/core/scheduler_node.h"
18 
19 namespace mindspore {
20 namespace ps {
21 namespace core {
~SchedulerNode()22 SchedulerNode::~SchedulerNode() {
23   MS_LOG(INFO) << "Stop scheduler node!";
24   if (!Stop()) {
25     MS_LOG(WARNING) << "Scheduler node stop failed.";
26   }
27 }
28 
Start(const uint32_t & timeout)29 bool SchedulerNode::Start(const uint32_t &timeout) {
30   MS_LOG(INFO) << "[Scheduler start]: 1. Begin to start scheduler node!";
31   if (PSContext::instance()->scheduler_manage_port() != 0) {
32     MS_LOG(WARNING) << "Start the scheduler http service, the ip:" << PSContext::instance()->scheduler_ip()
33                     << ", the port:" << PSContext::instance()->scheduler_manage_port();
34     StartRestfulServer(kLocalIp, PSContext::instance()->scheduler_manage_port(), 1);
35   }
36   Initialize();
37   StartUpdateClusterStateTimer();
38   if (!WaitForStart(timeout)) {
39     MS_LOG(ERROR) << "Start Scheduler node timeout!";
40     return false;
41   }
42   node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
43   MS_LOG(INFO) << "[Scheduler start]: 4. Successfully start scheduler, there are " << node_manager_.worker_num()
44                << " workers and " << node_manager_.server_num() << " servers registered.";
45 
46   return true;
47 }
48 
ProcessHeartbeat(const std::shared_ptr<TcpServer> & server,const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)49 void SchedulerNode::ProcessHeartbeat(const std::shared_ptr<TcpServer> &server,
50                                      const std::shared_ptr<TcpConnection> &conn,
51                                      const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
52   MS_EXCEPTION_IF_NULL(server);
53   MS_EXCEPTION_IF_NULL(conn);
54   MS_EXCEPTION_IF_NULL(meta);
55   MS_EXCEPTION_IF_NULL(data);
56   HeartbeatMessage heartbeat_message;
57   CHECK_RETURN_TYPE(heartbeat_message.ParseFromArray(data, SizeToInt(size)));
58 
59   node_manager_.UpdateHeartbeat(heartbeat_message.node_id());
60 
61   HeartbeatRespMessage heartbeat_resp_message;
62 
63   MS_LOG(DEBUG) << "The cluster state:" << CommUtil::ClusterStateToString(node_manager_.GetClusterState());
64   heartbeat_resp_message.set_cluster_state(node_manager_.GetClusterState());
65 
66   std::vector<ServersMeta> servers_meta_list = node_manager_.FetchAllNodesMeta();
67 
68   *heartbeat_resp_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
69 
70   heartbeat_resp_message.set_is_worker_or_server0(node_manager_.IsWorkerOrServer0());
71 
72   if (!server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(),
73                            heartbeat_resp_message.ByteSizeLong())) {
74     MS_LOG(WARNING) << "Send heart beat failed.";
75   }
76 }
77 
Initialize()78 void SchedulerNode::Initialize() {
79   config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
80   MS_EXCEPTION_IF_NULL(config_);
81   if (!config_->Initialize()) {
82     MS_LOG(INFO) << "The config file is empty.";
83   }
84   InitCommandHandler();
85   CreateTcpServer();
86   is_already_stopped_ = false;
87   if (PSContext::instance()->node_id().empty() && config_->Exists(kNodeId)) {
88     node_info_.node_id_ = config_->Get(kNodeId, "");
89   } else {
90     node_info_.node_id_ = PSContext::instance()->node_id();
91   }
92 
93   if (node_info_.node_id_.empty()) {
94     node_info_.node_id_ = CommUtil::GenerateUUID();
95   }
96   node_info_.node_role_ = NodeRole::SCHEDULER;
97   leader_scaler_ = std::make_unique<LeaderScaler>(this);
98   MS_EXCEPTION_IF_NULL(leader_scaler_);
99   instance_manager_ = std::make_unique<InstanceManager>(this);
100   MS_LOG(INFO) << "[Scheduler start]: 2. The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
101                << ", the node id is:" << node_info_.node_id_ << " create a tcp server.";
102 }
103 
InitCommandHandler()104 void SchedulerNode::InitCommandHandler() {
105   handlers_[NodeCommand::HEARTBEAT] = &SchedulerNode::ProcessHeartbeat;
106   handlers_[NodeCommand::REGISTER] = &SchedulerNode::ProcessRegister;
107   handlers_[NodeCommand::FINISH] = &SchedulerNode::ProcessFinish;
108   handlers_[NodeCommand::FETCH_METADATA] = &SchedulerNode::ProcessFetchMetadata;
109   handlers_[NodeCommand::SCALE_OUT_DONE] = &SchedulerNode::ProcessScaleOutDone;
110   handlers_[NodeCommand::SCALE_IN_DONE] = &SchedulerNode::ProcessScaleInDone;
111   handlers_[NodeCommand::SEND_EVENT] = &SchedulerNode::ProcessSendEvent;
112 }
113 
CreateTcpServer()114 void SchedulerNode::CreateTcpServer() {
115   node_manager_.InitNode();
116 
117   std::string scheduler_host = PSContext::instance()->cluster_config().scheduler_host;
118   uint32_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port;
119   server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port, config_.get());
120   MS_EXCEPTION_IF_NULL(server_);
121   server_->SetMessageCallback([&](const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
122                                   const Protos &, const void *data, size_t size) {
123     if (handlers_.count(meta->cmd()) == 0) {
124       MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
125     }
126     const auto &handler_ptr = handlers_[meta->cmd()];
127     (this->*handler_ptr)(server_, conn, meta, data, size);
128   });
129 
130   server_->Init();
131 
132   scheduler_thread_ = std::make_unique<std::thread>([this]() {
133     MS_LOG(INFO) << "The scheduler node start a tcp server!";
134     this->server_->Start();
135   });
136   MS_EXCEPTION_IF_NULL(scheduler_thread_);
137 }
138 
ProcessRegister(const std::shared_ptr<TcpServer> & server,const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)139 void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server,
140                                     const std::shared_ptr<TcpConnection> &conn,
141                                     const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
142   MS_EXCEPTION_IF_NULL(server);
143   MS_EXCEPTION_IF_NULL(conn);
144   MS_EXCEPTION_IF_NULL(meta);
145   MS_EXCEPTION_IF_NULL(data);
146   RegisterMessage register_message;
147   CHECK_RETURN_TYPE(register_message.ParseFromArray(data, SizeToInt(size)));
148 
149   const std::string &node_id = register_message.node_id();
150   node_manager_.UpdateHeartbeat(node_id);
151 
152   MS_LOG(INFO) << "The node id:" << node_id << " is registering to scheduler.";
153   client_mutex_.lock();
154   if (node_manager_.IsNodeRegistered(node_id)) {
155     MS_LOG(INFO) << "The node id is registered.";
156     if (connected_nodes_.count(node_id)) {
157       (void)connected_nodes_.erase(node_id);
158     }
159   }
160   client_mutex_.unlock();
161 
162   // assign worker node and server node rank id
163   uint32_t rank_id = node_manager_.NextRankId(register_message, meta);
164   if (rank_id == UINT32_MAX) {
165     MS_LOG(WARNING) << "The rank id is wrong!";
166   }
167 
168   RegisterRespMessage register_resp_message;
169   register_resp_message.set_node_id(node_id);
170 
171   if (!server->SendMessage(conn, meta, Protos::PROTOBUF, register_resp_message.SerializeAsString().data(),
172                            register_resp_message.ByteSizeLong())) {
173     MS_LOG(WARNING) << "Server response message failed.";
174   }
175 
176   if (node_manager_.IsAllNodesRegistered()) {
177     is_ready_ = true;
178     MS_LOG(INFO) << "There are " << node_manager_.worker_num() << " workers and " << node_manager_.server_num()
179                  << " servers registered to scheduer, so the scheduler send meta data to worker/server.";
180     if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_IN) {
181       auto nodes = node_manager_.nodes_info();
182       for (const auto &id : scale_in_node_ids_) {
183         MS_LOG(INFO) << "The scheduler send metadata to scale in node:" << id;
184         if (nodes.count(id)) {
185           auto scale_in_client = GetOrCreateClient(nodes[id]);
186           SendMetadata(scale_in_client, nodes[id].rank_id_);
187           node_manager_.UpdateHeartbeat(id);
188         }
189       }
190     }
191     node_manager_.UpdateNodesInfo();
192     auto node_infos = node_manager_.nodes_info();
193     for (const auto &kvs : node_infos) {
194       auto client = GetOrCreateClient(kvs.second);
195       MS_EXCEPTION_IF_NULL(client);
196       SendMetadata(client, kvs.second.rank_id_);
197       node_manager_.UpdateHeartbeat(kvs.first);
198     }
199     node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
200     wait_start_cond_.notify_all();
201   }
202 }
203 
ProcessFinish(const std::shared_ptr<TcpServer> & server,const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)204 void SchedulerNode::ProcessFinish(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
205                                   const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
206   MS_EXCEPTION_IF_NULL(server);
207   MS_EXCEPTION_IF_NULL(conn);
208   MS_EXCEPTION_IF_NULL(meta);
209   MS_EXCEPTION_IF_NULL(data);
210   auto finish_message = std::make_unique<std::string>(reinterpret_cast<const char *>(data), size);
211   MS_EXCEPTION_IF_NULL(finish_message);
212   std::string node_id = *finish_message;
213   MS_LOG(INFO) << "Process finish message from node id:" << node_id;
214   if (!server->SendMessage(conn, meta, Protos::PROTOBUF, data, size)) {
215     MS_LOG(WARNING) << "Server response message failed.";
216   }
217 
218   auto iter = std::find_if(scale_in_node_ids_.begin(), scale_in_node_ids_.end(), [node_id](auto item) {
219     if (node_id == item) {
220       MS_LOG(INFO) << "The finish node is a scale in node.";
221       return true;
222     }
223     return false;
224   });
225   if (iter != scale_in_node_ids_.end()) {
226     return;
227   }
228 
229   node_manager_.AddFinishNode(node_id);
230   if (node_manager_.IsAllNodesFinished()) {
231     auto node_infos = node_manager_.nodes_info();
232     for (const auto &kvs : node_infos) {
233       auto client = GetOrCreateClient(kvs.second);
234       SendFinish(client);
235     }
236     is_finish_ = true;
237     node_manager_.UpdateClusterState(ClusterState::CLUSTER_EXIT);
238     wait_finish_cond_.notify_all();
239   }
240 }
241 
ProcessFetchMetadata(const std::shared_ptr<TcpServer> & server,const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t)242 void SchedulerNode::ProcessFetchMetadata(const std::shared_ptr<TcpServer> &server,
243                                          const std::shared_ptr<TcpConnection> &conn,
244                                          const std::shared_ptr<MessageMeta> &meta, const void *data, size_t) {
245   MS_EXCEPTION_IF_NULL(server);
246   MS_EXCEPTION_IF_NULL(conn);
247   MS_EXCEPTION_IF_NULL(meta);
248   MS_EXCEPTION_IF_NULL(data);
249   FetchServersRespMessage fetch_servers_message;
250   std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();
251 
252   *fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
253 
254   if (!server->SendMessage(conn, meta, Protos::PROTOBUF, fetch_servers_message.SerializeAsString().data(),
255                            fetch_servers_message.ByteSizeLong())) {
256     MS_LOG(WARNING) << "Server response message failed.";
257   }
258 }
259 
ProcessScaleOutDone(const std::shared_ptr<TcpServer> & server,const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)260 void SchedulerNode::ProcessScaleOutDone(const std::shared_ptr<TcpServer> &server,
261                                         const std::shared_ptr<TcpConnection> &conn,
262                                         const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
263   MS_EXCEPTION_IF_NULL(server);
264   MS_EXCEPTION_IF_NULL(conn);
265   MS_EXCEPTION_IF_NULL(meta);
266   MS_EXCEPTION_IF_NULL(data);
267   ScaleOutDoneMessage scale_out_done_message;
268   scale_out_done_message.ParseFromArray(data, SizeToInt(size));
269   std::string node_id = scale_out_done_message.node_id();
270   MS_LOG(INFO) << "The scheduler process a scale_out_done message from node id:" << node_id;
271   node_manager_.AddScaleOutDoneNode(node_id);
272 
273   if (!server->SendMessage(conn, meta, Protos::PROTOBUF, data, size)) {
274     MS_LOG(WARNING) << "Server response message failed.";
275   }
276 
277   if (node_manager_.IsAllNodesScaleOutDone()) {
278     auto node_infos = node_manager_.nodes_info();
279     for (const auto &kvs : node_infos) {
280       auto client = GetOrCreateClient(kvs.second);
281       SendScaleOutDone(client);
282     }
283     is_ready_ = true;
284     node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
285   }
286 }
287 
ProcessScaleInDone(const std::shared_ptr<TcpServer> & server,const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)288 void SchedulerNode::ProcessScaleInDone(const std::shared_ptr<TcpServer> &server,
289                                        const std::shared_ptr<TcpConnection> &conn,
290                                        const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
291   MS_EXCEPTION_IF_NULL(server);
292   MS_EXCEPTION_IF_NULL(conn);
293   MS_EXCEPTION_IF_NULL(meta);
294   MS_EXCEPTION_IF_NULL(data);
295   ScaleInDoneMessage scale_in_done_message;
296   scale_in_done_message.ParseFromArray(data, SizeToInt(size));
297   std::string node_id = scale_in_done_message.node_id();
298   MS_LOG(INFO) << "The scheduler process a scale_in_done message from node id:" << node_id;
299   node_manager_.AddScaleInDoneNode(node_id);
300 
301   if (!server->SendMessage(conn, meta, Protos::PROTOBUF, data, size)) {
302     MS_LOG(WARNING) << "Server response message failed.";
303   }
304 
305   if (node_manager_.IsAllNodesScaleInDone()) {
306     auto node_infos = node_manager_.nodes_info();
307     for (const auto &kvs : node_infos) {
308       auto client = GetOrCreateClient(kvs.second);
309       SendScaleInDone(client);
310     }
311     is_ready_ = true;
312     node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
313   }
314 }
315 
ProcessSendEvent(const std::shared_ptr<TcpServer> & server,const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)316 void SchedulerNode::ProcessSendEvent(const std::shared_ptr<TcpServer> &server,
317                                      const std::shared_ptr<TcpConnection> &conn,
318                                      const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
319   MS_EXCEPTION_IF_NULL(server);
320   MS_EXCEPTION_IF_NULL(conn);
321   MS_EXCEPTION_IF_NULL(meta);
322   MS_EXCEPTION_IF_NULL(data);
323   EventMessage event_message;
324   event_message.ParseFromArray(data, SizeToInt(size));
325   std::string node_id = event_message.node_id();
326   uint32_t event = event_message.event();
327   MS_LOG(DEBUG) << "The scheduler process a event message from node id:" << node_id;
328 
329   if (!server->SendMessage(conn, meta, Protos::PROTOBUF, data, size)) {
330     MS_LOG(WARNING) << "Server response message failed.";
331   }
332 
333   auto node_infos = node_manager_.nodes_info();
334   for (const auto &kvs : node_infos) {
335     auto client = GetOrCreateClient(kvs.second);
336     SendEvent(client, event);
337   }
338 }
339 
SendMetadata(const std::shared_ptr<TcpClient> & client,uint32_t rank_id)340 void SchedulerNode::SendMetadata(const std::shared_ptr<TcpClient> &client, uint32_t rank_id) {
341   MS_EXCEPTION_IF_NULL(client);
342   auto message_meta = std::make_shared<MessageMeta>();
343   MS_EXCEPTION_IF_NULL(message_meta);
344   message_meta->set_cmd(NodeCommand::SEND_METADATA);
345 
346   SendMetadataMessage send_metadata_message;
347   std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();
348   send_metadata_message.set_worker_num(node_manager_.worker_num());
349   send_metadata_message.set_server_num(node_manager_.server_num());
350   send_metadata_message.set_cluster_state(node_manager_.GetClusterState());
351   send_metadata_message.set_rank_id(rank_id);
352 
353   *send_metadata_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
354 
355   if (!SendMessageAsync(client, message_meta, Protos::PROTOBUF, send_metadata_message.SerializeAsString().data(),
356                         send_metadata_message.ByteSizeLong())) {
357     MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
358                       << " the node id:" << node_info_.node_id_ << " send metadata timeout!";
359   }
360 
361   MS_LOG(DEBUG) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
362                 << " the node id:" << node_info_.node_id_ << "is sending metadata to workers and servers!";
363 }
364 
SendFinish(const std::shared_ptr<TcpClient> & client)365 void SchedulerNode::SendFinish(const std::shared_ptr<TcpClient> &client) {
366   MS_EXCEPTION_IF_NULL(client);
367   auto message_meta = std::make_shared<MessageMeta>();
368   MS_EXCEPTION_IF_NULL(message_meta);
369   message_meta->set_cmd(NodeCommand::FINISH);
370 
371   // The scheduler does not need to bring any data when sending the finish command
372   std::string resp_data;
373 
374   if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, resp_data.data(), resp_data.size())) {
375     MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
376                       << " the node id:" << node_info_.node_id_ << " send finish timeout!";
377   }
378 
379   MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
380                << " the node id:" << node_info_.node_id_ << "is sending finish to workers and servers!";
381 }
382 
SendScaleOutDone(const std::shared_ptr<TcpClient> & client)383 void SchedulerNode::SendScaleOutDone(const std::shared_ptr<TcpClient> &client) {
384   MS_EXCEPTION_IF_NULL(client);
385   auto message_meta = std::make_shared<MessageMeta>();
386   MS_EXCEPTION_IF_NULL(message_meta);
387   message_meta->set_cmd(NodeCommand::SCALE_OUT_DONE);
388 
389   // The scheduler does not need to bring any data when sending the scale_out_done command
390   std::string resp_data;
391 
392   if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, resp_data.data(), resp_data.size())) {
393     MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
394                       << " the node id:" << node_info_.node_id_ << " send scale_out_done timeout!";
395   }
396 
397   MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
398                << " the node id:" << node_info_.node_id_ << "is sending scale_out_done to workers and servers!";
399 }
400 
SendScaleInDone(const std::shared_ptr<TcpClient> & client)401 void SchedulerNode::SendScaleInDone(const std::shared_ptr<TcpClient> &client) {
402   MS_EXCEPTION_IF_NULL(client);
403   auto message_meta = std::make_shared<MessageMeta>();
404   MS_EXCEPTION_IF_NULL(message_meta);
405   message_meta->set_cmd(NodeCommand::SCALE_IN_DONE);
406 
407   // The scheduler does not need to bring any data when sending the scale_in_done command
408   std::string resp_data;
409 
410   if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, resp_data.data(), resp_data.size())) {
411     MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
412                       << " the node id:" << node_info_.node_id_ << " send scale_in_done timeout!";
413   }
414 
415   MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
416                << " the node id:" << node_info_.node_id_ << "is sending scale_in_done to workers and servers!";
417 }
418 
SendEvent(const std::shared_ptr<TcpClient> & client,const uint32_t & event)419 void SchedulerNode::SendEvent(const std::shared_ptr<TcpClient> &client, const uint32_t &event) {
420   MS_EXCEPTION_IF_NULL(client);
421   auto message_meta = std::make_shared<MessageMeta>();
422   MS_EXCEPTION_IF_NULL(message_meta);
423   message_meta->set_cmd(NodeCommand::SEND_EVENT);
424 
425   EventRespMessage event_resp_message;
426   event_resp_message.set_event(event);
427 
428   if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, event_resp_message.SerializeAsString().data(),
429                        event_resp_message.ByteSizeLong())) {
430     MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
431                   << " the node id:" << node_info_.node_id_ << " send event resp timeout!";
432     return;
433   }
434 
435   MS_LOG(DEBUG) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
436                 << " the node id:" << node_info_.node_id_ << "is sending event resp to workers and servers!";
437 }
438 
StartUpdateClusterStateTimer()439 void SchedulerNode::StartUpdateClusterStateTimer() {
440   MS_LOG(INFO) << "[Scheduler start]: 3. The scheduler start a heartbeat timer!";
441   update_state_thread_ = std::make_unique<std::thread>([&]() {
442     auto start_time = std::chrono::steady_clock::now();
443     while (!is_finish_.load()) {
444       // 1. update cluster timeout
445       if (!is_ready_ && (std::chrono::steady_clock::now() - start_time >
446                          std::chrono::seconds(PSContext::instance()->cluster_config().cluster_available_timeout))) {
447         node_manager_.CheckClusterTimeout();
448       }
449       std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
450       node_manager_.UpdateCluster();
451 
452       if (node_manager_.GetClusterState() == ClusterState::CLUSTER_EXIT) {
453         std::this_thread::sleep_for(
454           std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval * kHeartbeatTimes));
455         is_finish_ = true;
456         wait_finish_cond_.notify_all();
457       }
458     }
459   });
460   MS_EXCEPTION_IF_NULL(update_state_thread_);
461 }
462 
GetOrCreateClient(const NodeInfo & node_info)463 const std::shared_ptr<TcpClient> &SchedulerNode::GetOrCreateClient(const NodeInfo &node_info) {
464   std::lock_guard<std::mutex> lock(client_mutex_);
465   if (connected_nodes_.count(node_info.node_id_)) {
466     return connected_nodes_[node_info.node_id_];
467   } else {
468     if (config_ == nullptr) {
469       MS_LOG(EXCEPTION) << "The config is empty.";
470     }
471     std::string ip = node_info.ip_;
472     uint16_t port = node_info.port_;
473     auto client = std::make_shared<TcpClient>(ip, port, config_.get());
474     MS_EXCEPTION_IF_NULL(client);
475     client->SetMessageCallback(
476       [&](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) {
477         switch (meta->cmd()) {
478           case NodeCommand::SEND_DATA:
479             ProcessSendDataResp(meta, protos, data, size);
480             RunMessageCallback(meta->request_id());
481             break;
482           default:
483             MS_LOG(DEBUG) << "The cmd:" << meta->cmd();
484         }
485         NotifyMessageArrival(meta);
486       });
487     client->Init();
488     if (is_client_started_ == false) {
489       is_client_started_ = true;
490       client_thread_ = std::make_unique<std::thread>([&]() {
491         MS_LOG(INFO) << "The node start a tcp client!";
492         client->Start();
493       });
494       MS_EXCEPTION_IF_NULL(client_thread_);
495     }
496 
497     connected_nodes_[node_info.node_id_] = client;
498     return connected_nodes_[node_info.node_id_];
499   }
500 }
501 
Stop()502 bool SchedulerNode::Stop() {
503   MS_LOG(INFO) << "Stop scheduler node!";
504   if (!is_already_stopped_) {
505     MS_ERROR_IF_NULL_W_RET_VAL(update_state_thread_, false);
506     MS_ERROR_IF_NULL_W_RET_VAL(server_, false);
507     MS_ERROR_IF_NULL_W_RET_VAL(scheduler_thread_, false);
508     is_already_stopped_ = true;
509     update_state_thread_->join();
510     server_->Stop();
511     scheduler_thread_->join();
512     if (!connected_nodes_.empty()) {
513       for (auto &connected_node : connected_nodes_) {
514         auto client = connected_node.second;
515         MS_ERROR_IF_NULL_W_RET_VAL(client, false);
516         client->Stop();
517       }
518     }
519     if (client_thread_ != nullptr && client_thread_->joinable()) {
520       client_thread_->join();
521     }
522     is_ready_ = true;
523   }
524   if (PSContext::instance()->scheduler_manage_port() != 0) {
525     MS_LOG(WARNING) << "Stop the scheduler http service, the ip:" << PSContext::instance()->scheduler_ip()
526                     << ", the port:" << PSContext::instance()->scheduler_manage_port();
527     StopRestfulServer();
528   }
529   return true;
530 }
531 
Finish(const uint32_t &)532 bool SchedulerNode::Finish(const uint32_t &) {
533   MS_LOG(INFO) << "[Scheduler finish]: 1. Begin to finish scheduler node!";
534   std::unique_lock<std::mutex> lock(wait_finish_mutex_);
535   wait_finish_cond_.wait(lock, [this] {
536     if (this->is_finish_.load()) {
537       MS_LOG(INFO) << "[Scheduler finish]: 2. Successfully finish scheduler!";
538     }
539     return this->is_finish_.load();
540   });
541   return true;
542 }
543 
ProcessScaleOut(const std::shared_ptr<HttpMessageHandler> & resp)544 void SchedulerNode::ProcessScaleOut(const std::shared_ptr<HttpMessageHandler> &resp) {
545   MS_EXCEPTION_IF_NULL(resp);
546   RequestProcessResult status(RequestProcessResultCode::kSuccess);
547   status = resp->ParsePostMessageToJson();
548   if (status != RequestProcessResultCode::kSuccess) {
549     resp->ErrorResponse(HTTP_BADREQUEST, status);
550     return;
551   }
552 
553   int32_t scale_worker_num = 0;
554   status = resp->ParseValueFromKey(kWorkerNum, &scale_worker_num);
555   if (status != RequestProcessResultCode::kSuccess) {
556     resp->ErrorResponse(HTTP_BADREQUEST, status);
557     return;
558   }
559 
560   int32_t scale_server_num = 0;
561   status = resp->ParseValueFromKey(kServerNum, &scale_server_num);
562   if (status != RequestProcessResultCode::kSuccess) {
563     resp->ErrorResponse(HTTP_BADREQUEST, status);
564     return;
565   }
566 
567   status = CheckIfClusterReady();
568   if (status != RequestProcessResultCode::kSuccess) {
569     resp->ErrorResponse(HTTP_BADREQUEST, status);
570     return;
571   }
572 
573   int32_t total_worker_num = scale_worker_num + node_manager_.worker_num();
574   int32_t total_server_num = scale_server_num + node_manager_.server_num();
575 
576   MS_LOG(INFO) << "After scale out, the total worker num:" << total_worker_num
577                << ", the total server num:" << total_server_num;
578 
579   node_manager_.set_worker_num(total_worker_num);
580   node_manager_.set_server_num(total_server_num);
581   node_manager_.set_total_node_num(total_worker_num + total_server_num);
582 
583   node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT);
584   auto node_infos = node_manager_.nodes_info();
585   node_manager_.ResetMetadata();
586   for (const auto &kvs : node_infos) {
587     auto client = GetOrCreateClient(kvs.second);
588     MS_EXCEPTION_IF_NULL(client);
589     MS_EXCEPTION_IF_NULL(leader_scaler_);
590     leader_scaler_->ScaleOutAsync(client, node_manager_);
591   }
592   MS_LOG(INFO) << "Scheduler send scale out successful.";
593 
594   nlohmann::json js;
595   js["message"] = "Cluster begin to scale out.";
596   resp->AddRespString(js.dump());
597   resp->AddRespHeadParam("Content-Type", "application/json");
598 
599   resp->SetRespCode(HTTP_OK);
600   resp->SendResponse();
601 }
602 
603 /*
604  * The response body format.
605  * {
606  *    "node_ids": ["node_id1", "node_id2"]
607  * }
608  */
ProcessScaleIn(const std::shared_ptr<HttpMessageHandler> & resp)609 void SchedulerNode::ProcessScaleIn(const std::shared_ptr<HttpMessageHandler> &resp) {
610   MS_EXCEPTION_IF_NULL(resp);
611   RequestProcessResult status(RequestProcessResultCode::kSuccess);
612   status = resp->ParsePostMessageToJson();
613   if (status != RequestProcessResultCode::kSuccess) {
614     resp->ErrorResponse(HTTP_BADREQUEST, status);
615   }
616 
617   status = CheckIfClusterReady();
618   if (status != RequestProcessResultCode::kSuccess) {
619     resp->ErrorResponse(HTTP_BADREQUEST, status);
620     return;
621   }
622 
623   scale_in_node_ids_.clear();
624   status = resp->ParseNodeIdsFromKey(kNodesIds, &scale_in_node_ids_);
625   if (status != RequestProcessResultCode::kSuccess) {
626     resp->ErrorResponse(HTTP_BADREQUEST, status);
627     return;
628   }
629 
630   status = CheckIfNodeIdLegal(scale_in_node_ids_);
631   if (status != RequestProcessResultCode::kSuccess) {
632     resp->ErrorResponse(HTTP_BADREQUEST, status);
633     return;
634   }
635 
636   MS_LOG(WARNING) << "The scale in node ids:" << scale_in_node_ids_;
637 
638   std::unordered_map<std::string, bool> scale_in_nodes;
639 
640   int32_t scale_worker_num = 0;
641   int32_t scale_server_num = 0;
642   auto node_infos = node_manager_.nodes_info();
643   node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_IN);
644   node_manager_.ResetMetadata(scale_in_node_ids_);
645   for (auto const &val : scale_in_node_ids_) {
646     if (node_infos.count(val)) {
647       scale_in_nodes[val] = true;
648       NodeInfo info = node_infos[val];
649       if (info.node_role_ == NodeRole::WORKER) {
650         scale_worker_num++;
651       } else if (info.node_role_ == NodeRole::SERVER) {
652         scale_server_num++;
653       }
654     }
655   }
656 
657   MS_LOG(INFO) << "The scale worker num:" << scale_worker_num << ", the scale server num:" << scale_server_num;
658 
659   int32_t total_worker_num = node_manager_.worker_num() - scale_worker_num;
660   int32_t total_server_num = node_manager_.server_num() - scale_server_num;
661 
662   node_manager_.set_worker_num(total_worker_num);
663   node_manager_.set_server_num(total_server_num);
664   node_manager_.set_total_node_num(total_worker_num + total_server_num);
665   for (const auto &kvs : node_infos) {
666     auto client = GetOrCreateClient(kvs.second);
667     bool is_node_scale_in = false;
668     if (scale_in_nodes.count(kvs.first)) {
669       is_node_scale_in = true;
670     }
671     MS_EXCEPTION_IF_NULL(leader_scaler_);
672     leader_scaler_->ScaleInAsync(client, node_manager_, is_node_scale_in);
673   }
674 
675   nlohmann::json js;
676   js["message"] = "Cluster begin to scale in.";
677   resp->AddRespString(js.dump());
678   resp->AddRespHeadParam("Content-Type", "application/json");
679 
680   resp->SetRespCode(HTTP_OK);
681   resp->SendResponse();
682 }
683 
684 /*
685  * The response body format.
686  * {
687  *    "message": "Get nodes info successful.",
688  *    "node_ids": [
689  *        {
690  *            "node_id": "node_id1",
691  *            "rank_id": "0",
692  *            "role": "SERVER"
693  *        },
694  *        {
695  *            "node_id": "node_id2",
696  *            "rank_id": "1",
697  *            "role": "WORKER"
698  *        }
699  *    ]
700  * }
701  */
ProcessGetNodesInfo(const std::shared_ptr<HttpMessageHandler> & resp)702 void SchedulerNode::ProcessGetNodesInfo(const std::shared_ptr<HttpMessageHandler> &resp) {
703   MS_EXCEPTION_IF_NULL(resp);
704   nlohmann::json js;
705   js["message"] = "Get nodes info successful.";
706   auto node_infos = node_manager_.nodes_info();
707   for (const auto &kvs : node_infos) {
708     std::unordered_map<std::string, std::string> res;
709     res["node_id"] = kvs.second.node_id_;
710     res["rank_id"] = std::to_string(kvs.second.rank_id_);
711     res["role"] = CommUtil::NodeRoleToString(kvs.second.node_role_);
712     js["node_ids"].push_back(res);
713   }
714 
715   resp->AddRespString(js.dump());
716   resp->AddRespHeadParam("Content-Type", "application/json");
717 
718   resp->SetRespCode(HTTP_OK);
719   resp->SendResponse();
720 }
721 
722 /*
723  * The response body format.
724  * {
725  *    "message": "Get cluster state successful.",
726  *    "cluster_state": "CLUSTER_READY"
727  * }
728  */
ProcessGetClusterState(const std::shared_ptr<HttpMessageHandler> & resp)729 void SchedulerNode::ProcessGetClusterState(const std::shared_ptr<HttpMessageHandler> &resp) {
730   MS_EXCEPTION_IF_NULL(resp);
731   nlohmann::json js;
732   js["message"] = "Get cluster state successful.";
733   auto cluster_state = node_manager_.GetClusterState();
734   js["cluster_state"] = CommUtil::ClusterStateToString(cluster_state);
735 
736   resp->AddRespString(js.dump());
737   resp->AddRespHeadParam("Content-Type", "application/json");
738 
739   resp->SetRespCode(HTTP_OK);
740   resp->SendResponse();
741 }
742 
ProcessNewInstance(const std::shared_ptr<HttpMessageHandler> & resp)743 void SchedulerNode::ProcessNewInstance(const std::shared_ptr<HttpMessageHandler> &resp) {
744   MS_EXCEPTION_IF_NULL(resp);
745 
746   RequestProcessResult status(RequestProcessResultCode::kSuccess);
747 
748   status = CheckIfClusterReady();
749   if (status != RequestProcessResultCode::kSuccess) {
750     resp->ErrorResponse(HTTP_BADREQUEST, status);
751     return;
752   }
753 
754   status = resp->ParsePostMessageToJson();
755   if (status != RequestProcessResultCode::kSuccess) {
756     resp->ErrorResponse(HTTP_BADREQUEST, status);
757     return;
758   }
759 
760   node_manager_.UpdateClusterState(ClusterState::CLUSTER_NEW_INSTANCE);
761 
762   std::string body = resp->request_message().dump();
763 
764   uint64_t request_id = AddMessageTrack(node_manager_.server_num());
765 
766   std::unordered_map<uint32_t, VectorPtr> outputs;
767 
768   set_message_callback(request_id, [&]() {
769     receive_messages_mutex_.lock();
770     outputs = receive_messages_[request_id];
771     receive_messages_.erase(request_id);
772     receive_messages_mutex_.unlock();
773   });
774 
775   auto node_infos = node_manager_.nodes_info();
776   for (const auto &kvs : node_infos) {
777     if (kvs.second.node_role_ == NodeRole::SERVER) {
778       auto client = GetOrCreateClient(kvs.second);
779       MS_EXCEPTION_IF_NULL(client);
780       MS_EXCEPTION_IF_NULL(instance_manager_);
781       instance_manager_->NewInstanceAsync(client, node_manager_, body, request_id, node_info_);
782     }
783   }
784   bool res = Wait(request_id);
785   if (!res) {
786     ERROR_STATUS(status, RequestProcessResultCode::kInvalidInputs, "The new instance is timeout.");
787     resp->ErrorResponse(HTTP_BADREQUEST, status);
788     node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
789     return;
790   }
791 
792   node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
793   nlohmann::json js;
794   js["message"] = "Start update flPlan successful.";
795   for (const auto &output : outputs) {
796     std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
797     js["result"][output.first] = data;
798   }
799 
800   resp->AddRespString(js.dump());
801   resp->AddRespHeadParam("Content-Type", "application/json");
802 
803   resp->SetRespCode(HTTP_OK);
804   resp->SendResponse();
805 }
806 
ProcessQueryInstance(const std::shared_ptr<HttpMessageHandler> & resp)807 void SchedulerNode::ProcessQueryInstance(const std::shared_ptr<HttpMessageHandler> &resp) {
808   MS_EXCEPTION_IF_NULL(resp);
809 
810   RequestProcessResult status(RequestProcessResultCode::kSuccess);
811 
812   status = CheckIfClusterReady();
813   if (status != RequestProcessResultCode::kSuccess) {
814     resp->ErrorResponse(HTTP_BADREQUEST, status);
815     return;
816   }
817 
818   uint64_t request_id = AddMessageTrack(node_manager_.server_num());
819 
820   std::unordered_map<uint32_t, VectorPtr> outputs;
821 
822   set_message_callback(request_id, [&]() {
823     receive_messages_mutex_.lock();
824     outputs = receive_messages_[request_id];
825     receive_messages_.erase(request_id);
826     receive_messages_mutex_.unlock();
827   });
828 
829   auto node_infos = node_manager_.nodes_info();
830   for (const auto &kvs : node_infos) {
831     if (kvs.second.node_role_ == NodeRole::SERVER) {
832       auto client = GetOrCreateClient(kvs.second);
833       MS_EXCEPTION_IF_NULL(client);
834       MS_EXCEPTION_IF_NULL(instance_manager_);
835       instance_manager_->QueryInstanceAsync(client, node_manager_, request_id, node_info_);
836     }
837   }
838   bool res = Wait(request_id);
839   if (!res) {
840     ERROR_STATUS(status, RequestProcessResultCode::kInvalidInputs, "The query instance is timeout.");
841     resp->ErrorResponse(HTTP_BADREQUEST, status);
842     return;
843   }
844 
845   nlohmann::json js;
846   js["message"] = "Start update flPlan successful.";
847   for (const auto &output : outputs) {
848     std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
849     js["result"][output.first] = data;
850   }
851 
852   resp->AddRespString(js.dump());
853   resp->AddRespHeadParam("Content-Type", "application/json");
854 
855   resp->SetRespCode(HTTP_OK);
856   resp->SendResponse();
857 }
858 
ProcessEnableFLS(const std::shared_ptr<HttpMessageHandler> & resp)859 void SchedulerNode::ProcessEnableFLS(const std::shared_ptr<HttpMessageHandler> &resp) {
860   MS_EXCEPTION_IF_NULL(resp);
861 
862   RequestProcessResult status(RequestProcessResultCode::kSuccess);
863 
864   status = CheckIfClusterReady();
865   if (status != RequestProcessResultCode::kSuccess) {
866     resp->ErrorResponse(HTTP_BADREQUEST, status);
867     return;
868   }
869 
870   node_manager_.UpdateClusterState(ClusterState::CLUSTER_ENABLE_FLS);
871 
872   uint64_t request_id = AddMessageTrack(node_manager_.server_num());
873 
874   std::unordered_map<uint32_t, VectorPtr> outputs;
875 
876   set_message_callback(request_id, [&]() {
877     receive_messages_mutex_.lock();
878     outputs = receive_messages_[request_id];
879     receive_messages_.erase(request_id);
880     receive_messages_mutex_.unlock();
881   });
882 
883   auto node_infos = node_manager_.nodes_info();
884   for (const auto &kvs : node_infos) {
885     if (kvs.second.node_role_ == NodeRole::SERVER) {
886       auto client = GetOrCreateClient(kvs.second);
887       MS_EXCEPTION_IF_NULL(client);
888       MS_EXCEPTION_IF_NULL(instance_manager_);
889       instance_manager_->EnableFLSAsync(client, node_manager_, request_id, node_info_);
890     }
891   }
892   bool res = Wait(request_id);
893   if (!res) {
894     ERROR_STATUS(status, RequestProcessResultCode::kInvalidInputs, "The enable FLS is timeout.");
895     resp->ErrorResponse(HTTP_BADREQUEST, status);
896     node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
897     return;
898   }
899 
900   node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
901   nlohmann::json js;
902   js["message"] = "start enabling FL-Server successful.";
903   for (const auto &output : outputs) {
904     std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
905     js["result"][output.first] = data;
906   }
907 
908   resp->AddRespString(js.dump());
909   resp->AddRespHeadParam("Content-Type", "application/json");
910 
911   resp->SetRespCode(HTTP_OK);
912   resp->SendResponse();
913 }
914 
ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler> & resp)915 void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler> &resp) {
916   MS_EXCEPTION_IF_NULL(resp);
917 
918   RequestProcessResult status(RequestProcessResultCode::kSuccess);
919 
920   status = CheckIfClusterReady();
921   if (status != RequestProcessResultCode::kSuccess) {
922     resp->ErrorResponse(HTTP_BADREQUEST, status);
923     return;
924   }
925 
926   node_manager_.UpdateClusterState(ClusterState::CLUSTER_DISABLE_FLS);
927 
928   uint64_t request_id = AddMessageTrack(node_manager_.server_num());
929 
930   std::unordered_map<uint32_t, VectorPtr> outputs;
931 
932   set_message_callback(request_id, [&]() {
933     receive_messages_mutex_.lock();
934     outputs = receive_messages_[request_id];
935     receive_messages_.erase(request_id);
936     receive_messages_mutex_.unlock();
937   });
938 
939   auto node_infos = node_manager_.nodes_info();
940   for (const auto &kvs : node_infos) {
941     if (kvs.second.node_role_ == NodeRole::SERVER) {
942       auto client = GetOrCreateClient(kvs.second);
943       MS_EXCEPTION_IF_NULL(client);
944       MS_EXCEPTION_IF_NULL(instance_manager_);
945       instance_manager_->DisableFLSAsync(client, node_manager_, request_id, node_info_);
946     }
947   }
948   bool res = Wait(request_id);
949   if (!res) {
950     ERROR_STATUS(status, RequestProcessResultCode::kInvalidInputs, "The disable FLS is timeout.");
951     resp->ErrorResponse(HTTP_BADREQUEST, status);
952     node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
953     return;
954   }
955 
956   node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
957   nlohmann::json js;
958   js["message"] = "start disabling FL-Server successful.";
959   for (const auto &output : outputs) {
960     std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
961     js["result"][output.first] = data;
962   }
963 
964   resp->AddRespString(js.dump());
965   resp->AddRespHeadParam("Content-Type", "application/json");
966 
967   resp->SetRespCode(HTTP_OK);
968   resp->SendResponse();
969 }
970 
CheckIfClusterReady()971 RequestProcessResult SchedulerNode::CheckIfClusterReady() {
972   RequestProcessResult result(RequestProcessResultCode::kSuccess);
973   if (node_manager_.GetClusterState() != ClusterState::CLUSTER_READY) {
974     std::string message = "The cluster is not ready.";
975     ERROR_STATUS(result, RequestProcessResultCode::kSystemError, message);
976     return result;
977   }
978   return result;
979 }
980 
CheckIfNodeIdLegal(const std::vector<std::string> & node_ids)981 RequestProcessResult SchedulerNode::CheckIfNodeIdLegal(const std::vector<std::string> &node_ids) {
982   RequestProcessResult result(RequestProcessResultCode::kSuccess);
983   if (node_ids.size() == 0) {
984     std::string message = "The node ids should not be empty.";
985     ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
986     return result;
987   }
988 
989   auto node_infos = node_manager_.nodes_info();
990 
991   for (auto val : node_ids) {
992     if (!node_infos.count(val)) {
993       std::string message = "The node id:" + val + " is illegal.";
994       MS_LOG(ERROR) << message;
995       ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
996       return result;
997     }
998 
999     if (node_infos[val].node_role_ == NodeRole::SERVER && node_infos[val].rank_id_ == 0) {
1000       std::string error_message = "The node id:" + val + " is rank 0 of server, should not be scale in.";
1001       MS_LOG(ERROR) << error_message;
1002       ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, error_message);
1003       return result;
1004     }
1005 
1006     if (node_infos[val].node_role_ == NodeRole::WORKER) {
1007       std::string error_message = "The node id:" + val + " is the role of worker, should not be scale in.";
1008       MS_LOG(ERROR) << error_message;
1009       ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, error_message);
1010       return result;
1011     }
1012   }
1013 
1014   return result;
1015 }
1016 
StartRestfulServer(const std::string & address,std::uint16_t port,size_t thread_num)1017 void SchedulerNode::StartRestfulServer(const std::string &address, std::uint16_t port, size_t thread_num) {
1018   MS_LOG(INFO) << "Scheduler start https server.";
1019   http_server_ = std::make_shared<HttpServer>(address, port, thread_num);
1020   MS_EXCEPTION_IF_NULL(http_server_);
1021 
1022   OnRequestReceive scale_out = std::bind(&SchedulerNode::ProcessScaleOut, this, std::placeholders::_1);
1023   callbacks_["/scaleout"] = scale_out;
1024   http_server_->RegisterRoute("/scaleout", &callbacks_["/scaleout"]);
1025 
1026   OnRequestReceive scale_in = std::bind(&SchedulerNode::ProcessScaleIn, this, std::placeholders::_1);
1027   callbacks_["/scalein"] = scale_in;
1028   http_server_->RegisterRoute("/scalein", &callbacks_["/scalein"]);
1029 
1030   OnRequestReceive nodes = std::bind(&SchedulerNode::ProcessGetNodesInfo, this, std::placeholders::_1);
1031   callbacks_["/nodes"] = nodes;
1032   http_server_->RegisterRoute("/nodes", &callbacks_["/nodes"]);
1033 
1034   OnRequestReceive cluster_state = std::bind(&SchedulerNode::ProcessGetClusterState, this, std::placeholders::_1);
1035   callbacks_["/state"] = cluster_state;
1036   http_server_->RegisterRoute("/state", &callbacks_["/state"]);
1037 
1038   OnRequestReceive new_instance = std::bind(&SchedulerNode::ProcessNewInstance, this, std::placeholders::_1);
1039   callbacks_["/newInstance"] = new_instance;
1040   http_server_->RegisterRoute("/newInstance", &callbacks_["/newInstance"]);
1041 
1042   OnRequestReceive query_instance = std::bind(&SchedulerNode::ProcessQueryInstance, this, std::placeholders::_1);
1043   callbacks_["/queryInstance"] = query_instance;
1044   http_server_->RegisterRoute("/queryInstance", &callbacks_["/queryInstance"]);
1045 
1046   OnRequestReceive enable_fls = std::bind(&SchedulerNode::ProcessEnableFLS, this, std::placeholders::_1);
1047   callbacks_["/enableFLS"] = enable_fls;
1048   http_server_->RegisterRoute("/enableFLS", &callbacks_["/enableFLS"]);
1049 
1050   OnRequestReceive disable_fls = std::bind(&SchedulerNode::ProcessDisableFLS, this, std::placeholders::_1);
1051   callbacks_["/disableFLS"] = disable_fls;
1052   http_server_->RegisterRoute("/disableFLS", &callbacks_["/disableFLS"]);
1053 
1054   if (!http_server_->InitServer()) {
1055     MS_LOG(EXCEPTION) << "The scheduler init http server failed.";
1056   }
1057 
1058   if (!http_server_->Start(false)) {
1059     MS_LOG(EXCEPTION) << "The scheduler start http server failed.";
1060   }
1061   restful_thread_ = std::make_unique<std::thread>([&]() { http_server_->Wait(); });
1062   MS_EXCEPTION_IF_NULL(restful_thread_);
1063 }
1064 
StopRestfulServer()1065 void SchedulerNode::StopRestfulServer() {
1066   MS_LOG(INFO) << "Scheduler stop https server.";
1067   MS_ERROR_IF_NULL_WO_RET_VAL(http_server_);
1068   MS_ERROR_IF_NULL_WO_RET_VAL(restful_thread_);
1069   if (!http_server_->Stop()) {
1070     MS_LOG(WARNING) << "Scheduler stop https server failed.";
1071   }
1072   if (restful_thread_ != nullptr && restful_thread_->joinable()) {
1073     restful_thread_->join();
1074   }
1075 }
1076 }  // namespace core
1077 }  // namespace ps
1078 }  // namespace mindspore
1079