1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/core/scheduler_node.h"
18
19 namespace mindspore {
20 namespace ps {
21 namespace core {
~SchedulerNode()22 SchedulerNode::~SchedulerNode() {
23 MS_LOG(INFO) << "Stop scheduler node!";
24 if (!Stop()) {
25 MS_LOG(WARNING) << "Scheduler node stop failed.";
26 }
27 }
28
Start(const uint32_t & timeout)29 bool SchedulerNode::Start(const uint32_t &timeout) {
30 MS_LOG(INFO) << "[Scheduler start]: 1. Begin to start scheduler node!";
31 if (PSContext::instance()->scheduler_manage_port() != 0) {
32 MS_LOG(WARNING) << "Start the scheduler http service, the ip:" << PSContext::instance()->scheduler_ip()
33 << ", the port:" << PSContext::instance()->scheduler_manage_port();
34 StartRestfulServer(kLocalIp, PSContext::instance()->scheduler_manage_port(), 1);
35 }
36 Initialize();
37 StartUpdateClusterStateTimer();
38 if (!WaitForStart(timeout)) {
39 MS_LOG(ERROR) << "Start Scheduler node timeout!";
40 return false;
41 }
42 node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
43 MS_LOG(INFO) << "[Scheduler start]: 4. Successfully start scheduler, there are " << node_manager_.worker_num()
44 << " workers and " << node_manager_.server_num() << " servers registered.";
45
46 return true;
47 }
48
ProcessHeartbeat(const std::shared_ptr<TcpServer> & server,const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)49 void SchedulerNode::ProcessHeartbeat(const std::shared_ptr<TcpServer> &server,
50 const std::shared_ptr<TcpConnection> &conn,
51 const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
52 MS_EXCEPTION_IF_NULL(server);
53 MS_EXCEPTION_IF_NULL(conn);
54 MS_EXCEPTION_IF_NULL(meta);
55 MS_EXCEPTION_IF_NULL(data);
56 HeartbeatMessage heartbeat_message;
57 CHECK_RETURN_TYPE(heartbeat_message.ParseFromArray(data, SizeToInt(size)));
58
59 node_manager_.UpdateHeartbeat(heartbeat_message.node_id());
60
61 HeartbeatRespMessage heartbeat_resp_message;
62
63 MS_LOG(DEBUG) << "The cluster state:" << CommUtil::ClusterStateToString(node_manager_.GetClusterState());
64 heartbeat_resp_message.set_cluster_state(node_manager_.GetClusterState());
65
66 std::vector<ServersMeta> servers_meta_list = node_manager_.FetchAllNodesMeta();
67
68 *heartbeat_resp_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
69
70 heartbeat_resp_message.set_is_worker_or_server0(node_manager_.IsWorkerOrServer0());
71
72 if (!server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(),
73 heartbeat_resp_message.ByteSizeLong())) {
74 MS_LOG(WARNING) << "Send heart beat failed.";
75 }
76 }
77
Initialize()78 void SchedulerNode::Initialize() {
79 config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
80 MS_EXCEPTION_IF_NULL(config_);
81 if (!config_->Initialize()) {
82 MS_LOG(INFO) << "The config file is empty.";
83 }
84 InitCommandHandler();
85 CreateTcpServer();
86 is_already_stopped_ = false;
87 if (PSContext::instance()->node_id().empty() && config_->Exists(kNodeId)) {
88 node_info_.node_id_ = config_->Get(kNodeId, "");
89 } else {
90 node_info_.node_id_ = PSContext::instance()->node_id();
91 }
92
93 if (node_info_.node_id_.empty()) {
94 node_info_.node_id_ = CommUtil::GenerateUUID();
95 }
96 node_info_.node_role_ = NodeRole::SCHEDULER;
97 leader_scaler_ = std::make_unique<LeaderScaler>(this);
98 MS_EXCEPTION_IF_NULL(leader_scaler_);
99 instance_manager_ = std::make_unique<InstanceManager>(this);
100 MS_LOG(INFO) << "[Scheduler start]: 2. The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
101 << ", the node id is:" << node_info_.node_id_ << " create a tcp server.";
102 }
103
InitCommandHandler()104 void SchedulerNode::InitCommandHandler() {
105 handlers_[NodeCommand::HEARTBEAT] = &SchedulerNode::ProcessHeartbeat;
106 handlers_[NodeCommand::REGISTER] = &SchedulerNode::ProcessRegister;
107 handlers_[NodeCommand::FINISH] = &SchedulerNode::ProcessFinish;
108 handlers_[NodeCommand::FETCH_METADATA] = &SchedulerNode::ProcessFetchMetadata;
109 handlers_[NodeCommand::SCALE_OUT_DONE] = &SchedulerNode::ProcessScaleOutDone;
110 handlers_[NodeCommand::SCALE_IN_DONE] = &SchedulerNode::ProcessScaleInDone;
111 handlers_[NodeCommand::SEND_EVENT] = &SchedulerNode::ProcessSendEvent;
112 }
113
CreateTcpServer()114 void SchedulerNode::CreateTcpServer() {
115 node_manager_.InitNode();
116
117 std::string scheduler_host = PSContext::instance()->cluster_config().scheduler_host;
118 uint32_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port;
119 server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port, config_.get());
120 MS_EXCEPTION_IF_NULL(server_);
121 server_->SetMessageCallback([&](const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
122 const Protos &, const void *data, size_t size) {
123 if (handlers_.count(meta->cmd()) == 0) {
124 MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
125 }
126 const auto &handler_ptr = handlers_[meta->cmd()];
127 (this->*handler_ptr)(server_, conn, meta, data, size);
128 });
129
130 server_->Init();
131
132 scheduler_thread_ = std::make_unique<std::thread>([this]() {
133 MS_LOG(INFO) << "The scheduler node start a tcp server!";
134 this->server_->Start();
135 });
136 MS_EXCEPTION_IF_NULL(scheduler_thread_);
137 }
138
ProcessRegister(const std::shared_ptr<TcpServer> & server,const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)139 void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server,
140 const std::shared_ptr<TcpConnection> &conn,
141 const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
142 MS_EXCEPTION_IF_NULL(server);
143 MS_EXCEPTION_IF_NULL(conn);
144 MS_EXCEPTION_IF_NULL(meta);
145 MS_EXCEPTION_IF_NULL(data);
146 RegisterMessage register_message;
147 CHECK_RETURN_TYPE(register_message.ParseFromArray(data, SizeToInt(size)));
148
149 const std::string &node_id = register_message.node_id();
150 node_manager_.UpdateHeartbeat(node_id);
151
152 MS_LOG(INFO) << "The node id:" << node_id << " is registering to scheduler.";
153 client_mutex_.lock();
154 if (node_manager_.IsNodeRegistered(node_id)) {
155 MS_LOG(INFO) << "The node id is registered.";
156 if (connected_nodes_.count(node_id)) {
157 (void)connected_nodes_.erase(node_id);
158 }
159 }
160 client_mutex_.unlock();
161
162 // assign worker node and server node rank id
163 uint32_t rank_id = node_manager_.NextRankId(register_message, meta);
164 if (rank_id == UINT32_MAX) {
165 MS_LOG(WARNING) << "The rank id is wrong!";
166 }
167
168 RegisterRespMessage register_resp_message;
169 register_resp_message.set_node_id(node_id);
170
171 if (!server->SendMessage(conn, meta, Protos::PROTOBUF, register_resp_message.SerializeAsString().data(),
172 register_resp_message.ByteSizeLong())) {
173 MS_LOG(WARNING) << "Server response message failed.";
174 }
175
176 if (node_manager_.IsAllNodesRegistered()) {
177 is_ready_ = true;
178 MS_LOG(INFO) << "There are " << node_manager_.worker_num() << " workers and " << node_manager_.server_num()
179 << " servers registered to scheduer, so the scheduler send meta data to worker/server.";
180 if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_IN) {
181 auto nodes = node_manager_.nodes_info();
182 for (const auto &id : scale_in_node_ids_) {
183 MS_LOG(INFO) << "The scheduler send metadata to scale in node:" << id;
184 if (nodes.count(id)) {
185 auto scale_in_client = GetOrCreateClient(nodes[id]);
186 SendMetadata(scale_in_client, nodes[id].rank_id_);
187 node_manager_.UpdateHeartbeat(id);
188 }
189 }
190 }
191 node_manager_.UpdateNodesInfo();
192 auto node_infos = node_manager_.nodes_info();
193 for (const auto &kvs : node_infos) {
194 auto client = GetOrCreateClient(kvs.second);
195 MS_EXCEPTION_IF_NULL(client);
196 SendMetadata(client, kvs.second.rank_id_);
197 node_manager_.UpdateHeartbeat(kvs.first);
198 }
199 node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
200 wait_start_cond_.notify_all();
201 }
202 }
203
ProcessFinish(const std::shared_ptr<TcpServer> & server,const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)204 void SchedulerNode::ProcessFinish(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn,
205 const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
206 MS_EXCEPTION_IF_NULL(server);
207 MS_EXCEPTION_IF_NULL(conn);
208 MS_EXCEPTION_IF_NULL(meta);
209 MS_EXCEPTION_IF_NULL(data);
210 auto finish_message = std::make_unique<std::string>(reinterpret_cast<const char *>(data), size);
211 MS_EXCEPTION_IF_NULL(finish_message);
212 std::string node_id = *finish_message;
213 MS_LOG(INFO) << "Process finish message from node id:" << node_id;
214 if (!server->SendMessage(conn, meta, Protos::PROTOBUF, data, size)) {
215 MS_LOG(WARNING) << "Server response message failed.";
216 }
217
218 auto iter = std::find_if(scale_in_node_ids_.begin(), scale_in_node_ids_.end(), [node_id](auto item) {
219 if (node_id == item) {
220 MS_LOG(INFO) << "The finish node is a scale in node.";
221 return true;
222 }
223 return false;
224 });
225 if (iter != scale_in_node_ids_.end()) {
226 return;
227 }
228
229 node_manager_.AddFinishNode(node_id);
230 if (node_manager_.IsAllNodesFinished()) {
231 auto node_infos = node_manager_.nodes_info();
232 for (const auto &kvs : node_infos) {
233 auto client = GetOrCreateClient(kvs.second);
234 SendFinish(client);
235 }
236 is_finish_ = true;
237 node_manager_.UpdateClusterState(ClusterState::CLUSTER_EXIT);
238 wait_finish_cond_.notify_all();
239 }
240 }
241
ProcessFetchMetadata(const std::shared_ptr<TcpServer> & server,const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t)242 void SchedulerNode::ProcessFetchMetadata(const std::shared_ptr<TcpServer> &server,
243 const std::shared_ptr<TcpConnection> &conn,
244 const std::shared_ptr<MessageMeta> &meta, const void *data, size_t) {
245 MS_EXCEPTION_IF_NULL(server);
246 MS_EXCEPTION_IF_NULL(conn);
247 MS_EXCEPTION_IF_NULL(meta);
248 MS_EXCEPTION_IF_NULL(data);
249 FetchServersRespMessage fetch_servers_message;
250 std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();
251
252 *fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
253
254 if (!server->SendMessage(conn, meta, Protos::PROTOBUF, fetch_servers_message.SerializeAsString().data(),
255 fetch_servers_message.ByteSizeLong())) {
256 MS_LOG(WARNING) << "Server response message failed.";
257 }
258 }
259
ProcessScaleOutDone(const std::shared_ptr<TcpServer> & server,const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)260 void SchedulerNode::ProcessScaleOutDone(const std::shared_ptr<TcpServer> &server,
261 const std::shared_ptr<TcpConnection> &conn,
262 const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
263 MS_EXCEPTION_IF_NULL(server);
264 MS_EXCEPTION_IF_NULL(conn);
265 MS_EXCEPTION_IF_NULL(meta);
266 MS_EXCEPTION_IF_NULL(data);
267 ScaleOutDoneMessage scale_out_done_message;
268 scale_out_done_message.ParseFromArray(data, SizeToInt(size));
269 std::string node_id = scale_out_done_message.node_id();
270 MS_LOG(INFO) << "The scheduler process a scale_out_done message from node id:" << node_id;
271 node_manager_.AddScaleOutDoneNode(node_id);
272
273 if (!server->SendMessage(conn, meta, Protos::PROTOBUF, data, size)) {
274 MS_LOG(WARNING) << "Server response message failed.";
275 }
276
277 if (node_manager_.IsAllNodesScaleOutDone()) {
278 auto node_infos = node_manager_.nodes_info();
279 for (const auto &kvs : node_infos) {
280 auto client = GetOrCreateClient(kvs.second);
281 SendScaleOutDone(client);
282 }
283 is_ready_ = true;
284 node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
285 }
286 }
287
ProcessScaleInDone(const std::shared_ptr<TcpServer> & server,const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)288 void SchedulerNode::ProcessScaleInDone(const std::shared_ptr<TcpServer> &server,
289 const std::shared_ptr<TcpConnection> &conn,
290 const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
291 MS_EXCEPTION_IF_NULL(server);
292 MS_EXCEPTION_IF_NULL(conn);
293 MS_EXCEPTION_IF_NULL(meta);
294 MS_EXCEPTION_IF_NULL(data);
295 ScaleInDoneMessage scale_in_done_message;
296 scale_in_done_message.ParseFromArray(data, SizeToInt(size));
297 std::string node_id = scale_in_done_message.node_id();
298 MS_LOG(INFO) << "The scheduler process a scale_in_done message from node id:" << node_id;
299 node_manager_.AddScaleInDoneNode(node_id);
300
301 if (!server->SendMessage(conn, meta, Protos::PROTOBUF, data, size)) {
302 MS_LOG(WARNING) << "Server response message failed.";
303 }
304
305 if (node_manager_.IsAllNodesScaleInDone()) {
306 auto node_infos = node_manager_.nodes_info();
307 for (const auto &kvs : node_infos) {
308 auto client = GetOrCreateClient(kvs.second);
309 SendScaleInDone(client);
310 }
311 is_ready_ = true;
312 node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
313 }
314 }
315
ProcessSendEvent(const std::shared_ptr<TcpServer> & server,const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const void * data,size_t size)316 void SchedulerNode::ProcessSendEvent(const std::shared_ptr<TcpServer> &server,
317 const std::shared_ptr<TcpConnection> &conn,
318 const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
319 MS_EXCEPTION_IF_NULL(server);
320 MS_EXCEPTION_IF_NULL(conn);
321 MS_EXCEPTION_IF_NULL(meta);
322 MS_EXCEPTION_IF_NULL(data);
323 EventMessage event_message;
324 event_message.ParseFromArray(data, SizeToInt(size));
325 std::string node_id = event_message.node_id();
326 uint32_t event = event_message.event();
327 MS_LOG(DEBUG) << "The scheduler process a event message from node id:" << node_id;
328
329 if (!server->SendMessage(conn, meta, Protos::PROTOBUF, data, size)) {
330 MS_LOG(WARNING) << "Server response message failed.";
331 }
332
333 auto node_infos = node_manager_.nodes_info();
334 for (const auto &kvs : node_infos) {
335 auto client = GetOrCreateClient(kvs.second);
336 SendEvent(client, event);
337 }
338 }
339
SendMetadata(const std::shared_ptr<TcpClient> & client,uint32_t rank_id)340 void SchedulerNode::SendMetadata(const std::shared_ptr<TcpClient> &client, uint32_t rank_id) {
341 MS_EXCEPTION_IF_NULL(client);
342 auto message_meta = std::make_shared<MessageMeta>();
343 MS_EXCEPTION_IF_NULL(message_meta);
344 message_meta->set_cmd(NodeCommand::SEND_METADATA);
345
346 SendMetadataMessage send_metadata_message;
347 std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();
348 send_metadata_message.set_worker_num(node_manager_.worker_num());
349 send_metadata_message.set_server_num(node_manager_.server_num());
350 send_metadata_message.set_cluster_state(node_manager_.GetClusterState());
351 send_metadata_message.set_rank_id(rank_id);
352
353 *send_metadata_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
354
355 if (!SendMessageAsync(client, message_meta, Protos::PROTOBUF, send_metadata_message.SerializeAsString().data(),
356 send_metadata_message.ByteSizeLong())) {
357 MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
358 << " the node id:" << node_info_.node_id_ << " send metadata timeout!";
359 }
360
361 MS_LOG(DEBUG) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
362 << " the node id:" << node_info_.node_id_ << "is sending metadata to workers and servers!";
363 }
364
SendFinish(const std::shared_ptr<TcpClient> & client)365 void SchedulerNode::SendFinish(const std::shared_ptr<TcpClient> &client) {
366 MS_EXCEPTION_IF_NULL(client);
367 auto message_meta = std::make_shared<MessageMeta>();
368 MS_EXCEPTION_IF_NULL(message_meta);
369 message_meta->set_cmd(NodeCommand::FINISH);
370
371 // The scheduler does not need to bring any data when sending the finish command
372 std::string resp_data;
373
374 if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, resp_data.data(), resp_data.size())) {
375 MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
376 << " the node id:" << node_info_.node_id_ << " send finish timeout!";
377 }
378
379 MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
380 << " the node id:" << node_info_.node_id_ << "is sending finish to workers and servers!";
381 }
382
SendScaleOutDone(const std::shared_ptr<TcpClient> & client)383 void SchedulerNode::SendScaleOutDone(const std::shared_ptr<TcpClient> &client) {
384 MS_EXCEPTION_IF_NULL(client);
385 auto message_meta = std::make_shared<MessageMeta>();
386 MS_EXCEPTION_IF_NULL(message_meta);
387 message_meta->set_cmd(NodeCommand::SCALE_OUT_DONE);
388
389 // The scheduler does not need to bring any data when sending the scale_out_done command
390 std::string resp_data;
391
392 if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, resp_data.data(), resp_data.size())) {
393 MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
394 << " the node id:" << node_info_.node_id_ << " send scale_out_done timeout!";
395 }
396
397 MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
398 << " the node id:" << node_info_.node_id_ << "is sending scale_out_done to workers and servers!";
399 }
400
SendScaleInDone(const std::shared_ptr<TcpClient> & client)401 void SchedulerNode::SendScaleInDone(const std::shared_ptr<TcpClient> &client) {
402 MS_EXCEPTION_IF_NULL(client);
403 auto message_meta = std::make_shared<MessageMeta>();
404 MS_EXCEPTION_IF_NULL(message_meta);
405 message_meta->set_cmd(NodeCommand::SCALE_IN_DONE);
406
407 // The scheduler does not need to bring any data when sending the scale_in_done command
408 std::string resp_data;
409
410 if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, resp_data.data(), resp_data.size())) {
411 MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
412 << " the node id:" << node_info_.node_id_ << " send scale_in_done timeout!";
413 }
414
415 MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
416 << " the node id:" << node_info_.node_id_ << "is sending scale_in_done to workers and servers!";
417 }
418
SendEvent(const std::shared_ptr<TcpClient> & client,const uint32_t & event)419 void SchedulerNode::SendEvent(const std::shared_ptr<TcpClient> &client, const uint32_t &event) {
420 MS_EXCEPTION_IF_NULL(client);
421 auto message_meta = std::make_shared<MessageMeta>();
422 MS_EXCEPTION_IF_NULL(message_meta);
423 message_meta->set_cmd(NodeCommand::SEND_EVENT);
424
425 EventRespMessage event_resp_message;
426 event_resp_message.set_event(event);
427
428 if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, event_resp_message.SerializeAsString().data(),
429 event_resp_message.ByteSizeLong())) {
430 MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
431 << " the node id:" << node_info_.node_id_ << " send event resp timeout!";
432 return;
433 }
434
435 MS_LOG(DEBUG) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
436 << " the node id:" << node_info_.node_id_ << "is sending event resp to workers and servers!";
437 }
438
StartUpdateClusterStateTimer()439 void SchedulerNode::StartUpdateClusterStateTimer() {
440 MS_LOG(INFO) << "[Scheduler start]: 3. The scheduler start a heartbeat timer!";
441 update_state_thread_ = std::make_unique<std::thread>([&]() {
442 auto start_time = std::chrono::steady_clock::now();
443 while (!is_finish_.load()) {
444 // 1. update cluster timeout
445 if (!is_ready_ && (std::chrono::steady_clock::now() - start_time >
446 std::chrono::seconds(PSContext::instance()->cluster_config().cluster_available_timeout))) {
447 node_manager_.CheckClusterTimeout();
448 }
449 std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
450 node_manager_.UpdateCluster();
451
452 if (node_manager_.GetClusterState() == ClusterState::CLUSTER_EXIT) {
453 std::this_thread::sleep_for(
454 std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval * kHeartbeatTimes));
455 is_finish_ = true;
456 wait_finish_cond_.notify_all();
457 }
458 }
459 });
460 MS_EXCEPTION_IF_NULL(update_state_thread_);
461 }
462
GetOrCreateClient(const NodeInfo & node_info)463 const std::shared_ptr<TcpClient> &SchedulerNode::GetOrCreateClient(const NodeInfo &node_info) {
464 std::lock_guard<std::mutex> lock(client_mutex_);
465 if (connected_nodes_.count(node_info.node_id_)) {
466 return connected_nodes_[node_info.node_id_];
467 } else {
468 if (config_ == nullptr) {
469 MS_LOG(EXCEPTION) << "The config is empty.";
470 }
471 std::string ip = node_info.ip_;
472 uint16_t port = node_info.port_;
473 auto client = std::make_shared<TcpClient>(ip, port, config_.get());
474 MS_EXCEPTION_IF_NULL(client);
475 client->SetMessageCallback(
476 [&](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) {
477 switch (meta->cmd()) {
478 case NodeCommand::SEND_DATA:
479 ProcessSendDataResp(meta, protos, data, size);
480 RunMessageCallback(meta->request_id());
481 break;
482 default:
483 MS_LOG(DEBUG) << "The cmd:" << meta->cmd();
484 }
485 NotifyMessageArrival(meta);
486 });
487 client->Init();
488 if (is_client_started_ == false) {
489 is_client_started_ = true;
490 client_thread_ = std::make_unique<std::thread>([&]() {
491 MS_LOG(INFO) << "The node start a tcp client!";
492 client->Start();
493 });
494 MS_EXCEPTION_IF_NULL(client_thread_);
495 }
496
497 connected_nodes_[node_info.node_id_] = client;
498 return connected_nodes_[node_info.node_id_];
499 }
500 }
501
Stop()502 bool SchedulerNode::Stop() {
503 MS_LOG(INFO) << "Stop scheduler node!";
504 if (!is_already_stopped_) {
505 MS_ERROR_IF_NULL_W_RET_VAL(update_state_thread_, false);
506 MS_ERROR_IF_NULL_W_RET_VAL(server_, false);
507 MS_ERROR_IF_NULL_W_RET_VAL(scheduler_thread_, false);
508 is_already_stopped_ = true;
509 update_state_thread_->join();
510 server_->Stop();
511 scheduler_thread_->join();
512 if (!connected_nodes_.empty()) {
513 for (auto &connected_node : connected_nodes_) {
514 auto client = connected_node.second;
515 MS_ERROR_IF_NULL_W_RET_VAL(client, false);
516 client->Stop();
517 }
518 }
519 if (client_thread_ != nullptr && client_thread_->joinable()) {
520 client_thread_->join();
521 }
522 is_ready_ = true;
523 }
524 if (PSContext::instance()->scheduler_manage_port() != 0) {
525 MS_LOG(WARNING) << "Stop the scheduler http service, the ip:" << PSContext::instance()->scheduler_ip()
526 << ", the port:" << PSContext::instance()->scheduler_manage_port();
527 StopRestfulServer();
528 }
529 return true;
530 }
531
Finish(const uint32_t &)532 bool SchedulerNode::Finish(const uint32_t &) {
533 MS_LOG(INFO) << "[Scheduler finish]: 1. Begin to finish scheduler node!";
534 std::unique_lock<std::mutex> lock(wait_finish_mutex_);
535 wait_finish_cond_.wait(lock, [this] {
536 if (this->is_finish_.load()) {
537 MS_LOG(INFO) << "[Scheduler finish]: 2. Successfully finish scheduler!";
538 }
539 return this->is_finish_.load();
540 });
541 return true;
542 }
543
ProcessScaleOut(const std::shared_ptr<HttpMessageHandler> & resp)544 void SchedulerNode::ProcessScaleOut(const std::shared_ptr<HttpMessageHandler> &resp) {
545 MS_EXCEPTION_IF_NULL(resp);
546 RequestProcessResult status(RequestProcessResultCode::kSuccess);
547 status = resp->ParsePostMessageToJson();
548 if (status != RequestProcessResultCode::kSuccess) {
549 resp->ErrorResponse(HTTP_BADREQUEST, status);
550 return;
551 }
552
553 int32_t scale_worker_num = 0;
554 status = resp->ParseValueFromKey(kWorkerNum, &scale_worker_num);
555 if (status != RequestProcessResultCode::kSuccess) {
556 resp->ErrorResponse(HTTP_BADREQUEST, status);
557 return;
558 }
559
560 int32_t scale_server_num = 0;
561 status = resp->ParseValueFromKey(kServerNum, &scale_server_num);
562 if (status != RequestProcessResultCode::kSuccess) {
563 resp->ErrorResponse(HTTP_BADREQUEST, status);
564 return;
565 }
566
567 status = CheckIfClusterReady();
568 if (status != RequestProcessResultCode::kSuccess) {
569 resp->ErrorResponse(HTTP_BADREQUEST, status);
570 return;
571 }
572
573 int32_t total_worker_num = scale_worker_num + node_manager_.worker_num();
574 int32_t total_server_num = scale_server_num + node_manager_.server_num();
575
576 MS_LOG(INFO) << "After scale out, the total worker num:" << total_worker_num
577 << ", the total server num:" << total_server_num;
578
579 node_manager_.set_worker_num(total_worker_num);
580 node_manager_.set_server_num(total_server_num);
581 node_manager_.set_total_node_num(total_worker_num + total_server_num);
582
583 node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT);
584 auto node_infos = node_manager_.nodes_info();
585 node_manager_.ResetMetadata();
586 for (const auto &kvs : node_infos) {
587 auto client = GetOrCreateClient(kvs.second);
588 MS_EXCEPTION_IF_NULL(client);
589 MS_EXCEPTION_IF_NULL(leader_scaler_);
590 leader_scaler_->ScaleOutAsync(client, node_manager_);
591 }
592 MS_LOG(INFO) << "Scheduler send scale out successful.";
593
594 nlohmann::json js;
595 js["message"] = "Cluster begin to scale out.";
596 resp->AddRespString(js.dump());
597 resp->AddRespHeadParam("Content-Type", "application/json");
598
599 resp->SetRespCode(HTTP_OK);
600 resp->SendResponse();
601 }
602
603 /*
604 * The response body format.
605 * {
606 * "node_ids": ["node_id1", "node_id2"]
607 * }
608 */
ProcessScaleIn(const std::shared_ptr<HttpMessageHandler> & resp)609 void SchedulerNode::ProcessScaleIn(const std::shared_ptr<HttpMessageHandler> &resp) {
610 MS_EXCEPTION_IF_NULL(resp);
611 RequestProcessResult status(RequestProcessResultCode::kSuccess);
612 status = resp->ParsePostMessageToJson();
613 if (status != RequestProcessResultCode::kSuccess) {
614 resp->ErrorResponse(HTTP_BADREQUEST, status);
615 }
616
617 status = CheckIfClusterReady();
618 if (status != RequestProcessResultCode::kSuccess) {
619 resp->ErrorResponse(HTTP_BADREQUEST, status);
620 return;
621 }
622
623 scale_in_node_ids_.clear();
624 status = resp->ParseNodeIdsFromKey(kNodesIds, &scale_in_node_ids_);
625 if (status != RequestProcessResultCode::kSuccess) {
626 resp->ErrorResponse(HTTP_BADREQUEST, status);
627 return;
628 }
629
630 status = CheckIfNodeIdLegal(scale_in_node_ids_);
631 if (status != RequestProcessResultCode::kSuccess) {
632 resp->ErrorResponse(HTTP_BADREQUEST, status);
633 return;
634 }
635
636 MS_LOG(WARNING) << "The scale in node ids:" << scale_in_node_ids_;
637
638 std::unordered_map<std::string, bool> scale_in_nodes;
639
640 int32_t scale_worker_num = 0;
641 int32_t scale_server_num = 0;
642 auto node_infos = node_manager_.nodes_info();
643 node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_IN);
644 node_manager_.ResetMetadata(scale_in_node_ids_);
645 for (auto const &val : scale_in_node_ids_) {
646 if (node_infos.count(val)) {
647 scale_in_nodes[val] = true;
648 NodeInfo info = node_infos[val];
649 if (info.node_role_ == NodeRole::WORKER) {
650 scale_worker_num++;
651 } else if (info.node_role_ == NodeRole::SERVER) {
652 scale_server_num++;
653 }
654 }
655 }
656
657 MS_LOG(INFO) << "The scale worker num:" << scale_worker_num << ", the scale server num:" << scale_server_num;
658
659 int32_t total_worker_num = node_manager_.worker_num() - scale_worker_num;
660 int32_t total_server_num = node_manager_.server_num() - scale_server_num;
661
662 node_manager_.set_worker_num(total_worker_num);
663 node_manager_.set_server_num(total_server_num);
664 node_manager_.set_total_node_num(total_worker_num + total_server_num);
665 for (const auto &kvs : node_infos) {
666 auto client = GetOrCreateClient(kvs.second);
667 bool is_node_scale_in = false;
668 if (scale_in_nodes.count(kvs.first)) {
669 is_node_scale_in = true;
670 }
671 MS_EXCEPTION_IF_NULL(leader_scaler_);
672 leader_scaler_->ScaleInAsync(client, node_manager_, is_node_scale_in);
673 }
674
675 nlohmann::json js;
676 js["message"] = "Cluster begin to scale in.";
677 resp->AddRespString(js.dump());
678 resp->AddRespHeadParam("Content-Type", "application/json");
679
680 resp->SetRespCode(HTTP_OK);
681 resp->SendResponse();
682 }
683
684 /*
685 * The response body format.
686 * {
687 * "message": "Get nodes info successful.",
688 * "node_ids": [
689 * {
690 * "node_id": "node_id1",
691 * "rank_id": "0",
692 * "role": "SERVER"
693 * },
694 * {
695 * "node_id": "node_id2",
696 * "rank_id": "1",
697 * "role": "WORKER"
698 * }
699 * ]
700 * }
701 */
ProcessGetNodesInfo(const std::shared_ptr<HttpMessageHandler> & resp)702 void SchedulerNode::ProcessGetNodesInfo(const std::shared_ptr<HttpMessageHandler> &resp) {
703 MS_EXCEPTION_IF_NULL(resp);
704 nlohmann::json js;
705 js["message"] = "Get nodes info successful.";
706 auto node_infos = node_manager_.nodes_info();
707 for (const auto &kvs : node_infos) {
708 std::unordered_map<std::string, std::string> res;
709 res["node_id"] = kvs.second.node_id_;
710 res["rank_id"] = std::to_string(kvs.second.rank_id_);
711 res["role"] = CommUtil::NodeRoleToString(kvs.second.node_role_);
712 js["node_ids"].push_back(res);
713 }
714
715 resp->AddRespString(js.dump());
716 resp->AddRespHeadParam("Content-Type", "application/json");
717
718 resp->SetRespCode(HTTP_OK);
719 resp->SendResponse();
720 }
721
722 /*
723 * The response body format.
724 * {
725 * "message": "Get cluster state successful.",
726 * "cluster_state": "CLUSTER_READY"
727 * }
728 */
ProcessGetClusterState(const std::shared_ptr<HttpMessageHandler> & resp)729 void SchedulerNode::ProcessGetClusterState(const std::shared_ptr<HttpMessageHandler> &resp) {
730 MS_EXCEPTION_IF_NULL(resp);
731 nlohmann::json js;
732 js["message"] = "Get cluster state successful.";
733 auto cluster_state = node_manager_.GetClusterState();
734 js["cluster_state"] = CommUtil::ClusterStateToString(cluster_state);
735
736 resp->AddRespString(js.dump());
737 resp->AddRespHeadParam("Content-Type", "application/json");
738
739 resp->SetRespCode(HTTP_OK);
740 resp->SendResponse();
741 }
742
ProcessNewInstance(const std::shared_ptr<HttpMessageHandler> & resp)743 void SchedulerNode::ProcessNewInstance(const std::shared_ptr<HttpMessageHandler> &resp) {
744 MS_EXCEPTION_IF_NULL(resp);
745
746 RequestProcessResult status(RequestProcessResultCode::kSuccess);
747
748 status = CheckIfClusterReady();
749 if (status != RequestProcessResultCode::kSuccess) {
750 resp->ErrorResponse(HTTP_BADREQUEST, status);
751 return;
752 }
753
754 status = resp->ParsePostMessageToJson();
755 if (status != RequestProcessResultCode::kSuccess) {
756 resp->ErrorResponse(HTTP_BADREQUEST, status);
757 return;
758 }
759
760 node_manager_.UpdateClusterState(ClusterState::CLUSTER_NEW_INSTANCE);
761
762 std::string body = resp->request_message().dump();
763
764 uint64_t request_id = AddMessageTrack(node_manager_.server_num());
765
766 std::unordered_map<uint32_t, VectorPtr> outputs;
767
768 set_message_callback(request_id, [&]() {
769 receive_messages_mutex_.lock();
770 outputs = receive_messages_[request_id];
771 receive_messages_.erase(request_id);
772 receive_messages_mutex_.unlock();
773 });
774
775 auto node_infos = node_manager_.nodes_info();
776 for (const auto &kvs : node_infos) {
777 if (kvs.second.node_role_ == NodeRole::SERVER) {
778 auto client = GetOrCreateClient(kvs.second);
779 MS_EXCEPTION_IF_NULL(client);
780 MS_EXCEPTION_IF_NULL(instance_manager_);
781 instance_manager_->NewInstanceAsync(client, node_manager_, body, request_id, node_info_);
782 }
783 }
784 bool res = Wait(request_id);
785 if (!res) {
786 ERROR_STATUS(status, RequestProcessResultCode::kInvalidInputs, "The new instance is timeout.");
787 resp->ErrorResponse(HTTP_BADREQUEST, status);
788 node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
789 return;
790 }
791
792 node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
793 nlohmann::json js;
794 js["message"] = "Start update flPlan successful.";
795 for (const auto &output : outputs) {
796 std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
797 js["result"][output.first] = data;
798 }
799
800 resp->AddRespString(js.dump());
801 resp->AddRespHeadParam("Content-Type", "application/json");
802
803 resp->SetRespCode(HTTP_OK);
804 resp->SendResponse();
805 }
806
ProcessQueryInstance(const std::shared_ptr<HttpMessageHandler> & resp)807 void SchedulerNode::ProcessQueryInstance(const std::shared_ptr<HttpMessageHandler> &resp) {
808 MS_EXCEPTION_IF_NULL(resp);
809
810 RequestProcessResult status(RequestProcessResultCode::kSuccess);
811
812 status = CheckIfClusterReady();
813 if (status != RequestProcessResultCode::kSuccess) {
814 resp->ErrorResponse(HTTP_BADREQUEST, status);
815 return;
816 }
817
818 uint64_t request_id = AddMessageTrack(node_manager_.server_num());
819
820 std::unordered_map<uint32_t, VectorPtr> outputs;
821
822 set_message_callback(request_id, [&]() {
823 receive_messages_mutex_.lock();
824 outputs = receive_messages_[request_id];
825 receive_messages_.erase(request_id);
826 receive_messages_mutex_.unlock();
827 });
828
829 auto node_infos = node_manager_.nodes_info();
830 for (const auto &kvs : node_infos) {
831 if (kvs.second.node_role_ == NodeRole::SERVER) {
832 auto client = GetOrCreateClient(kvs.second);
833 MS_EXCEPTION_IF_NULL(client);
834 MS_EXCEPTION_IF_NULL(instance_manager_);
835 instance_manager_->QueryInstanceAsync(client, node_manager_, request_id, node_info_);
836 }
837 }
838 bool res = Wait(request_id);
839 if (!res) {
840 ERROR_STATUS(status, RequestProcessResultCode::kInvalidInputs, "The query instance is timeout.");
841 resp->ErrorResponse(HTTP_BADREQUEST, status);
842 return;
843 }
844
845 nlohmann::json js;
846 js["message"] = "Start update flPlan successful.";
847 for (const auto &output : outputs) {
848 std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
849 js["result"][output.first] = data;
850 }
851
852 resp->AddRespString(js.dump());
853 resp->AddRespHeadParam("Content-Type", "application/json");
854
855 resp->SetRespCode(HTTP_OK);
856 resp->SendResponse();
857 }
858
ProcessEnableFLS(const std::shared_ptr<HttpMessageHandler> & resp)859 void SchedulerNode::ProcessEnableFLS(const std::shared_ptr<HttpMessageHandler> &resp) {
860 MS_EXCEPTION_IF_NULL(resp);
861
862 RequestProcessResult status(RequestProcessResultCode::kSuccess);
863
864 status = CheckIfClusterReady();
865 if (status != RequestProcessResultCode::kSuccess) {
866 resp->ErrorResponse(HTTP_BADREQUEST, status);
867 return;
868 }
869
870 node_manager_.UpdateClusterState(ClusterState::CLUSTER_ENABLE_FLS);
871
872 uint64_t request_id = AddMessageTrack(node_manager_.server_num());
873
874 std::unordered_map<uint32_t, VectorPtr> outputs;
875
876 set_message_callback(request_id, [&]() {
877 receive_messages_mutex_.lock();
878 outputs = receive_messages_[request_id];
879 receive_messages_.erase(request_id);
880 receive_messages_mutex_.unlock();
881 });
882
883 auto node_infos = node_manager_.nodes_info();
884 for (const auto &kvs : node_infos) {
885 if (kvs.second.node_role_ == NodeRole::SERVER) {
886 auto client = GetOrCreateClient(kvs.second);
887 MS_EXCEPTION_IF_NULL(client);
888 MS_EXCEPTION_IF_NULL(instance_manager_);
889 instance_manager_->EnableFLSAsync(client, node_manager_, request_id, node_info_);
890 }
891 }
892 bool res = Wait(request_id);
893 if (!res) {
894 ERROR_STATUS(status, RequestProcessResultCode::kInvalidInputs, "The enable FLS is timeout.");
895 resp->ErrorResponse(HTTP_BADREQUEST, status);
896 node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
897 return;
898 }
899
900 node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
901 nlohmann::json js;
902 js["message"] = "start enabling FL-Server successful.";
903 for (const auto &output : outputs) {
904 std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
905 js["result"][output.first] = data;
906 }
907
908 resp->AddRespString(js.dump());
909 resp->AddRespHeadParam("Content-Type", "application/json");
910
911 resp->SetRespCode(HTTP_OK);
912 resp->SendResponse();
913 }
914
ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler> & resp)915 void SchedulerNode::ProcessDisableFLS(const std::shared_ptr<HttpMessageHandler> &resp) {
916 MS_EXCEPTION_IF_NULL(resp);
917
918 RequestProcessResult status(RequestProcessResultCode::kSuccess);
919
920 status = CheckIfClusterReady();
921 if (status != RequestProcessResultCode::kSuccess) {
922 resp->ErrorResponse(HTTP_BADREQUEST, status);
923 return;
924 }
925
926 node_manager_.UpdateClusterState(ClusterState::CLUSTER_DISABLE_FLS);
927
928 uint64_t request_id = AddMessageTrack(node_manager_.server_num());
929
930 std::unordered_map<uint32_t, VectorPtr> outputs;
931
932 set_message_callback(request_id, [&]() {
933 receive_messages_mutex_.lock();
934 outputs = receive_messages_[request_id];
935 receive_messages_.erase(request_id);
936 receive_messages_mutex_.unlock();
937 });
938
939 auto node_infos = node_manager_.nodes_info();
940 for (const auto &kvs : node_infos) {
941 if (kvs.second.node_role_ == NodeRole::SERVER) {
942 auto client = GetOrCreateClient(kvs.second);
943 MS_EXCEPTION_IF_NULL(client);
944 MS_EXCEPTION_IF_NULL(instance_manager_);
945 instance_manager_->DisableFLSAsync(client, node_manager_, request_id, node_info_);
946 }
947 }
948 bool res = Wait(request_id);
949 if (!res) {
950 ERROR_STATUS(status, RequestProcessResultCode::kInvalidInputs, "The disable FLS is timeout.");
951 resp->ErrorResponse(HTTP_BADREQUEST, status);
952 node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
953 return;
954 }
955
956 node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
957 nlohmann::json js;
958 js["message"] = "start disabling FL-Server successful.";
959 for (const auto &output : outputs) {
960 std::string data = std::string(reinterpret_cast<char *>(output.second->data()), output.second->size());
961 js["result"][output.first] = data;
962 }
963
964 resp->AddRespString(js.dump());
965 resp->AddRespHeadParam("Content-Type", "application/json");
966
967 resp->SetRespCode(HTTP_OK);
968 resp->SendResponse();
969 }
970
CheckIfClusterReady()971 RequestProcessResult SchedulerNode::CheckIfClusterReady() {
972 RequestProcessResult result(RequestProcessResultCode::kSuccess);
973 if (node_manager_.GetClusterState() != ClusterState::CLUSTER_READY) {
974 std::string message = "The cluster is not ready.";
975 ERROR_STATUS(result, RequestProcessResultCode::kSystemError, message);
976 return result;
977 }
978 return result;
979 }
980
CheckIfNodeIdLegal(const std::vector<std::string> & node_ids)981 RequestProcessResult SchedulerNode::CheckIfNodeIdLegal(const std::vector<std::string> &node_ids) {
982 RequestProcessResult result(RequestProcessResultCode::kSuccess);
983 if (node_ids.size() == 0) {
984 std::string message = "The node ids should not be empty.";
985 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
986 return result;
987 }
988
989 auto node_infos = node_manager_.nodes_info();
990
991 for (auto val : node_ids) {
992 if (!node_infos.count(val)) {
993 std::string message = "The node id:" + val + " is illegal.";
994 MS_LOG(ERROR) << message;
995 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, message);
996 return result;
997 }
998
999 if (node_infos[val].node_role_ == NodeRole::SERVER && node_infos[val].rank_id_ == 0) {
1000 std::string error_message = "The node id:" + val + " is rank 0 of server, should not be scale in.";
1001 MS_LOG(ERROR) << error_message;
1002 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, error_message);
1003 return result;
1004 }
1005
1006 if (node_infos[val].node_role_ == NodeRole::WORKER) {
1007 std::string error_message = "The node id:" + val + " is the role of worker, should not be scale in.";
1008 MS_LOG(ERROR) << error_message;
1009 ERROR_STATUS(result, RequestProcessResultCode::kInvalidInputs, error_message);
1010 return result;
1011 }
1012 }
1013
1014 return result;
1015 }
1016
StartRestfulServer(const std::string & address,std::uint16_t port,size_t thread_num)1017 void SchedulerNode::StartRestfulServer(const std::string &address, std::uint16_t port, size_t thread_num) {
1018 MS_LOG(INFO) << "Scheduler start https server.";
1019 http_server_ = std::make_shared<HttpServer>(address, port, thread_num);
1020 MS_EXCEPTION_IF_NULL(http_server_);
1021
1022 OnRequestReceive scale_out = std::bind(&SchedulerNode::ProcessScaleOut, this, std::placeholders::_1);
1023 callbacks_["/scaleout"] = scale_out;
1024 http_server_->RegisterRoute("/scaleout", &callbacks_["/scaleout"]);
1025
1026 OnRequestReceive scale_in = std::bind(&SchedulerNode::ProcessScaleIn, this, std::placeholders::_1);
1027 callbacks_["/scalein"] = scale_in;
1028 http_server_->RegisterRoute("/scalein", &callbacks_["/scalein"]);
1029
1030 OnRequestReceive nodes = std::bind(&SchedulerNode::ProcessGetNodesInfo, this, std::placeholders::_1);
1031 callbacks_["/nodes"] = nodes;
1032 http_server_->RegisterRoute("/nodes", &callbacks_["/nodes"]);
1033
1034 OnRequestReceive cluster_state = std::bind(&SchedulerNode::ProcessGetClusterState, this, std::placeholders::_1);
1035 callbacks_["/state"] = cluster_state;
1036 http_server_->RegisterRoute("/state", &callbacks_["/state"]);
1037
1038 OnRequestReceive new_instance = std::bind(&SchedulerNode::ProcessNewInstance, this, std::placeholders::_1);
1039 callbacks_["/newInstance"] = new_instance;
1040 http_server_->RegisterRoute("/newInstance", &callbacks_["/newInstance"]);
1041
1042 OnRequestReceive query_instance = std::bind(&SchedulerNode::ProcessQueryInstance, this, std::placeholders::_1);
1043 callbacks_["/queryInstance"] = query_instance;
1044 http_server_->RegisterRoute("/queryInstance", &callbacks_["/queryInstance"]);
1045
1046 OnRequestReceive enable_fls = std::bind(&SchedulerNode::ProcessEnableFLS, this, std::placeholders::_1);
1047 callbacks_["/enableFLS"] = enable_fls;
1048 http_server_->RegisterRoute("/enableFLS", &callbacks_["/enableFLS"]);
1049
1050 OnRequestReceive disable_fls = std::bind(&SchedulerNode::ProcessDisableFLS, this, std::placeholders::_1);
1051 callbacks_["/disableFLS"] = disable_fls;
1052 http_server_->RegisterRoute("/disableFLS", &callbacks_["/disableFLS"]);
1053
1054 if (!http_server_->InitServer()) {
1055 MS_LOG(EXCEPTION) << "The scheduler init http server failed.";
1056 }
1057
1058 if (!http_server_->Start(false)) {
1059 MS_LOG(EXCEPTION) << "The scheduler start http server failed.";
1060 }
1061 restful_thread_ = std::make_unique<std::thread>([&]() { http_server_->Wait(); });
1062 MS_EXCEPTION_IF_NULL(restful_thread_);
1063 }
1064
StopRestfulServer()1065 void SchedulerNode::StopRestfulServer() {
1066 MS_LOG(INFO) << "Scheduler stop https server.";
1067 MS_ERROR_IF_NULL_WO_RET_VAL(http_server_);
1068 MS_ERROR_IF_NULL_WO_RET_VAL(restful_thread_);
1069 if (!http_server_->Stop()) {
1070 MS_LOG(WARNING) << "Scheduler stop https server failed.";
1071 }
1072 if (restful_thread_ != nullptr && restful_thread_->joinable()) {
1073 restful_thread_->join();
1074 }
1075 }
1076 } // namespace core
1077 } // namespace ps
1078 } // namespace mindspore
1079