1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "distributed/rpc/tcp/tcp_comm.h"
18
19 #include <mutex>
20 #include <utility>
21 #include <memory>
22
23 #include "actor/aid.h"
24 #include "include/backend/distributed/rpc/tcp/constants.h"
25 #include "distributed/rpc/tcp/tcp_socket_operation.h"
26
27 namespace mindspore {
28 namespace distributed {
29 namespace rpc {
DoDisconnect(int fd,Connection * conn,uint32_t error,int soError)30 void DoDisconnect(int fd, Connection *conn, uint32_t error, int soError) {
31 if (conn == nullptr) {
32 return;
33 }
34 if (LOG_CHECK_EVERY_N()) {
35 MS_LOG(INFO) << "Failed to call connect, fd: " << fd << "from : " << conn->source << " to: " << conn->destination
36 << ", events: " << error << ", errno: " << soError << " " << strerror(soError);
37 }
38
39 conn->state = ConnectionState::kDisconnecting;
40 conn->error_code = soError;
41 conn->event_callback(conn);
42 return;
43 }
44
ConnectedEventHandler(int fd,uint32_t events,void * context)45 void ConnectedEventHandler(int fd, uint32_t events, void *context) {
46 uint32_t error = events & (EPOLLERR | EPOLLHUP | EPOLLRDHUP);
47 int soError = 0;
48 Connection *conn = reinterpret_cast<Connection *>(context);
49 if (conn == nullptr || conn->socket_operation == nullptr) {
50 return;
51 }
52 if (conn->state == ConnectionState::kDisconnecting) {
53 DoDisconnect(fd, conn, error, soError);
54 return;
55 }
56
57 if (!conn->ReconnectSourceSocket(fd, events, &soError, error)) {
58 DoDisconnect(fd, conn, error, soError);
59 return;
60 }
61 if (conn->write_callback) {
62 conn->write_callback(conn);
63 }
64 MS_LOG(WARNING) << "Connection from " << conn->source << " to " << conn->destination << "is successfully created "
65 << strerror(errno);
66 conn->socket_operation->ConnEstablishedEventHandler(fd, events, context);
67 return;
68 }
69
OnAccept(int server,uint32_t events,void * arg)70 void OnAccept(int server, uint32_t events, void *arg) {
71 if (events & (EPOLLHUP | EPOLLERR)) {
72 MS_LOG(ERROR) << "Invalid error event, server fd: " << server << ", events: " << events;
73 return;
74 }
75 TCPComm *tcpmgr = reinterpret_cast<TCPComm *>(arg);
76 if (tcpmgr == nullptr || tcpmgr->conn_pool_ == nullptr) {
77 return;
78 }
79 if (tcpmgr->recv_event_loop_ == nullptr) {
80 MS_LOG(ERROR) << "EventLoop is null, server fd: " << server << ", events: " << events;
81 return;
82 }
83
84 // accept connection
85 auto acceptFd = SocketOperation::Accept(server);
86 if (acceptFd < 0) {
87 MS_LOG(ERROR) << "Failed to call accept, server fd: " << server << ", events: " << events;
88 return;
89 }
90
91 // This new connection will be added to connection pool.
92 Connection *conn = new (std::nothrow) Connection();
93 if (conn == nullptr) {
94 MS_LOG(ERROR) << "Failed to create new connection, server fd:" << server << ", events: " << events
95 << ", accept fd: " << acceptFd;
96 if (close(acceptFd) != 0) {
97 MS_LOG(ERROR) << "Failed to close fd: " << acceptFd;
98 }
99 return;
100 }
101 conn->enable_ssl = tcpmgr->enable_ssl_;
102
103 // init metrics
104 conn->send_metrics = new (std::nothrow) SendMetrics();
105 if (conn->send_metrics == nullptr) {
106 MS_LOG(ERROR) << "Failed to create connection metrics, server fd: " << server << ", events: " << events
107 << ", accept fd: " << acceptFd;
108 if (close(acceptFd) != 0) {
109 MS_LOG(ERROR) << "Failed to close fd: " << acceptFd;
110 }
111 delete conn;
112 return;
113 }
114
115 conn->socket_fd = acceptFd;
116 conn->source = tcpmgr->url_;
117 conn->destination = SocketOperation::GetPeer(acceptFd);
118 conn->peer = conn->destination;
119
120 conn->is_remote = true;
121 conn->recv_event_loop = tcpmgr->recv_event_loop_;
122 conn->send_event_loop = tcpmgr->send_event_loop_;
123
124 conn->conn_mutex = tcpmgr->conn_mutex_;
125 conn->message_handler = tcpmgr->message_handler_;
126
127 conn->event_callback = std::bind(&TCPComm::EventCallBack, tcpmgr, std::placeholders::_1);
128 conn->write_callback = std::bind(&TCPComm::WriteCallBack, tcpmgr, std::placeholders::_1);
129 conn->read_callback = std::bind(&TCPComm::ReadCallBack, tcpmgr, std::placeholders::_1);
130
131 conn->SetAllocateCallback(tcpmgr->allocate_cb());
132
133 int retval = conn->Initialize();
134 if (retval != RPC_OK) {
135 MS_LOG(ERROR) << "Failed to add accept fd event, server fd: " << server << ", events: " << events
136 << ", accept fd: " << acceptFd;
137 if (close(acceptFd) != 0) {
138 MS_LOG(ERROR) << "Failed to close fd: " << acceptFd;
139 }
140 acceptFd = -1;
141 delete conn->send_metrics;
142 delete conn;
143 conn = nullptr;
144 return;
145 }
146 tcpmgr->conn_pool_->AddConnection(conn);
147 }
148
SetMessageHandler(const MessageHandler & handler)149 void TCPComm::SetMessageHandler(const MessageHandler &handler) { message_handler_ = handler; }
150
Initialize()151 bool TCPComm::Initialize() {
152 conn_pool_ = std::make_shared<ConnectionPool>();
153 MS_EXCEPTION_IF_NULL(conn_pool_);
154
155 conn_mutex_ = std::make_shared<std::mutex>();
156 MS_EXCEPTION_IF_NULL(conn_mutex_);
157
158 recv_event_loop_ = new (std::nothrow) EventLoop();
159 if (recv_event_loop_ == nullptr) {
160 MS_LOG(ERROR) << "Failed to create recv evLoop.";
161 return false;
162 }
163
164 bool ok = recv_event_loop_->Initialize(TCP_RECV_EVLOOP_THREADNAME);
165 if (!ok) {
166 MS_LOG(ERROR) << "Failed to init recv evLoop";
167 delete recv_event_loop_;
168 recv_event_loop_ = nullptr;
169 return false;
170 }
171
172 send_event_loop_ = new (std::nothrow) EventLoop();
173 if (send_event_loop_ == nullptr) {
174 MS_LOG(ERROR) << "Failed to create send evLoop.";
175 delete recv_event_loop_;
176 recv_event_loop_ = nullptr;
177 return false;
178 }
179 ok = send_event_loop_->Initialize(TCP_SEND_EVLOOP_THREADNAME);
180 if (!ok) {
181 MS_LOG(ERROR) << "Failed to init send evLoop";
182 delete recv_event_loop_;
183 recv_event_loop_ = nullptr;
184 delete send_event_loop_;
185 send_event_loop_ = nullptr;
186 return false;
187 }
188
189 return true;
190 }
191
StartServerSocket(const std::string & url,const MemAllocateCallback & allocate_cb)192 int TCPComm::StartServerSocket(const std::string &url, const MemAllocateCallback &allocate_cb) {
193 server_fd_ = SocketOperation::Listen(url);
194 if (server_fd_ < 0) {
195 MS_LOG(WARNING) << "Failed to call socket listen, url: " << url.c_str();
196 return server_fd_;
197 }
198 url_ = url;
199 allocate_cb_ = allocate_cb;
200 size_t index = url.find(URL_PROTOCOL_IP_SEPARATOR);
201 if (index != std::string::npos) {
202 index = index + sizeof(URL_PROTOCOL_IP_SEPARATOR) - 1;
203 if (index < url.length()) {
204 url_ = url.substr(index);
205 }
206 }
207
208 // Register read event callback for server socket
209 int retval = recv_event_loop_->SetEventHandler(server_fd_, EPOLLIN | EPOLLHUP | EPOLLERR, OnAccept,
210 reinterpret_cast<void *>(this));
211 if (retval != RPC_OK) {
212 MS_LOG(ERROR) << "Failed to add server event, url: " << url.c_str();
213 return -1;
214 }
215 MS_LOG(INFO) << "Start server succ, fd: " << server_fd_ << ", url: " << url.c_str();
216 return 0;
217 }
218
StartServerSocket(const MemAllocateCallback & allocate_cb)219 int TCPComm::StartServerSocket(const MemAllocateCallback &allocate_cb) {
220 auto ip = SocketOperation::GetLocalIP();
221 // The port 0 means that the port will be allocated randomly by the os system.
222 auto url = ip + ":0";
223 return StartServerSocket(url, allocate_cb);
224 }
225
GetServerFd() const226 int TCPComm::GetServerFd() const { return server_fd_; }
227
ReadCallBack(void * connection)228 void TCPComm::ReadCallBack(void *connection) {
229 const int max_recv_count = 3;
230 Connection *conn = reinterpret_cast<Connection *>(connection);
231 if (conn == nullptr) {
232 return;
233 }
234 int count = 0;
235 int retval = 0;
236 do {
237 retval = ReceiveMessage(conn);
238 ++count;
239 } while (retval > 0 && count < max_recv_count);
240
241 return;
242 }
243
EventCallBack(void * connection)244 void TCPComm::EventCallBack(void *connection) {
245 Connection *conn = reinterpret_cast<Connection *>(connection);
246 if (conn == nullptr) {
247 return;
248 }
249 if (conn->state == ConnectionState::kConnected) {
250 conn->conn_mutex->lock();
251 (void)conn->Flush();
252 conn->conn_mutex->unlock();
253 } else if (conn->state == ConnectionState::kDisconnecting) {
254 std::lock_guard<std::mutex> lock(*conn_mutex_);
255 auto current_conn = conn_pool_->FindConnection(conn->destination);
256 if (current_conn != nullptr) {
257 if (current_conn->source != conn->source) {
258 MS_LOG(WARNING) << "Current connection created to " << conn->destination << " is from " << current_conn->source
259 << ", not " << conn->source << ". No need to delete.";
260 return;
261 }
262 }
263 MS_LOG(INFO) << "The connection state is kDisconnecting. Start disconnecting from " << conn->source << " to "
264 << conn->destination;
265 conn_pool_->DeleteConnection(conn->destination);
266 }
267 }
268
WriteCallBack(void * connection)269 void TCPComm::WriteCallBack(void *connection) {
270 Connection *conn = reinterpret_cast<Connection *>(connection);
271 if (conn == nullptr) {
272 return;
273 }
274 if (conn->state == ConnectionState::kConnected) {
275 conn->conn_mutex->lock();
276 (void)conn->Flush();
277 conn->conn_mutex->unlock();
278 }
279 }
280
281 /* static method */
ReceiveMessage(Connection * conn)282 int TCPComm::ReceiveMessage(Connection *conn) {
283 std::lock_guard<std::mutex> lock(*conn->conn_mutex);
284 conn->CheckMessageType();
285 switch (conn->recv_message_type) {
286 case ParseType::kTcpMsg:
287 return conn->ReceiveMessage();
288 default:
289 return 0;
290 }
291 }
292
293 /* static method */
SetConnectedHandler(Connection * conn)294 int TCPComm::SetConnectedHandler(Connection *conn) {
295 /* add to epoll */
296 return conn->recv_event_loop->SetEventHandler(conn->socket_fd,
297 static_cast<uint32_t>(EPOLLOUT | EPOLLHUP | EPOLLRDHUP | EPOLLERR),
298 ConnectedEventHandler, reinterpret_cast<void *>(conn));
299 }
300
301 /* static method */
DoConnect(Connection * conn,const struct sockaddr * sa,socklen_t saLen)302 int TCPComm::DoConnect(Connection *conn, const struct sockaddr *sa, socklen_t saLen) {
303 if (conn == nullptr || conn->recv_event_loop == nullptr || sa == nullptr) {
304 return RPC_ERROR;
305 }
306 int retval = 0;
307 uint16_t localPort = 0;
308
309 retval = SocketOperation::Connect(conn->socket_fd, sa, saLen, &localPort);
310 if (retval != RPC_OK) {
311 return RPC_ERROR;
312 }
313
314 // Init connection metrics.
315 if (conn->send_metrics == nullptr) {
316 conn->send_metrics = new (std::nothrow) SendMetrics();
317 if (conn->send_metrics == nullptr) {
318 return RPC_ERROR;
319 }
320 }
321
322 // Add the socket of this connection to epoll.
323 retval = SetConnectedHandler(conn);
324 if (retval != RPC_OK) {
325 if (conn->send_metrics != nullptr) {
326 delete conn->send_metrics;
327 conn->send_metrics = nullptr;
328 }
329 return RPC_ERROR;
330 }
331 return RPC_OK;
332 }
333
334 /* static method */
DropMessage(MessageBase * msg)335 void TCPComm::DropMessage(MessageBase *msg) {
336 auto *ptr = msg;
337 delete ptr;
338 ptr = nullptr;
339 }
340
Send(MessageBase * msg,size_t * const send_bytes,bool sync)341 bool TCPComm::Send(MessageBase *msg, size_t *const send_bytes, bool sync) {
342 if (msg == nullptr) {
343 return false;
344 }
345 auto task = [msg, send_bytes, this] {
346 std::lock_guard<std::mutex> lock(*conn_mutex_);
347 // Search connection by the target address
348 std::string destination = msg->to.Url();
349 Connection *conn = conn_pool_->FindConnection(destination);
350 if (conn == nullptr) {
351 MS_LOG(WARNING) << "Can not found remote link and send fail name: " << msg->name.c_str()
352 << ", from: " << msg->from.Url().c_str() << ", to: " << destination;
353 DropMessage(msg);
354 return false;
355 }
356
357 if (conn->send_message_queue.size() >= SENDMSG_QUEUELEN) {
358 MS_LOG(WARNING) << "The message queue is full(max len:" << SENDMSG_QUEUELEN
359 << ") and the name of dropped message is: " << msg->name.c_str() << ", fd: " << conn->socket_fd
360 << ", to: " << conn->destination.c_str();
361 if (!conn->FreeMessageMemory(msg)) {
362 MS_LOG(ERROR) << "Failed to free memory of the message.";
363 }
364 DropMessage(msg);
365 return false;
366 }
367
368 if (conn->state != ConnectionState::kConnected) {
369 MS_LOG(WARNING) << "Invalid connection state " << conn->state
370 << " and the name of dropped message is: " << msg->name.c_str() << ", fd: " << conn->socket_fd
371 << ", to: " << conn->destination.c_str();
372 if (!conn->FreeMessageMemory(msg)) {
373 MS_LOG(ERROR) << "Failed to free memory of the message.";
374 }
375 DropMessage(msg);
376 return false;
377 }
378
379 if (conn->total_send_len == 0) {
380 conn->FillSendMessage(msg, url_, false);
381 } else {
382 (void)conn->send_message_queue.emplace(msg);
383 }
384 auto bytes = conn->Flush();
385 if (send_bytes != nullptr) {
386 *send_bytes = bytes;
387 }
388 return true;
389 };
390 if (sync) {
391 return task();
392 } else {
393 send_event_loop_->AddTask(task);
394 return true;
395 }
396 }
397
Flush(const std::string & dst_url)398 bool TCPComm::Flush(const std::string &dst_url) {
399 Connection *conn = conn_pool_->FindConnection(dst_url);
400 if (conn == nullptr) {
401 MS_LOG(ERROR) << "Can not find the connection to url: " << dst_url;
402 return false;
403 } else {
404 std::lock_guard<std::mutex> lock(*(conn->conn_mutex));
405 return (conn->Flush() >= 0);
406 }
407 }
408
Connect(const std::string & dst_url,const MemFreeCallback & free_cb)409 bool TCPComm::Connect(const std::string &dst_url, const MemFreeCallback &free_cb) {
410 MS_EXCEPTION_IF_NULL(conn_mutex_);
411 MS_EXCEPTION_IF_NULL(conn_pool_);
412 if (!free_cb) {
413 MS_LOG(EXCEPTION) << "The message callback is empty.";
414 }
415
416 std::lock_guard<std::mutex> lock(*conn_mutex_);
417
418 // Search connection by the target address
419 Connection *conn = conn_pool_->FindConnection(dst_url);
420
421 if (conn == nullptr) {
422 MS_LOG(INFO) << "Can not found link destination: " << dst_url;
423 conn = new (std::nothrow) Connection();
424 if (conn == nullptr) {
425 MS_LOG(ERROR) << "Failed to create new connection and link fail destination: " << dst_url;
426 return false;
427 }
428 conn->enable_ssl = enable_ssl_;
429 conn->recv_event_loop = this->recv_event_loop_;
430 conn->send_event_loop = this->send_event_loop_;
431 conn->conn_mutex = conn_mutex_;
432 conn->message_handler = message_handler_;
433 conn->InitSocketOperation();
434
435 // Create the client socket.
436 SocketAddress addr;
437 if (!SocketOperation::GetSockAddr(dst_url, &addr)) {
438 MS_LOG(ERROR) << "Failed to get socket address to dest url " << dst_url;
439 return false;
440 }
441 int sock_fd = SocketOperation::CreateSocket(addr.sa.sa_family);
442 if (sock_fd < 0) {
443 MS_LOG(ERROR) << "Failed to create client tcp socket to dest url " << dst_url;
444 return false;
445 }
446
447 conn->socket_fd = sock_fd;
448 conn->event_callback = std::bind(&TCPComm::EventCallBack, this, std::placeholders::_1);
449 conn->write_callback = std::bind(&TCPComm::WriteCallBack, this, std::placeholders::_1);
450 conn->read_callback = std::bind(&TCPComm::ReadCallBack, this, std::placeholders::_1);
451
452 int ret = TCPComm::DoConnect(conn, reinterpret_cast<struct sockaddr *>(&addr), sizeof(addr));
453 if (ret < 0) {
454 MS_LOG(ERROR) << "Failed to do connect and link fail destination: " << dst_url;
455 if (conn->socket_operation != nullptr) {
456 delete conn->socket_operation;
457 conn->socket_operation = nullptr;
458 }
459 delete conn;
460 conn = nullptr;
461 return false;
462 }
463 conn->source = SocketOperation::GetIP(sock_fd) + ":" + std::to_string(SocketOperation::GetPort(sock_fd));
464 conn->destination = dst_url;
465 dst_url_to_src_ip_[dst_url] = SocketOperation::GetIP(sock_fd);
466 MS_LOG(INFO) << "Connection " << sock_fd << " source: " << conn->source << ", destination: " << conn->destination;
467
468 // Check the state of this new created connection.
469 uint32_t interval = 1;
470 size_t retry = 3;
471 // Record total retry number to avoid duplicated log.
472 static size_t total_retry_count = 0;
473 while (conn->state < ConnectionState::kConnected && retry-- > 0) {
474 MS_LOG(WARNING) << "Waiting for the state of the connection to " << dst_url
475 << " to be connected...Retry number: " << ++total_retry_count;
476 (void)sleep(interval);
477 }
478 if (conn->state != ConnectionState::kConnected) {
479 return false;
480 }
481 conn_pool_->AddConnection(conn);
482 conn->SetMessageFreeCallback(free_cb);
483 }
484 conn_pool_->AddConnInfo(conn->socket_fd, dst_url, nullptr);
485 MS_LOG(INFO) << "Connected to destination: " << dst_url;
486 return true;
487 }
488
IsConnected(const std::string & dst_url)489 bool TCPComm::IsConnected(const std::string &dst_url) {
490 MS_EXCEPTION_IF_NULL(conn_pool_);
491 Connection *conn = conn_pool_->FindConnection(dst_url);
492 if (conn != nullptr && conn->state == ConnectionState::kConnected) {
493 return true;
494 }
495 return false;
496 }
497
Disconnect(const std::string & dst_url)498 bool TCPComm::Disconnect(const std::string &dst_url) {
499 MS_EXCEPTION_IF_NULL(conn_mutex_);
500 MS_EXCEPTION_IF_NULL(conn_pool_);
501 MS_EXCEPTION_IF_NULL(recv_event_loop_);
502 MS_EXCEPTION_IF_NULL(send_event_loop_);
503
504 unsigned int interval = 100000;
505 size_t retry = 30;
506 while (recv_event_loop_->RemainingTaskNum() != 0 && send_event_loop_->RemainingTaskNum() != 0 && retry > 0) {
507 (void)usleep(interval);
508 retry--;
509 }
510 if (recv_event_loop_->RemainingTaskNum() > 0 || send_event_loop_->RemainingTaskNum() > 0) {
511 MS_LOG(ERROR) << "Failed to disconnect from url " << dst_url
512 << ", because there are still pending tasks to be executed, please try later.";
513 return false;
514 }
515 std::lock_guard<std::mutex> lock(*conn_mutex_);
516 auto conn = conn_pool_->FindConnection(dst_url);
517 if (conn != nullptr) {
518 std::lock_guard<std::mutex> conn_lock(conn->conn_owned_mutex_);
519 conn_pool_->DeleteConnection(dst_url);
520 }
521 return true;
522 }
523
CreateDefaultConn(const std::string & to)524 Connection *TCPComm::CreateDefaultConn(const std::string &to) {
525 Connection *conn = new (std::nothrow) Connection();
526 if (conn == nullptr) {
527 MS_LOG(ERROR) << "Failed to create new connection and reconnect fail to: " << to.c_str();
528 return conn;
529 }
530 conn->enable_ssl = enable_ssl_;
531 conn->source = url_.data();
532 conn->destination = to;
533 conn->recv_event_loop = this->recv_event_loop_;
534 conn->send_event_loop = this->send_event_loop_;
535 conn->conn_mutex = conn_mutex_;
536 conn->message_handler = message_handler_;
537 conn->InitSocketOperation();
538 return conn;
539 }
540
Finalize()541 void TCPComm::Finalize() {
542 if (send_event_loop_ != nullptr) {
543 MS_LOG(INFO) << "Delete send event loop";
544 send_event_loop_->Finalize();
545 delete send_event_loop_;
546 send_event_loop_ = nullptr;
547 }
548
549 if (recv_event_loop_ != nullptr) {
550 MS_LOG(INFO) << "Delete recv event loop";
551 recv_event_loop_->Finalize();
552 delete recv_event_loop_;
553 recv_event_loop_ = nullptr;
554 }
555
556 if (server_fd_ > 0) {
557 if (close(server_fd_) != 0) {
558 MS_LOG(ERROR) << "Failed to close fd: " << server_fd_;
559 }
560 server_fd_ = -1;
561 }
562
563 if (conn_pool_ != nullptr) {
564 MS_LOG(INFO) << "Delete connection pool.";
565 conn_pool_->Finalize();
566 conn_pool_.reset();
567 conn_pool_ = nullptr;
568 }
569 }
570 } // namespace rpc
571 } // namespace distributed
572 } // namespace mindspore
573