• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/core/communicator/tcp_server.h"
18 
19 #include <arpa/inet.h>
20 #include <event2/buffer.h>
21 #include <event2/buffer_compat.h>
22 #include <event2/bufferevent.h>
23 #include <event2/event.h>
24 #include <event2/listener.h>
25 #include <event2/util.h>
26 #include <netinet/in.h>
27 #include <netinet/tcp.h>
28 #include <sys/socket.h>
29 #include <csignal>
30 #include <utility>
31 
32 namespace mindspore {
33 namespace ps {
34 namespace core {
~TcpConnection()35 TcpConnection::~TcpConnection() { bufferevent_free(buffer_event_); }
InitConnection(const messageReceive & callback)36 void TcpConnection::InitConnection(const messageReceive &callback) { tcp_message_handler_.SetCallback(callback); }
37 
OnReadHandler(const void * buffer,size_t num)38 void TcpConnection::OnReadHandler(const void *buffer, size_t num) {
39   MS_EXCEPTION_IF_NULL(buffer);
40   tcp_message_handler_.ReceiveMessage(buffer, num);
41 }
42 
SendMessage(const void * buffer,size_t num) const43 void TcpConnection::SendMessage(const void *buffer, size_t num) const {
44   MS_EXCEPTION_IF_NULL(buffer);
45   MS_EXCEPTION_IF_NULL(buffer_event_);
46   bufferevent_lock(buffer_event_);
47   if (bufferevent_write(buffer_event_, buffer, num) == -1) {
48     MS_LOG(ERROR) << "Write message to buffer event failed!";
49   }
50   bufferevent_unlock(buffer_event_);
51 }
52 
GetServer() const53 const TcpServer *TcpConnection::GetServer() const { return server_; }
54 
GetFd() const55 const evutil_socket_t &TcpConnection::GetFd() const { return fd_; }
56 
set_callback(const Callback & callback)57 void TcpConnection::set_callback(const Callback &callback) { callback_ = callback; }
58 
SendMessage(const std::shared_ptr<CommMessage> & message) const59 bool TcpConnection::SendMessage(const std::shared_ptr<CommMessage> &message) const {
60   MS_EXCEPTION_IF_NULL(buffer_event_);
61   MS_EXCEPTION_IF_NULL(message);
62   bufferevent_lock(buffer_event_);
63   bool res = true;
64   size_t buf_size = message->ByteSizeLong();
65   if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) {
66     MS_LOG(ERROR) << "Event buffer add header failed!";
67     res = false;
68   }
69   if (bufferevent_write(buffer_event_, message->SerializeAsString().data(), buf_size) == -1) {
70     MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
71     res = false;
72   }
73   bufferevent_unlock(buffer_event_);
74   return res;
75 }
76 
SendMessage(const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size) const77 bool TcpConnection::SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
78                                 size_t size) const {
79   MS_EXCEPTION_IF_NULL(buffer_event_);
80   MS_EXCEPTION_IF_NULL(meta);
81   MS_EXCEPTION_IF_NULL(data);
82   bufferevent_lock(buffer_event_);
83   bool res = true;
84   MessageHeader header;
85   header.message_proto_ = protos;
86   header.message_meta_length_ = SizeToUint(meta->ByteSizeLong());
87   header.message_length_ = size + header.message_meta_length_;
88 
89   if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
90     MS_LOG(ERROR) << "Event buffer add header failed!";
91     res = false;
92   }
93   if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) {
94     MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
95     res = false;
96   }
97   if (bufferevent_write(buffer_event_, data, size) == -1) {
98     MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
99     res = false;
100   }
101   int result = bufferevent_flush(buffer_event_, EV_READ | EV_WRITE, BEV_FLUSH);
102   if (result < 0) {
103     bufferevent_unlock(buffer_event_);
104     MS_LOG(EXCEPTION) << "Bufferevent flush failed!";
105   }
106   bufferevent_unlock(buffer_event_);
107   return res;
108 }
109 
TcpServer(const std::string & address,std::uint16_t port,Configuration * const config)110 TcpServer::TcpServer(const std::string &address, std::uint16_t port, Configuration *const config)
111     : base_(nullptr),
112       signal_event_(nullptr),
113       listener_(nullptr),
114       server_address_(std::move(address)),
115       server_port_(port),
116       is_stop_(true),
117       config_(config),
118       max_connection_(0) {}
119 
~TcpServer()120 TcpServer::~TcpServer() {
121   if (signal_event_ != nullptr) {
122     event_free(signal_event_);
123     signal_event_ = nullptr;
124   }
125 
126   if (listener_ != nullptr) {
127     evconnlistener_free(listener_);
128     listener_ = nullptr;
129   }
130 
131   if (base_ != nullptr) {
132     event_base_free(base_);
133     base_ = nullptr;
134   }
135 }
136 
SetServerCallback(const OnConnected & client_conn,const OnDisconnected & client_disconn,const OnAccepted & client_accept)137 void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
138                                   const OnAccepted &client_accept) {
139   this->client_connection_ = client_conn;
140   this->client_disconnection_ = client_disconn;
141   this->client_accept_ = client_accept;
142 }
143 
set_timer_once_callback(const OnTimerOnce & timer)144 void TcpServer::set_timer_once_callback(const OnTimerOnce &timer) { on_timer_once_callback_ = timer; }
145 
set_timer_callback(const OnTimer & timer)146 void TcpServer::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; }
147 
Init()148 void TcpServer::Init() {
149   int result = evthread_use_pthreads();
150   if (result != 0) {
151     MS_LOG(EXCEPTION) << "Use event pthread failed!";
152   }
153 
154   is_stop_ = false;
155   base_ = event_base_new();
156   MS_EXCEPTION_IF_NULL(base_);
157   if (!CommUtil::CheckIp(server_address_)) {
158     MS_LOG(EXCEPTION) << "The tcp server ip:" << server_address_ << " is illegal!";
159   }
160   MS_EXCEPTION_IF_NULL(config_);
161   max_connection_ = kConnectionNumDefault;
162   if (config_->Exists(kConnectionNum)) {
163     max_connection_ = config_->GetInt(kConnectionNum, 0);
164   }
165   MS_LOG(INFO) << "The max connection is:" << max_connection_;
166 
167   struct sockaddr_in sin {};
168   if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
169     MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
170   }
171   sin.sin_family = AF_INET;
172   sin.sin_port = htons(server_port_);
173   sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
174 
175   listener_ = evconnlistener_new_bind(base_, ListenerCallback, reinterpret_cast<void *>(this),
176                                       LEV_OPT_REUSEABLE | LEV_OPT_CLOSE_ON_FREE, -1,
177                                       reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
178 
179   MS_EXCEPTION_IF_NULL(listener_);
180 
181   if (server_port_ == 0) {
182     struct sockaddr_in sin_bound {};
183     if (memset_s(&sin, sizeof(sin_bound), 0, sizeof(sin_bound)) != EOK) {
184       MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
185     }
186     socklen_t addr_len = sizeof(struct sockaddr_in);
187     if (getsockname(evconnlistener_get_fd(listener_), (struct sockaddr *)&sin_bound, &addr_len) != 0) {
188       MS_LOG(EXCEPTION) << "Get sock name failed!";
189     }
190     server_port_ = htons(sin_bound.sin_port);
191   }
192 
193   signal_event_ = evsignal_new(base_, SIGINT, SignalCallback, reinterpret_cast<void *>(this));
194   MS_EXCEPTION_IF_NULL(signal_event_);
195   if (event_add(signal_event_, nullptr) < 0) {
196     MS_LOG(EXCEPTION) << "Cannot create signal event.";
197   }
198 }
199 
Start()200 void TcpServer::Start() {
201   MS_LOG(INFO) << "Start tcp server!";
202   MS_EXCEPTION_IF_NULL(base_);
203   int ret = event_base_dispatch(base_);
204   MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
205   MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
206     << "Event base dispatch failed with no events pending or active!";
207   MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
208   MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpected error code!";
209 }
210 
StartWithNoBlock()211 void TcpServer::StartWithNoBlock() {
212   std::lock_guard<std::mutex> lock(connection_mutex_);
213   MS_LOG(INFO) << "Start tcp server with no block!";
214   MS_EXCEPTION_IF_NULL(base_);
215   int ret = event_base_loop(base_, EVLOOP_NONBLOCK);
216   MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
217   MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
218   MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
219   MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpected error code!";
220 }
221 
Stop()222 void TcpServer::Stop() {
223   MS_EXCEPTION_IF_NULL(base_);
224   std::lock_guard<std::mutex> lock(connection_mutex_);
225   MS_LOG(INFO) << "Stop tcp server!";
226   if (event_base_got_break(base_)) {
227     MS_LOG(DEBUG) << "The event base has stopped!";
228     is_stop_ = true;
229     return;
230   }
231   if (!is_stop_.load()) {
232     is_stop_ = true;
233     int ret = event_base_loopbreak(base_);
234     if (ret != 0) {
235       MS_LOG(ERROR) << "Event base loop break failed!";
236     }
237   }
238 }
239 
SendToAllClients(const char * data,size_t len)240 void TcpServer::SendToAllClients(const char *data, size_t len) {
241   MS_EXCEPTION_IF_NULL(data);
242   std::lock_guard<std::mutex> lock(connection_mutex_);
243   for (auto it = connections_.begin(); it != connections_.end(); ++it) {
244     it->second->SendMessage(data, len);
245   }
246 }
247 
AddConnection(const evutil_socket_t & fd,std::shared_ptr<TcpConnection> connection)248 void TcpServer::AddConnection(const evutil_socket_t &fd, std::shared_ptr<TcpConnection> connection) {
249   MS_EXCEPTION_IF_NULL(connection);
250   std::lock_guard<std::mutex> lock(connection_mutex_);
251   connections_.insert(std::make_pair(fd, connection));
252 }
253 
RemoveConnection(const evutil_socket_t & fd)254 void TcpServer::RemoveConnection(const evutil_socket_t &fd) {
255   std::lock_guard<std::mutex> lock(connection_mutex_);
256   connections_.erase(fd);
257 }
258 
GetConnectionByFd(const evutil_socket_t & fd)259 std::shared_ptr<TcpConnection> TcpServer::GetConnectionByFd(const evutil_socket_t &fd) { return connections_[fd]; }
260 
ListenerCallback(struct evconnlistener *,evutil_socket_t fd,struct sockaddr * sockaddr,int,void * data)261 void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *sockaddr, int,
262                                  void *data) {
263   auto server = reinterpret_cast<class TcpServer *>(data);
264   MS_EXCEPTION_IF_NULL(server);
265   auto base = reinterpret_cast<struct event_base *>(server->base_);
266   MS_EXCEPTION_IF_NULL(base);
267   MS_EXCEPTION_IF_NULL(sockaddr);
268 
269   if (server->ConnectionNum() >= server->max_connection_) {
270     MS_LOG(WARNING) << "The current connection num:" << server->ConnectionNum() << " is greater or equal to "
271                     << server->max_connection_;
272     return;
273   }
274 
275   struct bufferevent *bev = nullptr;
276 
277   if (!PSContext::instance()->enable_ssl()) {
278     MS_LOG(INFO) << "SSL is disable.";
279     bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
280   } else {
281     MS_LOG(INFO) << "Enable ssl support.";
282     SSL *ssl = SSL_new(SSLWrapper::GetInstance().GetSSLCtx());
283     MS_EXCEPTION_IF_NULL(ssl);
284     bev = bufferevent_openssl_socket_new(base, fd, ssl, BUFFEREVENT_SSL_ACCEPTING,
285                                          BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
286   }
287   if (bev == nullptr) {
288     MS_LOG(ERROR) << "Error constructing buffer event!";
289     int ret = event_base_loopbreak(base);
290     if (ret != 0) {
291       MS_LOG(EXCEPTION) << "event base loop break failed!";
292     }
293     return;
294   }
295 
296   std::shared_ptr<TcpConnection> conn = server->onCreateConnection(bev, fd);
297   MS_EXCEPTION_IF_NULL(conn);
298   SetTcpNoDelay(fd);
299   server->AddConnection(fd, conn);
300   conn->InitConnection(
301     [=](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) {
302       OnServerReceiveMessage on_server_receive = server->GetServerReceive();
303       if (on_server_receive) {
304         on_server_receive(conn, meta, protos, data, size);
305       }
306     });
307   bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback,
308                     reinterpret_cast<void *>(conn.get()));
309   if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
310     MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!";
311   }
312 }
313 
onCreateConnection(struct bufferevent * bev,const evutil_socket_t & fd)314 std::shared_ptr<TcpConnection> TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) {
315   MS_EXCEPTION_IF_NULL(bev);
316   std::shared_ptr<TcpConnection> conn = nullptr;
317   if (client_accept_) {
318     conn = (client_accept_(*this));
319   } else {
320     conn = std::make_shared<TcpConnection>(bev, fd, this);
321   }
322 
323   return conn;
324 }
325 
GetServerReceive() const326 OnServerReceiveMessage TcpServer::GetServerReceive() const { return message_callback_; }
327 
SignalCallback(evutil_socket_t,std::int16_t,void * data)328 void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *data) {
329   MS_EXCEPTION_IF_NULL(data);
330   auto server = reinterpret_cast<class TcpServer *>(data);
331   struct event_base *base = server->base_;
332   MS_EXCEPTION_IF_NULL(base);
333   struct timeval delay = {0, 0};
334   MS_LOG(ERROR) << "Caught an interrupt signal; exiting cleanly in 0 seconds.";
335   if (event_base_loopexit(base, &delay) == -1) {
336     MS_LOG(ERROR) << "Event base loop exit failed.";
337   }
338 }
339 
ReadCallback(struct bufferevent * bev,void * connection)340 void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) {
341   MS_EXCEPTION_IF_NULL(bev);
342   MS_EXCEPTION_IF_NULL(connection);
343 
344   auto conn = static_cast<class TcpConnection *>(connection);
345   struct evbuffer *buf = bufferevent_get_input(bev);
346   MS_EXCEPTION_IF_NULL(buf);
347   char read_buffer[kMessageChunkLength];
348   while (EVBUFFER_LENGTH(buf) > 0) {
349     int read = evbuffer_remove(buf, &read_buffer, sizeof(read_buffer));
350     if (read == -1) {
351       MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
352     }
353     conn->OnReadHandler(read_buffer, IntToSize(read));
354   }
355 }
356 
EventCallback(struct bufferevent * bev,std::int16_t events,void * data)357 void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void *data) {
358   MS_EXCEPTION_IF_NULL(bev);
359   MS_EXCEPTION_IF_NULL(data);
360   struct evbuffer *output = bufferevent_get_output(bev);
361   MS_EXCEPTION_IF_NULL(output);
362   auto conn = static_cast<class TcpConnection *>(data);
363   auto srv = const_cast<TcpServer *>(conn->GetServer());
364   MS_EXCEPTION_IF_NULL(srv);
365 
366   if (events & BEV_EVENT_EOF) {
367     MS_LOG(INFO) << "Event buffer end of file, a client is disconnected from this server!";
368     // Notify about disconnection
369     if (srv->client_disconnection_) {
370       srv->client_disconnection_(*srv, *conn);
371     }
372     // Free connection structures
373     srv->RemoveConnection(conn->GetFd());
374   } else if (events & BEV_EVENT_ERROR) {
375     MS_LOG(WARNING) << "Connect to server error.";
376     if (PSContext::instance()->enable_ssl()) {
377       uint64_t err = bufferevent_get_openssl_error(bev);
378       MS_LOG(WARNING) << "The error number is:" << err;
379 
380       MS_LOG(WARNING) << "Error message:" << ERR_reason_error_string(err)
381                       << ", the error lib:" << ERR_lib_error_string(err)
382                       << ", the error func:" << ERR_func_error_string(err);
383     }
384     // Free connection structures
385     srv->RemoveConnection(conn->GetFd());
386 
387     // Notify about disconnection
388     if (srv->client_disconnection_) {
389       srv->client_disconnection_(*srv, *conn);
390     }
391   } else {
392     MS_LOG(WARNING) << "Unhandled event:" << events;
393   }
394 }
395 
TimerCallback(evutil_socket_t,int16_t,void * arg)396 void TcpServer::TimerCallback(evutil_socket_t, int16_t, void *arg) {
397   MS_EXCEPTION_IF_NULL(arg);
398   auto tcp_server = reinterpret_cast<TcpServer *>(arg);
399   if (tcp_server->on_timer_callback_) {
400     tcp_server->on_timer_callback_();
401   }
402 }
403 
TimerOnceCallback(evutil_socket_t,int16_t,void * arg)404 void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) {
405   MS_EXCEPTION_IF_NULL(arg);
406   auto tcp_server = reinterpret_cast<TcpServer *>(arg);
407   if (tcp_server->on_timer_once_callback_) {
408     tcp_server->on_timer_once_callback_(*tcp_server);
409   }
410 }
411 
SetTcpNoDelay(const evutil_socket_t & fd)412 void TcpServer::SetTcpNoDelay(const evutil_socket_t &fd) {
413   const int one = 1;
414   int ret = setsockopt(fd, static_cast<int>(IPPROTO_TCP), static_cast<int>(TCP_NODELAY), &one, sizeof(int));
415   if (ret < 0) {
416     MS_LOG(EXCEPTION) << "Set socket no delay failed!";
417   }
418 }
419 
SendMessage(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<CommMessage> & message)420 bool TcpServer::SendMessage(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<CommMessage> &message) {
421   MS_EXCEPTION_IF_NULL(conn);
422   MS_EXCEPTION_IF_NULL(message);
423   return conn->SendMessage(message);
424 }
425 
SendMessage(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size)426 bool TcpServer::SendMessage(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
427                             const Protos &protos, const void *data, size_t size) {
428   MS_EXCEPTION_IF_NULL(conn);
429   MS_EXCEPTION_IF_NULL(meta);
430   MS_EXCEPTION_IF_NULL(data);
431   return conn->SendMessage(meta, protos, data, size);
432 }
433 
SendMessage(const std::shared_ptr<CommMessage> & message)434 void TcpServer::SendMessage(const std::shared_ptr<CommMessage> &message) {
435   MS_EXCEPTION_IF_NULL(message);
436   std::lock_guard<std::mutex> lock(connection_mutex_);
437 
438   for (auto it = connections_.begin(); it != connections_.end(); ++it) {
439     SendMessage(it->second, message);
440   }
441 }
442 
BoundPort() const443 uint16_t TcpServer::BoundPort() const { return server_port_; }
444 
BoundIp() const445 std::string TcpServer::BoundIp() const { return server_address_; }
446 
ConnectionNum() const447 int TcpServer::ConnectionNum() const { return SizeToInt(connections_.size()); }
448 
Connections() const449 const std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> &TcpServer::Connections() const { return connections_; }
450 
SetMessageCallback(const OnServerReceiveMessage & cb)451 void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
452 }  // namespace core
453 }  // namespace ps
454 }  // namespace mindspore
455