• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "distributed/rpc/tcp/tcp_comm.h"
18 
19 #include <mutex>
20 #include <utility>
21 #include <memory>
22 
23 #include "actor/aid.h"
24 #include "include/backend/distributed/rpc/tcp/constants.h"
25 #include "distributed/rpc/tcp/tcp_socket_operation.h"
26 
27 namespace mindspore {
28 namespace distributed {
29 namespace rpc {
DoDisconnect(int fd,Connection * conn,uint32_t error,int soError)30 void DoDisconnect(int fd, Connection *conn, uint32_t error, int soError) {
31   if (conn == nullptr) {
32     return;
33   }
34   if (LOG_CHECK_EVERY_N()) {
35     MS_LOG(INFO) << "Failed to call connect, fd: " << fd << "from : " << conn->source << " to: " << conn->destination
36                  << ", events: " << error << ", errno: " << soError << " " << strerror(soError);
37   }
38 
39   conn->state = ConnectionState::kDisconnecting;
40   conn->error_code = soError;
41   conn->event_callback(conn);
42   return;
43 }
44 
ConnectedEventHandler(int fd,uint32_t events,void * context)45 void ConnectedEventHandler(int fd, uint32_t events, void *context) {
46   uint32_t error = events & (EPOLLERR | EPOLLHUP | EPOLLRDHUP);
47   int soError = 0;
48   Connection *conn = reinterpret_cast<Connection *>(context);
49   if (conn == nullptr || conn->socket_operation == nullptr) {
50     return;
51   }
52   if (conn->state == ConnectionState::kDisconnecting) {
53     DoDisconnect(fd, conn, error, soError);
54     return;
55   }
56 
57   if (!conn->ReconnectSourceSocket(fd, events, &soError, error)) {
58     DoDisconnect(fd, conn, error, soError);
59     return;
60   }
61   if (conn->write_callback) {
62     conn->write_callback(conn);
63   }
64   MS_LOG(WARNING) << "Connection from " << conn->source << " to " << conn->destination << "is successfully created "
65                   << strerror(errno);
66   conn->socket_operation->ConnEstablishedEventHandler(fd, events, context);
67   return;
68 }
69 
OnAccept(int server,uint32_t events,void * arg)70 void OnAccept(int server, uint32_t events, void *arg) {
71   if (events & (EPOLLHUP | EPOLLERR)) {
72     MS_LOG(ERROR) << "Invalid error event, server fd: " << server << ", events: " << events;
73     return;
74   }
75   TCPComm *tcpmgr = reinterpret_cast<TCPComm *>(arg);
76   if (tcpmgr == nullptr || tcpmgr->conn_pool_ == nullptr) {
77     return;
78   }
79   if (tcpmgr->recv_event_loop_ == nullptr) {
80     MS_LOG(ERROR) << "EventLoop is null, server fd: " << server << ", events: " << events;
81     return;
82   }
83 
84   // accept connection
85   auto acceptFd = SocketOperation::Accept(server);
86   if (acceptFd < 0) {
87     MS_LOG(ERROR) << "Failed to call accept, server fd: " << server << ", events: " << events;
88     return;
89   }
90 
91   // This new connection will be added to connection pool.
92   Connection *conn = new (std::nothrow) Connection();
93   if (conn == nullptr) {
94     MS_LOG(ERROR) << "Failed to create new connection, server fd:" << server << ", events: " << events
95                   << ", accept fd: " << acceptFd;
96     if (close(acceptFd) != 0) {
97       MS_LOG(ERROR) << "Failed to close fd: " << acceptFd;
98     }
99     return;
100   }
101   conn->enable_ssl = tcpmgr->enable_ssl_;
102 
103   // init metrics
104   conn->send_metrics = new (std::nothrow) SendMetrics();
105   if (conn->send_metrics == nullptr) {
106     MS_LOG(ERROR) << "Failed to create connection metrics, server fd: " << server << ", events: " << events
107                   << ", accept fd: " << acceptFd;
108     if (close(acceptFd) != 0) {
109       MS_LOG(ERROR) << "Failed to close fd: " << acceptFd;
110     }
111     delete conn;
112     return;
113   }
114 
115   conn->socket_fd = acceptFd;
116   conn->source = tcpmgr->url_;
117   conn->destination = SocketOperation::GetPeer(acceptFd);
118   conn->peer = conn->destination;
119 
120   conn->is_remote = true;
121   conn->recv_event_loop = tcpmgr->recv_event_loop_;
122   conn->send_event_loop = tcpmgr->send_event_loop_;
123 
124   conn->conn_mutex = tcpmgr->conn_mutex_;
125   conn->message_handler = tcpmgr->message_handler_;
126 
127   conn->event_callback = std::bind(&TCPComm::EventCallBack, tcpmgr, std::placeholders::_1);
128   conn->write_callback = std::bind(&TCPComm::WriteCallBack, tcpmgr, std::placeholders::_1);
129   conn->read_callback = std::bind(&TCPComm::ReadCallBack, tcpmgr, std::placeholders::_1);
130 
131   conn->SetAllocateCallback(tcpmgr->allocate_cb());
132 
133   int retval = conn->Initialize();
134   if (retval != RPC_OK) {
135     MS_LOG(ERROR) << "Failed to add accept fd event, server fd: " << server << ", events: " << events
136                   << ", accept fd: " << acceptFd;
137     if (close(acceptFd) != 0) {
138       MS_LOG(ERROR) << "Failed to close fd: " << acceptFd;
139     }
140     acceptFd = -1;
141     delete conn->send_metrics;
142     delete conn;
143     conn = nullptr;
144     return;
145   }
146   tcpmgr->conn_pool_->AddConnection(conn);
147 }
148 
SetMessageHandler(const MessageHandler & handler)149 void TCPComm::SetMessageHandler(const MessageHandler &handler) { message_handler_ = handler; }
150 
Initialize()151 bool TCPComm::Initialize() {
152   conn_pool_ = std::make_shared<ConnectionPool>();
153   MS_EXCEPTION_IF_NULL(conn_pool_);
154 
155   conn_mutex_ = std::make_shared<std::mutex>();
156   MS_EXCEPTION_IF_NULL(conn_mutex_);
157 
158   recv_event_loop_ = new (std::nothrow) EventLoop();
159   if (recv_event_loop_ == nullptr) {
160     MS_LOG(ERROR) << "Failed to create recv evLoop.";
161     return false;
162   }
163 
164   bool ok = recv_event_loop_->Initialize(TCP_RECV_EVLOOP_THREADNAME);
165   if (!ok) {
166     MS_LOG(ERROR) << "Failed to init recv evLoop";
167     delete recv_event_loop_;
168     recv_event_loop_ = nullptr;
169     return false;
170   }
171 
172   send_event_loop_ = new (std::nothrow) EventLoop();
173   if (send_event_loop_ == nullptr) {
174     MS_LOG(ERROR) << "Failed to create send evLoop.";
175     delete recv_event_loop_;
176     recv_event_loop_ = nullptr;
177     return false;
178   }
179   ok = send_event_loop_->Initialize(TCP_SEND_EVLOOP_THREADNAME);
180   if (!ok) {
181     MS_LOG(ERROR) << "Failed to init send evLoop";
182     delete recv_event_loop_;
183     recv_event_loop_ = nullptr;
184     delete send_event_loop_;
185     send_event_loop_ = nullptr;
186     return false;
187   }
188 
189   return true;
190 }
191 
StartServerSocket(const std::string & url,const MemAllocateCallback & allocate_cb)192 int TCPComm::StartServerSocket(const std::string &url, const MemAllocateCallback &allocate_cb) {
193   server_fd_ = SocketOperation::Listen(url);
194   if (server_fd_ < 0) {
195     MS_LOG(WARNING) << "Failed to call socket listen, url: " << url.c_str();
196     return server_fd_;
197   }
198   url_ = url;
199   allocate_cb_ = allocate_cb;
200   size_t index = url.find(URL_PROTOCOL_IP_SEPARATOR);
201   if (index != std::string::npos) {
202     index = index + sizeof(URL_PROTOCOL_IP_SEPARATOR) - 1;
203     if (index < url.length()) {
204       url_ = url.substr(index);
205     }
206   }
207 
208   // Register read event callback for server socket
209   int retval = recv_event_loop_->SetEventHandler(server_fd_, EPOLLIN | EPOLLHUP | EPOLLERR, OnAccept,
210                                                  reinterpret_cast<void *>(this));
211   if (retval != RPC_OK) {
212     MS_LOG(ERROR) << "Failed to add server event, url: " << url.c_str();
213     return -1;
214   }
215   MS_LOG(INFO) << "Start server succ, fd: " << server_fd_ << ", url: " << url.c_str();
216   return 0;
217 }
218 
StartServerSocket(const MemAllocateCallback & allocate_cb)219 int TCPComm::StartServerSocket(const MemAllocateCallback &allocate_cb) {
220   auto ip = SocketOperation::GetLocalIP();
221   // The port 0 means that the port will be allocated randomly by the os system.
222   auto url = ip + ":0";
223   return StartServerSocket(url, allocate_cb);
224 }
225 
GetServerFd() const226 int TCPComm::GetServerFd() const { return server_fd_; }
227 
ReadCallBack(void * connection)228 void TCPComm::ReadCallBack(void *connection) {
229   const int max_recv_count = 3;
230   Connection *conn = reinterpret_cast<Connection *>(connection);
231   if (conn == nullptr) {
232     return;
233   }
234   int count = 0;
235   int retval = 0;
236   do {
237     retval = ReceiveMessage(conn);
238     ++count;
239   } while (retval > 0 && count < max_recv_count);
240 
241   return;
242 }
243 
EventCallBack(void * connection)244 void TCPComm::EventCallBack(void *connection) {
245   Connection *conn = reinterpret_cast<Connection *>(connection);
246   if (conn == nullptr) {
247     return;
248   }
249   if (conn->state == ConnectionState::kConnected) {
250     conn->conn_mutex->lock();
251     (void)conn->Flush();
252     conn->conn_mutex->unlock();
253   } else if (conn->state == ConnectionState::kDisconnecting) {
254     std::lock_guard<std::mutex> lock(*conn_mutex_);
255     auto current_conn = conn_pool_->FindConnection(conn->destination);
256     if (current_conn != nullptr) {
257       if (current_conn->source != conn->source) {
258         MS_LOG(WARNING) << "Current connection created to " << conn->destination << " is from " << current_conn->source
259                         << ", not " << conn->source << ". No need to delete.";
260         return;
261       }
262     }
263     MS_LOG(INFO) << "The connection state is kDisconnecting. Start disconnecting from " << conn->source << " to "
264                  << conn->destination;
265     conn_pool_->DeleteConnection(conn->destination);
266   }
267 }
268 
WriteCallBack(void * connection)269 void TCPComm::WriteCallBack(void *connection) {
270   Connection *conn = reinterpret_cast<Connection *>(connection);
271   if (conn == nullptr) {
272     return;
273   }
274   if (conn->state == ConnectionState::kConnected) {
275     conn->conn_mutex->lock();
276     (void)conn->Flush();
277     conn->conn_mutex->unlock();
278   }
279 }
280 
281 /* static method */
ReceiveMessage(Connection * conn)282 int TCPComm::ReceiveMessage(Connection *conn) {
283   std::lock_guard<std::mutex> lock(*conn->conn_mutex);
284   conn->CheckMessageType();
285   switch (conn->recv_message_type) {
286     case ParseType::kTcpMsg:
287       return conn->ReceiveMessage();
288     default:
289       return 0;
290   }
291 }
292 
293 /* static method */
SetConnectedHandler(Connection * conn)294 int TCPComm::SetConnectedHandler(Connection *conn) {
295   /* add to epoll */
296   return conn->recv_event_loop->SetEventHandler(conn->socket_fd,
297                                                 static_cast<uint32_t>(EPOLLOUT | EPOLLHUP | EPOLLRDHUP | EPOLLERR),
298                                                 ConnectedEventHandler, reinterpret_cast<void *>(conn));
299 }
300 
301 /* static method */
DoConnect(Connection * conn,const struct sockaddr * sa,socklen_t saLen)302 int TCPComm::DoConnect(Connection *conn, const struct sockaddr *sa, socklen_t saLen) {
303   if (conn == nullptr || conn->recv_event_loop == nullptr || sa == nullptr) {
304     return RPC_ERROR;
305   }
306   int retval = 0;
307   uint16_t localPort = 0;
308 
309   retval = SocketOperation::Connect(conn->socket_fd, sa, saLen, &localPort);
310   if (retval != RPC_OK) {
311     return RPC_ERROR;
312   }
313 
314   // Init connection metrics.
315   if (conn->send_metrics == nullptr) {
316     conn->send_metrics = new (std::nothrow) SendMetrics();
317     if (conn->send_metrics == nullptr) {
318       return RPC_ERROR;
319     }
320   }
321 
322   // Add the socket of this connection to epoll.
323   retval = SetConnectedHandler(conn);
324   if (retval != RPC_OK) {
325     if (conn->send_metrics != nullptr) {
326       delete conn->send_metrics;
327       conn->send_metrics = nullptr;
328     }
329     return RPC_ERROR;
330   }
331   return RPC_OK;
332 }
333 
334 /* static method */
DropMessage(MessageBase * msg)335 void TCPComm::DropMessage(MessageBase *msg) {
336   auto *ptr = msg;
337   delete ptr;
338   ptr = nullptr;
339 }
340 
Send(MessageBase * msg,size_t * const send_bytes,bool sync)341 bool TCPComm::Send(MessageBase *msg, size_t *const send_bytes, bool sync) {
342   if (msg == nullptr) {
343     return false;
344   }
345   auto task = [msg, send_bytes, this] {
346     std::lock_guard<std::mutex> lock(*conn_mutex_);
347     // Search connection by the target address
348     std::string destination = msg->to.Url();
349     Connection *conn = conn_pool_->FindConnection(destination);
350     if (conn == nullptr) {
351       MS_LOG(WARNING) << "Can not found remote link and send fail name: " << msg->name.c_str()
352                       << ", from: " << msg->from.Url().c_str() << ", to: " << destination;
353       DropMessage(msg);
354       return false;
355     }
356 
357     if (conn->send_message_queue.size() >= SENDMSG_QUEUELEN) {
358       MS_LOG(WARNING) << "The message queue is full(max len:" << SENDMSG_QUEUELEN
359                       << ") and the name of dropped message is: " << msg->name.c_str() << ", fd: " << conn->socket_fd
360                       << ", to: " << conn->destination.c_str();
361       if (!conn->FreeMessageMemory(msg)) {
362         MS_LOG(ERROR) << "Failed to free memory of the message.";
363       }
364       DropMessage(msg);
365       return false;
366     }
367 
368     if (conn->state != ConnectionState::kConnected) {
369       MS_LOG(WARNING) << "Invalid connection state " << conn->state
370                       << " and the name of dropped message is: " << msg->name.c_str() << ", fd: " << conn->socket_fd
371                       << ", to: " << conn->destination.c_str();
372       if (!conn->FreeMessageMemory(msg)) {
373         MS_LOG(ERROR) << "Failed to free memory of the message.";
374       }
375       DropMessage(msg);
376       return false;
377     }
378 
379     if (conn->total_send_len == 0) {
380       conn->FillSendMessage(msg, url_, false);
381     } else {
382       (void)conn->send_message_queue.emplace(msg);
383     }
384     auto bytes = conn->Flush();
385     if (send_bytes != nullptr) {
386       *send_bytes = bytes;
387     }
388     return true;
389   };
390   if (sync) {
391     return task();
392   } else {
393     send_event_loop_->AddTask(task);
394     return true;
395   }
396 }
397 
Flush(const std::string & dst_url)398 bool TCPComm::Flush(const std::string &dst_url) {
399   Connection *conn = conn_pool_->FindConnection(dst_url);
400   if (conn == nullptr) {
401     MS_LOG(ERROR) << "Can not find the connection to url: " << dst_url;
402     return false;
403   } else {
404     std::lock_guard<std::mutex> lock(*(conn->conn_mutex));
405     return (conn->Flush() >= 0);
406   }
407 }
408 
Connect(const std::string & dst_url,const MemFreeCallback & free_cb)409 bool TCPComm::Connect(const std::string &dst_url, const MemFreeCallback &free_cb) {
410   MS_EXCEPTION_IF_NULL(conn_mutex_);
411   MS_EXCEPTION_IF_NULL(conn_pool_);
412   if (!free_cb) {
413     MS_LOG(EXCEPTION) << "The message callback is empty.";
414   }
415 
416   std::lock_guard<std::mutex> lock(*conn_mutex_);
417 
418   // Search connection by the target address
419   Connection *conn = conn_pool_->FindConnection(dst_url);
420 
421   if (conn == nullptr) {
422     MS_LOG(INFO) << "Can not found link destination: " << dst_url;
423     conn = new (std::nothrow) Connection();
424     if (conn == nullptr) {
425       MS_LOG(ERROR) << "Failed to create new connection and link fail destination: " << dst_url;
426       return false;
427     }
428     conn->enable_ssl = enable_ssl_;
429     conn->recv_event_loop = this->recv_event_loop_;
430     conn->send_event_loop = this->send_event_loop_;
431     conn->conn_mutex = conn_mutex_;
432     conn->message_handler = message_handler_;
433     conn->InitSocketOperation();
434 
435     // Create the client socket.
436     SocketAddress addr;
437     if (!SocketOperation::GetSockAddr(dst_url, &addr)) {
438       MS_LOG(ERROR) << "Failed to get socket address to dest url " << dst_url;
439       return false;
440     }
441     int sock_fd = SocketOperation::CreateSocket(addr.sa.sa_family);
442     if (sock_fd < 0) {
443       MS_LOG(ERROR) << "Failed to create client tcp socket to dest url " << dst_url;
444       return false;
445     }
446 
447     conn->socket_fd = sock_fd;
448     conn->event_callback = std::bind(&TCPComm::EventCallBack, this, std::placeholders::_1);
449     conn->write_callback = std::bind(&TCPComm::WriteCallBack, this, std::placeholders::_1);
450     conn->read_callback = std::bind(&TCPComm::ReadCallBack, this, std::placeholders::_1);
451 
452     int ret = TCPComm::DoConnect(conn, reinterpret_cast<struct sockaddr *>(&addr), sizeof(addr));
453     if (ret < 0) {
454       MS_LOG(ERROR) << "Failed to do connect and link fail destination: " << dst_url;
455       if (conn->socket_operation != nullptr) {
456         delete conn->socket_operation;
457         conn->socket_operation = nullptr;
458       }
459       delete conn;
460       conn = nullptr;
461       return false;
462     }
463     conn->source = SocketOperation::GetIP(sock_fd) + ":" + std::to_string(SocketOperation::GetPort(sock_fd));
464     conn->destination = dst_url;
465     dst_url_to_src_ip_[dst_url] = SocketOperation::GetIP(sock_fd);
466     MS_LOG(INFO) << "Connection " << sock_fd << " source: " << conn->source << ", destination: " << conn->destination;
467 
468     // Check the state of this new created connection.
469     uint32_t interval = 1;
470     size_t retry = 3;
471     // Record total retry number to avoid duplicated log.
472     static size_t total_retry_count = 0;
473     while (conn->state < ConnectionState::kConnected && retry-- > 0) {
474       MS_LOG(WARNING) << "Waiting for the state of the connection to " << dst_url
475                       << " to be connected...Retry number: " << ++total_retry_count;
476       (void)sleep(interval);
477     }
478     if (conn->state != ConnectionState::kConnected) {
479       return false;
480     }
481     conn_pool_->AddConnection(conn);
482     conn->SetMessageFreeCallback(free_cb);
483   }
484   conn_pool_->AddConnInfo(conn->socket_fd, dst_url, nullptr);
485   MS_LOG(INFO) << "Connected to destination: " << dst_url;
486   return true;
487 }
488 
IsConnected(const std::string & dst_url)489 bool TCPComm::IsConnected(const std::string &dst_url) {
490   MS_EXCEPTION_IF_NULL(conn_pool_);
491   Connection *conn = conn_pool_->FindConnection(dst_url);
492   if (conn != nullptr && conn->state == ConnectionState::kConnected) {
493     return true;
494   }
495   return false;
496 }
497 
Disconnect(const std::string & dst_url)498 bool TCPComm::Disconnect(const std::string &dst_url) {
499   MS_EXCEPTION_IF_NULL(conn_mutex_);
500   MS_EXCEPTION_IF_NULL(conn_pool_);
501   MS_EXCEPTION_IF_NULL(recv_event_loop_);
502   MS_EXCEPTION_IF_NULL(send_event_loop_);
503 
504   unsigned int interval = 100000;
505   size_t retry = 30;
506   while (recv_event_loop_->RemainingTaskNum() != 0 && send_event_loop_->RemainingTaskNum() != 0 && retry > 0) {
507     (void)usleep(interval);
508     retry--;
509   }
510   if (recv_event_loop_->RemainingTaskNum() > 0 || send_event_loop_->RemainingTaskNum() > 0) {
511     MS_LOG(ERROR) << "Failed to disconnect from url " << dst_url
512                   << ", because there are still pending tasks to be executed, please try later.";
513     return false;
514   }
515   std::lock_guard<std::mutex> lock(*conn_mutex_);
516   auto conn = conn_pool_->FindConnection(dst_url);
517   if (conn != nullptr) {
518     std::lock_guard<std::mutex> conn_lock(conn->conn_owned_mutex_);
519     conn_pool_->DeleteConnection(dst_url);
520   }
521   return true;
522 }
523 
CreateDefaultConn(const std::string & to)524 Connection *TCPComm::CreateDefaultConn(const std::string &to) {
525   Connection *conn = new (std::nothrow) Connection();
526   if (conn == nullptr) {
527     MS_LOG(ERROR) << "Failed to create new connection and reconnect fail to: " << to.c_str();
528     return conn;
529   }
530   conn->enable_ssl = enable_ssl_;
531   conn->source = url_.data();
532   conn->destination = to;
533   conn->recv_event_loop = this->recv_event_loop_;
534   conn->send_event_loop = this->send_event_loop_;
535   conn->conn_mutex = conn_mutex_;
536   conn->message_handler = message_handler_;
537   conn->InitSocketOperation();
538   return conn;
539 }
540 
Finalize()541 void TCPComm::Finalize() {
542   if (send_event_loop_ != nullptr) {
543     MS_LOG(INFO) << "Delete send event loop";
544     send_event_loop_->Finalize();
545     delete send_event_loop_;
546     send_event_loop_ = nullptr;
547   }
548 
549   if (recv_event_loop_ != nullptr) {
550     MS_LOG(INFO) << "Delete recv event loop";
551     recv_event_loop_->Finalize();
552     delete recv_event_loop_;
553     recv_event_loop_ = nullptr;
554   }
555 
556   if (server_fd_ > 0) {
557     if (close(server_fd_) != 0) {
558       MS_LOG(ERROR) << "Failed to close fd: " << server_fd_;
559     }
560     server_fd_ = -1;
561   }
562 
563   if (conn_pool_ != nullptr) {
564     MS_LOG(INFO) << "Delete connection pool.";
565     conn_pool_->Finalize();
566     conn_pool_.reset();
567     conn_pool_ = nullptr;
568   }
569 }
570 }  // namespace rpc
571 }  // namespace distributed
572 }  // namespace mindspore
573