• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/core/communicator/tcp_server.h"
18 
19 #include <arpa/inet.h>
20 #include <event2/buffer.h>
21 #include <event2/buffer_compat.h>
22 #include <event2/bufferevent.h>
23 #include <event2/event.h>
24 #include <event2/listener.h>
25 #include <event2/util.h>
26 #include <netinet/in.h>
27 #include <netinet/tcp.h>
28 #include <sys/socket.h>
29 #include <csignal>
30 #include <utility>
31 
32 namespace mindspore {
33 namespace ps {
34 namespace core {
~TcpConnection()35 TcpConnection::~TcpConnection() {
36   MS_LOG(WARNING) << "TcpConnection is destructed! fd is " << fd_;
37   bufferevent_free(buffer_event_);
38 }
InitConnection(const messageReceive & callback)39 void TcpConnection::InitConnection(const messageReceive &callback) { tcp_message_handler_.SetCallback(callback); }
40 
OnReadHandler(const void * buffer,size_t num)41 void TcpConnection::OnReadHandler(const void *buffer, size_t num) {
42   MS_EXCEPTION_IF_NULL(buffer);
43   tcp_message_handler_.ReceiveMessage(buffer, num);
44 }
45 
SendMessage(const void * buffer,size_t num) const46 void TcpConnection::SendMessage(const void *buffer, size_t num) const {
47   MS_EXCEPTION_IF_NULL(buffer);
48   MS_EXCEPTION_IF_NULL(buffer_event_);
49   bufferevent_lock(buffer_event_);
50   if (bufferevent_write(buffer_event_, buffer, num) == -1) {
51     MS_LOG(ERROR) << "Write message to buffer event failed!";
52   }
53   bufferevent_unlock(buffer_event_);
54 }
55 
GetServer() const56 const TcpServer *TcpConnection::GetServer() const { return server_; }
57 
GetFd() const58 const evutil_socket_t &TcpConnection::GetFd() const { return fd_; }
59 
set_callback(const Callback & callback)60 void TcpConnection::set_callback(const Callback &callback) { callback_ = callback; }
61 
SendMessage(const std::shared_ptr<CommMessage> & message) const62 bool TcpConnection::SendMessage(const std::shared_ptr<CommMessage> &message) const {
63   MS_EXCEPTION_IF_NULL(buffer_event_);
64   MS_EXCEPTION_IF_NULL(message);
65   bufferevent_lock(buffer_event_);
66   bool res = true;
67   size_t buf_size = message->ByteSizeLong();
68   if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) {
69     MS_LOG(ERROR) << "Event buffer add header failed!";
70     res = false;
71   }
72   if (bufferevent_write(buffer_event_, message->SerializeAsString().data(), buf_size) == -1) {
73     MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
74     res = false;
75   }
76   bufferevent_unlock(buffer_event_);
77   return res;
78 }
79 
SendMessage(const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size) const80 bool TcpConnection::SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
81                                 size_t size) const {
82   MS_EXCEPTION_IF_NULL(buffer_event_);
83   MS_EXCEPTION_IF_NULL(meta);
84   MS_EXCEPTION_IF_NULL(data);
85   bufferevent_lock(buffer_event_);
86   bool res = true;
87   MessageHeader header;
88   header.message_proto_ = protos;
89   header.message_meta_length_ = SizeToUint(meta->ByteSizeLong());
90   header.message_length_ = size + header.message_meta_length_;
91 
92   if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
93     MS_LOG(ERROR) << "Event buffer add header failed!";
94     res = false;
95   }
96   if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) {
97     MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
98     res = false;
99   }
100   if (bufferevent_write(buffer_event_, data, size) == -1) {
101     MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
102     res = false;
103   }
104   int result = bufferevent_flush(buffer_event_, EV_READ | EV_WRITE, BEV_FLUSH);
105   if (result < 0) {
106     bufferevent_unlock(buffer_event_);
107     MS_LOG(EXCEPTION) << "Bufferevent flush failed!";
108   }
109   bufferevent_unlock(buffer_event_);
110   return res;
111 }
112 
TcpServer(const std::string & address,std::uint16_t port,Configuration * const config,const std::pair<uint32_t,uint32_t> & port_range)113 TcpServer::TcpServer(const std::string &address, std::uint16_t port, Configuration *const config,
114                      const std::pair<uint32_t, uint32_t> &port_range)
115     : base_(nullptr),
116       signal_event_(nullptr),
117       listener_(nullptr),
118       server_address_(std::move(address)),
119       server_port_(port),
120       is_stop_(true),
121       config_(config),
122       max_connection_(0),
123       port_range_(port_range) {}
124 
~TcpServer()125 TcpServer::~TcpServer() {
126   if (signal_event_ != nullptr) {
127     event_free(signal_event_);
128     signal_event_ = nullptr;
129   }
130 
131   if (listener_ != nullptr) {
132     evconnlistener_free(listener_);
133     listener_ = nullptr;
134   }
135 
136   if (base_ != nullptr) {
137     event_base_free(base_);
138     base_ = nullptr;
139   }
140 }
141 
SetServerCallback(const OnConnected & client_conn,const OnDisconnected & client_disconn,const OnAccepted & client_accept)142 void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
143                                   const OnAccepted &client_accept) {
144   this->client_connection_ = client_conn;
145   this->client_disconnection_ = client_disconn;
146   this->client_accept_ = client_accept;
147 }
148 
Init()149 void TcpServer::Init() {
150   if (PSContext::instance()->enable_ssl()) {
151     MS_LOG(INFO) << "Init ssl.";
152     SSLWrapper::GetInstance().InitSSL();
153   }
154   int result = evthread_use_pthreads();
155   if (result != 0) {
156     MS_LOG(EXCEPTION) << "Use event pthread failed!";
157   }
158 
159   is_stop_ = false;
160   base_ = event_base_new();
161   MS_EXCEPTION_IF_NULL(base_);
162   if (!CommUtil::CheckIp(server_address_)) {
163     MS_LOG(EXCEPTION) << "The tcp server ip:" << server_address_ << " is illegal!";
164   }
165   MS_EXCEPTION_IF_NULL(config_);
166   max_connection_ = kConnectionNumDefault;
167   if (config_->Exists(kConnectionNum)) {
168     max_connection_ = config_->GetInt(kConnectionNum, 0);
169   }
170   MS_LOG(INFO) << "The max connection is:" << max_connection_;
171 
172   struct sockaddr_in sin {};
173   if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
174     MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
175   }
176   sin.sin_family = AF_INET;
177   sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
178   // We do not use server_port_ because it's always 0.
179   MS_LOG(INFO) << "Initialize TcpServer with port range " << port_range_.first << " to " << port_range_.second;
180   uint16_t current_port = static_cast<uint16_t>(port_range_.first);
181   do {
182     sin.sin_port = htons(current_port);
183     listener_ = evconnlistener_new_bind(base_, ListenerCallback, reinterpret_cast<void *>(this),
184                                         LEV_OPT_REUSEABLE | LEV_OPT_CLOSE_ON_FREE, -1,
185                                         reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
186     if (listener_ == nullptr) {
187       current_port++;
188       MS_LOG(WARNING) << "The port " << current_port << " is already in use. So increase port to: " << current_port;
189       if (current_port > port_range_.second) {
190         MS_LOG(EXCEPTION) << "Port range " << port_range_.first << " to " << port_range_.second
191                           << " are all in use already. You can run 'netstat -anp|grep <port number>' command to check "
192                              "which process occupies the port.";
193       }
194     }
195   } while (listener_ == nullptr);
196 
197   if (server_port_ == 0) {
198     struct sockaddr_in sin_bound {};
199     if (memset_s(&sin, sizeof(sin_bound), 0, sizeof(sin_bound)) != EOK) {
200       MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
201     }
202     socklen_t addr_len = sizeof(struct sockaddr_in);
203     if (getsockname(evconnlistener_get_fd(listener_), (struct sockaddr *)&sin_bound, &addr_len) != 0) {
204       MS_LOG(EXCEPTION) << "Get sock name failed!";
205     }
206     server_port_ = htons(sin_bound.sin_port);
207   }
208 
209   signal_event_ = evsignal_new(base_, SIGINT, SignalCallback, reinterpret_cast<void *>(this));
210   MS_EXCEPTION_IF_NULL(signal_event_);
211   if (event_add(signal_event_, nullptr) < 0) {
212     MS_LOG(EXCEPTION) << "Cannot create signal event.";
213   }
214 }
215 
Start()216 void TcpServer::Start() {
217   MS_LOG(INFO) << "Start tcp server!";
218   MS_EXCEPTION_IF_NULL(base_);
219   int ret = event_base_dispatch(base_);
220   MSLOG_IF(MsLogLevel::kInfo, ret == 0, NoExceptionType, nullptr) << "Event base dispatch success!";
221   MSLOG_IF(MsLogLevel::kError, ret == 1, NoExceptionType, nullptr)
222     << "Event base dispatch failed with no events pending or active!";
223   MSLOG_IF(MsLogLevel::kError, ret == -1, NoExceptionType, nullptr)
224     << "Event base dispatch failed with error occurred!";
225   MSLOG_IF(MsLogLevel::kException, ret < -1, AbortedError, nullptr)
226     << "Event base dispatch with unexpected error code!";
227 }
228 
Stop()229 void TcpServer::Stop() {
230   MS_ERROR_IF_NULL_WO_RET_VAL(base_);
231   std::lock_guard<std::mutex> lock(connection_mutex_);
232   MS_LOG(INFO) << "Stop tcp server!";
233   if (event_base_got_break(base_)) {
234     MS_LOG(DEBUG) << "The event base has already been stopped!";
235     is_stop_ = true;
236     return;
237   }
238   if (!is_stop_.load()) {
239     is_stop_ = true;
240     int ret = event_base_loopbreak(base_);
241     if (ret != 0) {
242       MS_LOG(ERROR) << "Event base loop break failed!";
243     }
244   }
245 }
246 
SendToAllClients(const char * data,size_t len)247 void TcpServer::SendToAllClients(const char *data, size_t len) {
248   MS_EXCEPTION_IF_NULL(data);
249   std::lock_guard<std::mutex> lock(connection_mutex_);
250   for (auto it = connections_.begin(); it != connections_.end(); ++it) {
251     it->second->SendMessage(data, len);
252   }
253 }
254 
AddConnection(const evutil_socket_t & fd,std::shared_ptr<TcpConnection> connection)255 void TcpServer::AddConnection(const evutil_socket_t &fd, std::shared_ptr<TcpConnection> connection) {
256   MS_EXCEPTION_IF_NULL(connection);
257   std::lock_guard<std::mutex> lock(connection_mutex_);
258   (void)connections_.insert(std::make_pair(fd, connection));
259 }
260 
RemoveConnection(const evutil_socket_t & fd)261 void TcpServer::RemoveConnection(const evutil_socket_t &fd) {
262   std::lock_guard<std::mutex> lock(connection_mutex_);
263   MS_LOG(INFO) << "Remove connection fd: " << fd;
264   (void)connections_.erase(fd);
265 }
266 
GetConnectionByFd(const evutil_socket_t & fd)267 std::shared_ptr<TcpConnection> &TcpServer::GetConnectionByFd(const evutil_socket_t &fd) { return connections_[fd]; }
268 
ListenerCallback(struct evconnlistener *,evutil_socket_t fd,struct sockaddr * sockaddr,int,void * const data)269 void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *sockaddr, int,
270                                  void *const data) {
271   try {
272     ListenerCallbackInner(fd, sockaddr, data);
273   } catch (const std::exception &e) {
274     MS_LOG(ERROR) << "Catch exception: " << e.what();
275   }
276 }
277 
ListenerCallbackInner(evutil_socket_t fd,struct sockaddr * sockaddr,void * const data)278 void TcpServer::ListenerCallbackInner(evutil_socket_t fd, struct sockaddr *sockaddr, void *const data) {
279   auto server = reinterpret_cast<class TcpServer *>(data);
280   MS_EXCEPTION_IF_NULL(server);
281   auto base = reinterpret_cast<struct event_base *>(server->base_);
282   MS_EXCEPTION_IF_NULL(base);
283   MS_EXCEPTION_IF_NULL(sockaddr);
284 
285   if (server->ConnectionNum() >= server->max_connection_) {
286     MS_LOG(WARNING) << "The current connection num:" << server->ConnectionNum() << " is greater or equal to "
287                     << server->max_connection_;
288     return;
289   }
290 
291   struct bufferevent *bev = nullptr;
292 
293   if (!PSContext::instance()->enable_ssl()) {
294     MS_LOG(INFO) << "SSL is disable.";
295     bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
296   } else {
297     MS_LOG(INFO) << "Enable ssl support.";
298     SSL *ssl = SSL_new(SSLWrapper::GetInstance().GetSSLCtx());
299     MS_EXCEPTION_IF_NULL(ssl);
300     bev = bufferevent_openssl_socket_new(base, fd, ssl, BUFFEREVENT_SSL_ACCEPTING,
301                                          BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
302   }
303   if (bev == nullptr) {
304     MS_LOG(ERROR) << "Error constructing buffer event!";
305     int ret = event_base_loopbreak(base);
306     if (ret != 0) {
307       MS_LOG(EXCEPTION) << "event base loop break failed!";
308     }
309     return;
310   }
311 
312   std::shared_ptr<TcpConnection> conn = server->onCreateConnection(bev, fd);
313   MS_EXCEPTION_IF_NULL(conn);
314   SetTcpNoDelay(fd);
315   server->AddConnection(fd, conn);
316   conn->InitConnection(
317     [=](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) {
318       OnServerReceiveMessage on_server_receive = server->GetServerReceive();
319       if (on_server_receive) {
320         on_server_receive(conn, meta, protos, data, size);
321       }
322     });
323 
324   bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback,
325                     reinterpret_cast<void *>(conn.get()));
326   MS_LOG(INFO) << "A client is connected, fd is " << fd;
327   if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
328     MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!";
329   }
330 }
331 
onCreateConnection(struct bufferevent * bev,const evutil_socket_t & fd)332 std::shared_ptr<TcpConnection> TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) {
333   MS_EXCEPTION_IF_NULL(bev);
334   std::shared_ptr<TcpConnection> conn = nullptr;
335   if (client_accept_) {
336     conn = (client_accept_(*this));
337   } else {
338     conn = std::make_shared<TcpConnection>(bev, fd, this);
339   }
340 
341   return conn;
342 }
343 
GetServerReceive() const344 OnServerReceiveMessage TcpServer::GetServerReceive() const { return message_callback_; }
345 
SignalCallback(evutil_socket_t,std::int16_t,void * const data)346 void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *const data) {
347   try {
348     SignalCallbackInner(data);
349   } catch (const std::exception &e) {
350     MS_LOG(ERROR) << "Catch exception: " << e.what();
351   }
352 }
353 
SignalCallbackInner(void * const data)354 void TcpServer::SignalCallbackInner(void *const data) {
355   MS_EXCEPTION_IF_NULL(data);
356   auto server = reinterpret_cast<class TcpServer *>(data);
357   struct event_base *base = server->base_;
358   MS_EXCEPTION_IF_NULL(base);
359   struct timeval delay = {0, 0};
360   MS_LOG(ERROR) << "Caught an interrupt signal; exiting cleanly in 0 seconds.";
361   if (event_base_loopexit(base, &delay) == -1) {
362     MS_LOG(ERROR) << "Event base loop exit failed.";
363   }
364 }
365 
ReadCallback(struct bufferevent * bev,void * const connection)366 void TcpServer::ReadCallback(struct bufferevent *bev, void *const connection) {
367   try {
368     ReadCallbackInner(bev, connection);
369   } catch (const std::exception &e) {
370     MS_LOG(ERROR) << "Catch exception: " << e.what();
371   }
372 }
373 
ReadCallbackInner(struct bufferevent * bev,void * const connection)374 void TcpServer::ReadCallbackInner(struct bufferevent *bev, void *const connection) {
375   MS_EXCEPTION_IF_NULL(bev);
376   MS_EXCEPTION_IF_NULL(connection);
377 
378   auto conn = static_cast<class TcpConnection *>(connection);
379   struct evbuffer *buf = bufferevent_get_input(bev);
380   MS_EXCEPTION_IF_NULL(buf);
381   char read_buffer[kMessageChunkLength];
382   while (EVBUFFER_LENGTH(buf) > 0) {
383     int read = evbuffer_remove(buf, &read_buffer, sizeof(read_buffer));
384     if (read == -1) {
385       MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
386     }
387     conn->OnReadHandler(read_buffer, IntToSize(read));
388   }
389 }
390 
EventCallback(struct bufferevent * bev,std::int16_t events,void * const data)391 void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void *const data) {
392   try {
393     EventCallbackInner(bev, events, data);
394   } catch (const std::exception &e) {
395     MS_LOG(ERROR) << "Catch exception: " << e.what();
396   }
397 }
398 
EventCallbackInner(struct bufferevent * bev,std::int16_t events,void * const data)399 void TcpServer::EventCallbackInner(struct bufferevent *bev, std::int16_t events, void *const data) {
400   MS_EXCEPTION_IF_NULL(bev);
401   MS_EXCEPTION_IF_NULL(data);
402   struct evbuffer *output = bufferevent_get_output(bev);
403   MS_EXCEPTION_IF_NULL(output);
404   auto conn = static_cast<class TcpConnection *>(data);
405   auto srv = const_cast<TcpServer *>(conn->GetServer());
406   MS_EXCEPTION_IF_NULL(srv);
407 
408   if (events & BEV_EVENT_EOF) {
409     MS_LOG(INFO) << "BEV_EVENT_EOF event is trigger!";
410     // Notify about disconnection
411     if (srv->client_disconnection_) {
412       srv->client_disconnection_(*srv, *conn);
413     }
414     // Free connection structures
415     srv->RemoveConnection(conn->GetFd());
416   } else if (events & BEV_EVENT_ERROR) {
417     MS_LOG(WARNING) << "BEV_EVENT_ERROR event is trigger!";
418     if (PSContext::instance()->enable_ssl()) {
419       uint64_t err = bufferevent_get_openssl_error(bev);
420       MS_LOG(WARNING) << "The error number is:" << err;
421 
422       MS_LOG(WARNING) << "Error message:" << ERR_reason_error_string(err)
423                       << ", the error lib:" << ERR_lib_error_string(err)
424                       << ", the error func:" << ERR_func_error_string(err);
425     }
426     // Notify about disconnection
427     if (srv->client_disconnection_) {
428       srv->client_disconnection_(*srv, *conn);
429     }
430     // Free connection structures
431     srv->RemoveConnection(conn->GetFd());
432   } else {
433     MS_LOG(WARNING) << "Unhandled event:" << events;
434   }
435 }
436 
SetTcpNoDelay(const evutil_socket_t & fd)437 void TcpServer::SetTcpNoDelay(const evutil_socket_t &fd) {
438   const int one = 1;
439   int ret = setsockopt(fd, static_cast<int>(IPPROTO_TCP), static_cast<int>(TCP_NODELAY), &one, sizeof(int));
440   if (ret < 0) {
441     MS_LOG(EXCEPTION) << "Set socket no delay failed!";
442   }
443 }
444 
SendMessage(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<CommMessage> & message)445 bool TcpServer::SendMessage(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<CommMessage> &message) {
446   MS_EXCEPTION_IF_NULL(conn);
447   MS_EXCEPTION_IF_NULL(message);
448   return conn->SendMessage(message);
449 }
450 
SendMessage(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size)451 bool TcpServer::SendMessage(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
452                             const Protos &protos, const void *data, size_t size) {
453   MS_EXCEPTION_IF_NULL(conn);
454   MS_EXCEPTION_IF_NULL(meta);
455   MS_EXCEPTION_IF_NULL(data);
456   return conn->SendMessage(meta, protos, data, size);
457 }
458 
SendMessage(const std::shared_ptr<CommMessage> & message)459 void TcpServer::SendMessage(const std::shared_ptr<CommMessage> &message) {
460   MS_EXCEPTION_IF_NULL(message);
461   std::lock_guard<std::mutex> lock(connection_mutex_);
462 
463   for (auto it = connections_.begin(); it != connections_.end(); ++it) {
464     SendMessage(it->second, message);
465   }
466 }
467 
BoundPort() const468 uint16_t TcpServer::BoundPort() const { return server_port_; }
469 
BoundIp() const470 std::string TcpServer::BoundIp() const { return server_address_; }
471 
ConnectionNum() const472 int TcpServer::ConnectionNum() const { return SizeToInt(connections_.size()); }
473 
Connections() const474 const std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> &TcpServer::Connections() const { return connections_; }
475 
SetMessageCallback(const OnServerReceiveMessage & cb)476 void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
477 }  // namespace core
478 }  // namespace ps
479 }  // namespace mindspore
480