1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/core/communicator/tcp_server.h"
18
19 #include <arpa/inet.h>
20 #include <event2/buffer.h>
21 #include <event2/buffer_compat.h>
22 #include <event2/bufferevent.h>
23 #include <event2/event.h>
24 #include <event2/listener.h>
25 #include <event2/util.h>
26 #include <netinet/in.h>
27 #include <netinet/tcp.h>
28 #include <sys/socket.h>
29 #include <csignal>
30 #include <utility>
31
32 namespace mindspore {
33 namespace ps {
34 namespace core {
~TcpConnection()35 TcpConnection::~TcpConnection() {
36 MS_LOG(WARNING) << "TcpConnection is destructed! fd is " << fd_;
37 bufferevent_free(buffer_event_);
38 }
InitConnection(const messageReceive & callback)39 void TcpConnection::InitConnection(const messageReceive &callback) { tcp_message_handler_.SetCallback(callback); }
40
OnReadHandler(const void * buffer,size_t num)41 void TcpConnection::OnReadHandler(const void *buffer, size_t num) {
42 MS_EXCEPTION_IF_NULL(buffer);
43 tcp_message_handler_.ReceiveMessage(buffer, num);
44 }
45
SendMessage(const void * buffer,size_t num) const46 void TcpConnection::SendMessage(const void *buffer, size_t num) const {
47 MS_EXCEPTION_IF_NULL(buffer);
48 MS_EXCEPTION_IF_NULL(buffer_event_);
49 bufferevent_lock(buffer_event_);
50 if (bufferevent_write(buffer_event_, buffer, num) == -1) {
51 MS_LOG(ERROR) << "Write message to buffer event failed!";
52 }
53 bufferevent_unlock(buffer_event_);
54 }
55
GetServer() const56 const TcpServer *TcpConnection::GetServer() const { return server_; }
57
GetFd() const58 const evutil_socket_t &TcpConnection::GetFd() const { return fd_; }
59
set_callback(const Callback & callback)60 void TcpConnection::set_callback(const Callback &callback) { callback_ = callback; }
61
SendMessage(const std::shared_ptr<CommMessage> & message) const62 bool TcpConnection::SendMessage(const std::shared_ptr<CommMessage> &message) const {
63 MS_EXCEPTION_IF_NULL(buffer_event_);
64 MS_EXCEPTION_IF_NULL(message);
65 bufferevent_lock(buffer_event_);
66 bool res = true;
67 size_t buf_size = message->ByteSizeLong();
68 if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) {
69 MS_LOG(ERROR) << "Event buffer add header failed!";
70 res = false;
71 }
72 if (bufferevent_write(buffer_event_, message->SerializeAsString().data(), buf_size) == -1) {
73 MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
74 res = false;
75 }
76 bufferevent_unlock(buffer_event_);
77 return res;
78 }
79
SendMessage(const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size) const80 bool TcpConnection::SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
81 size_t size) const {
82 MS_EXCEPTION_IF_NULL(buffer_event_);
83 MS_EXCEPTION_IF_NULL(meta);
84 MS_EXCEPTION_IF_NULL(data);
85 bufferevent_lock(buffer_event_);
86 bool res = true;
87 MessageHeader header;
88 header.message_proto_ = protos;
89 header.message_meta_length_ = SizeToUint(meta->ByteSizeLong());
90 header.message_length_ = size + header.message_meta_length_;
91
92 if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
93 MS_LOG(ERROR) << "Event buffer add header failed!";
94 res = false;
95 }
96 if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) {
97 MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
98 res = false;
99 }
100 if (bufferevent_write(buffer_event_, data, size) == -1) {
101 MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
102 res = false;
103 }
104 int result = bufferevent_flush(buffer_event_, EV_READ | EV_WRITE, BEV_FLUSH);
105 if (result < 0) {
106 bufferevent_unlock(buffer_event_);
107 MS_LOG(EXCEPTION) << "Bufferevent flush failed!";
108 }
109 bufferevent_unlock(buffer_event_);
110 return res;
111 }
112
TcpServer(const std::string & address,std::uint16_t port,Configuration * const config,const std::pair<uint32_t,uint32_t> & port_range)113 TcpServer::TcpServer(const std::string &address, std::uint16_t port, Configuration *const config,
114 const std::pair<uint32_t, uint32_t> &port_range)
115 : base_(nullptr),
116 signal_event_(nullptr),
117 listener_(nullptr),
118 server_address_(std::move(address)),
119 server_port_(port),
120 is_stop_(true),
121 config_(config),
122 max_connection_(0),
123 port_range_(port_range) {}
124
~TcpServer()125 TcpServer::~TcpServer() {
126 if (signal_event_ != nullptr) {
127 event_free(signal_event_);
128 signal_event_ = nullptr;
129 }
130
131 if (listener_ != nullptr) {
132 evconnlistener_free(listener_);
133 listener_ = nullptr;
134 }
135
136 if (base_ != nullptr) {
137 event_base_free(base_);
138 base_ = nullptr;
139 }
140 }
141
SetServerCallback(const OnConnected & client_conn,const OnDisconnected & client_disconn,const OnAccepted & client_accept)142 void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
143 const OnAccepted &client_accept) {
144 this->client_connection_ = client_conn;
145 this->client_disconnection_ = client_disconn;
146 this->client_accept_ = client_accept;
147 }
148
Init()149 void TcpServer::Init() {
150 if (PSContext::instance()->enable_ssl()) {
151 MS_LOG(INFO) << "Init ssl.";
152 SSLWrapper::GetInstance().InitSSL();
153 }
154 int result = evthread_use_pthreads();
155 if (result != 0) {
156 MS_LOG(EXCEPTION) << "Use event pthread failed!";
157 }
158
159 is_stop_ = false;
160 base_ = event_base_new();
161 MS_EXCEPTION_IF_NULL(base_);
162 if (!CommUtil::CheckIp(server_address_)) {
163 MS_LOG(EXCEPTION) << "The tcp server ip:" << server_address_ << " is illegal!";
164 }
165 MS_EXCEPTION_IF_NULL(config_);
166 max_connection_ = kConnectionNumDefault;
167 if (config_->Exists(kConnectionNum)) {
168 max_connection_ = config_->GetInt(kConnectionNum, 0);
169 }
170 MS_LOG(INFO) << "The max connection is:" << max_connection_;
171
172 struct sockaddr_in sin {};
173 if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
174 MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
175 }
176 sin.sin_family = AF_INET;
177 sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
178 // We do not use server_port_ because it's always 0.
179 MS_LOG(INFO) << "Initialize TcpServer with port range " << port_range_.first << " to " << port_range_.second;
180 uint16_t current_port = static_cast<uint16_t>(port_range_.first);
181 do {
182 sin.sin_port = htons(current_port);
183 listener_ = evconnlistener_new_bind(base_, ListenerCallback, reinterpret_cast<void *>(this),
184 LEV_OPT_REUSEABLE | LEV_OPT_CLOSE_ON_FREE, -1,
185 reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
186 if (listener_ == nullptr) {
187 current_port++;
188 MS_LOG(WARNING) << "The port " << current_port << " is already in use. So increase port to: " << current_port;
189 if (current_port > port_range_.second) {
190 MS_LOG(EXCEPTION) << "Port range " << port_range_.first << " to " << port_range_.second
191 << " are all in use already. You can run 'netstat -anp|grep <port number>' command to check "
192 "which process occupies the port.";
193 }
194 }
195 } while (listener_ == nullptr);
196
197 if (server_port_ == 0) {
198 struct sockaddr_in sin_bound {};
199 if (memset_s(&sin, sizeof(sin_bound), 0, sizeof(sin_bound)) != EOK) {
200 MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
201 }
202 socklen_t addr_len = sizeof(struct sockaddr_in);
203 if (getsockname(evconnlistener_get_fd(listener_), (struct sockaddr *)&sin_bound, &addr_len) != 0) {
204 MS_LOG(EXCEPTION) << "Get sock name failed!";
205 }
206 server_port_ = htons(sin_bound.sin_port);
207 }
208
209 signal_event_ = evsignal_new(base_, SIGINT, SignalCallback, reinterpret_cast<void *>(this));
210 MS_EXCEPTION_IF_NULL(signal_event_);
211 if (event_add(signal_event_, nullptr) < 0) {
212 MS_LOG(EXCEPTION) << "Cannot create signal event.";
213 }
214 }
215
Start()216 void TcpServer::Start() {
217 MS_LOG(INFO) << "Start tcp server!";
218 MS_EXCEPTION_IF_NULL(base_);
219 int ret = event_base_dispatch(base_);
220 MSLOG_IF(MsLogLevel::kInfo, ret == 0, NoExceptionType, nullptr) << "Event base dispatch success!";
221 MSLOG_IF(MsLogLevel::kError, ret == 1, NoExceptionType, nullptr)
222 << "Event base dispatch failed with no events pending or active!";
223 MSLOG_IF(MsLogLevel::kError, ret == -1, NoExceptionType, nullptr)
224 << "Event base dispatch failed with error occurred!";
225 MSLOG_IF(MsLogLevel::kException, ret < -1, AbortedError, nullptr)
226 << "Event base dispatch with unexpected error code!";
227 }
228
Stop()229 void TcpServer::Stop() {
230 MS_ERROR_IF_NULL_WO_RET_VAL(base_);
231 std::lock_guard<std::mutex> lock(connection_mutex_);
232 MS_LOG(INFO) << "Stop tcp server!";
233 if (event_base_got_break(base_)) {
234 MS_LOG(DEBUG) << "The event base has already been stopped!";
235 is_stop_ = true;
236 return;
237 }
238 if (!is_stop_.load()) {
239 is_stop_ = true;
240 int ret = event_base_loopbreak(base_);
241 if (ret != 0) {
242 MS_LOG(ERROR) << "Event base loop break failed!";
243 }
244 }
245 }
246
SendToAllClients(const char * data,size_t len)247 void TcpServer::SendToAllClients(const char *data, size_t len) {
248 MS_EXCEPTION_IF_NULL(data);
249 std::lock_guard<std::mutex> lock(connection_mutex_);
250 for (auto it = connections_.begin(); it != connections_.end(); ++it) {
251 it->second->SendMessage(data, len);
252 }
253 }
254
AddConnection(const evutil_socket_t & fd,std::shared_ptr<TcpConnection> connection)255 void TcpServer::AddConnection(const evutil_socket_t &fd, std::shared_ptr<TcpConnection> connection) {
256 MS_EXCEPTION_IF_NULL(connection);
257 std::lock_guard<std::mutex> lock(connection_mutex_);
258 (void)connections_.insert(std::make_pair(fd, connection));
259 }
260
RemoveConnection(const evutil_socket_t & fd)261 void TcpServer::RemoveConnection(const evutil_socket_t &fd) {
262 std::lock_guard<std::mutex> lock(connection_mutex_);
263 MS_LOG(INFO) << "Remove connection fd: " << fd;
264 (void)connections_.erase(fd);
265 }
266
GetConnectionByFd(const evutil_socket_t & fd)267 std::shared_ptr<TcpConnection> &TcpServer::GetConnectionByFd(const evutil_socket_t &fd) { return connections_[fd]; }
268
ListenerCallback(struct evconnlistener *,evutil_socket_t fd,struct sockaddr * sockaddr,int,void * const data)269 void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *sockaddr, int,
270 void *const data) {
271 try {
272 ListenerCallbackInner(fd, sockaddr, data);
273 } catch (const std::exception &e) {
274 MS_LOG(ERROR) << "Catch exception: " << e.what();
275 }
276 }
277
ListenerCallbackInner(evutil_socket_t fd,struct sockaddr * sockaddr,void * const data)278 void TcpServer::ListenerCallbackInner(evutil_socket_t fd, struct sockaddr *sockaddr, void *const data) {
279 auto server = reinterpret_cast<class TcpServer *>(data);
280 MS_EXCEPTION_IF_NULL(server);
281 auto base = reinterpret_cast<struct event_base *>(server->base_);
282 MS_EXCEPTION_IF_NULL(base);
283 MS_EXCEPTION_IF_NULL(sockaddr);
284
285 if (server->ConnectionNum() >= server->max_connection_) {
286 MS_LOG(WARNING) << "The current connection num:" << server->ConnectionNum() << " is greater or equal to "
287 << server->max_connection_;
288 return;
289 }
290
291 struct bufferevent *bev = nullptr;
292
293 if (!PSContext::instance()->enable_ssl()) {
294 MS_LOG(INFO) << "SSL is disable.";
295 bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
296 } else {
297 MS_LOG(INFO) << "Enable ssl support.";
298 SSL *ssl = SSL_new(SSLWrapper::GetInstance().GetSSLCtx());
299 MS_EXCEPTION_IF_NULL(ssl);
300 bev = bufferevent_openssl_socket_new(base, fd, ssl, BUFFEREVENT_SSL_ACCEPTING,
301 BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
302 }
303 if (bev == nullptr) {
304 MS_LOG(ERROR) << "Error constructing buffer event!";
305 int ret = event_base_loopbreak(base);
306 if (ret != 0) {
307 MS_LOG(EXCEPTION) << "event base loop break failed!";
308 }
309 return;
310 }
311
312 std::shared_ptr<TcpConnection> conn = server->onCreateConnection(bev, fd);
313 MS_EXCEPTION_IF_NULL(conn);
314 SetTcpNoDelay(fd);
315 server->AddConnection(fd, conn);
316 conn->InitConnection(
317 [=](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) {
318 OnServerReceiveMessage on_server_receive = server->GetServerReceive();
319 if (on_server_receive) {
320 on_server_receive(conn, meta, protos, data, size);
321 }
322 });
323
324 bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback,
325 reinterpret_cast<void *>(conn.get()));
326 MS_LOG(INFO) << "A client is connected, fd is " << fd;
327 if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
328 MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!";
329 }
330 }
331
onCreateConnection(struct bufferevent * bev,const evutil_socket_t & fd)332 std::shared_ptr<TcpConnection> TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) {
333 MS_EXCEPTION_IF_NULL(bev);
334 std::shared_ptr<TcpConnection> conn = nullptr;
335 if (client_accept_) {
336 conn = (client_accept_(*this));
337 } else {
338 conn = std::make_shared<TcpConnection>(bev, fd, this);
339 }
340
341 return conn;
342 }
343
GetServerReceive() const344 OnServerReceiveMessage TcpServer::GetServerReceive() const { return message_callback_; }
345
SignalCallback(evutil_socket_t,std::int16_t,void * const data)346 void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *const data) {
347 try {
348 SignalCallbackInner(data);
349 } catch (const std::exception &e) {
350 MS_LOG(ERROR) << "Catch exception: " << e.what();
351 }
352 }
353
SignalCallbackInner(void * const data)354 void TcpServer::SignalCallbackInner(void *const data) {
355 MS_EXCEPTION_IF_NULL(data);
356 auto server = reinterpret_cast<class TcpServer *>(data);
357 struct event_base *base = server->base_;
358 MS_EXCEPTION_IF_NULL(base);
359 struct timeval delay = {0, 0};
360 MS_LOG(ERROR) << "Caught an interrupt signal; exiting cleanly in 0 seconds.";
361 if (event_base_loopexit(base, &delay) == -1) {
362 MS_LOG(ERROR) << "Event base loop exit failed.";
363 }
364 }
365
ReadCallback(struct bufferevent * bev,void * const connection)366 void TcpServer::ReadCallback(struct bufferevent *bev, void *const connection) {
367 try {
368 ReadCallbackInner(bev, connection);
369 } catch (const std::exception &e) {
370 MS_LOG(ERROR) << "Catch exception: " << e.what();
371 }
372 }
373
ReadCallbackInner(struct bufferevent * bev,void * const connection)374 void TcpServer::ReadCallbackInner(struct bufferevent *bev, void *const connection) {
375 MS_EXCEPTION_IF_NULL(bev);
376 MS_EXCEPTION_IF_NULL(connection);
377
378 auto conn = static_cast<class TcpConnection *>(connection);
379 struct evbuffer *buf = bufferevent_get_input(bev);
380 MS_EXCEPTION_IF_NULL(buf);
381 char read_buffer[kMessageChunkLength];
382 while (EVBUFFER_LENGTH(buf) > 0) {
383 int read = evbuffer_remove(buf, &read_buffer, sizeof(read_buffer));
384 if (read == -1) {
385 MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
386 }
387 conn->OnReadHandler(read_buffer, IntToSize(read));
388 }
389 }
390
EventCallback(struct bufferevent * bev,std::int16_t events,void * const data)391 void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void *const data) {
392 try {
393 EventCallbackInner(bev, events, data);
394 } catch (const std::exception &e) {
395 MS_LOG(ERROR) << "Catch exception: " << e.what();
396 }
397 }
398
EventCallbackInner(struct bufferevent * bev,std::int16_t events,void * const data)399 void TcpServer::EventCallbackInner(struct bufferevent *bev, std::int16_t events, void *const data) {
400 MS_EXCEPTION_IF_NULL(bev);
401 MS_EXCEPTION_IF_NULL(data);
402 struct evbuffer *output = bufferevent_get_output(bev);
403 MS_EXCEPTION_IF_NULL(output);
404 auto conn = static_cast<class TcpConnection *>(data);
405 auto srv = const_cast<TcpServer *>(conn->GetServer());
406 MS_EXCEPTION_IF_NULL(srv);
407
408 if (events & BEV_EVENT_EOF) {
409 MS_LOG(INFO) << "BEV_EVENT_EOF event is trigger!";
410 // Notify about disconnection
411 if (srv->client_disconnection_) {
412 srv->client_disconnection_(*srv, *conn);
413 }
414 // Free connection structures
415 srv->RemoveConnection(conn->GetFd());
416 } else if (events & BEV_EVENT_ERROR) {
417 MS_LOG(WARNING) << "BEV_EVENT_ERROR event is trigger!";
418 if (PSContext::instance()->enable_ssl()) {
419 uint64_t err = bufferevent_get_openssl_error(bev);
420 MS_LOG(WARNING) << "The error number is:" << err;
421
422 MS_LOG(WARNING) << "Error message:" << ERR_reason_error_string(err)
423 << ", the error lib:" << ERR_lib_error_string(err)
424 << ", the error func:" << ERR_func_error_string(err);
425 }
426 // Notify about disconnection
427 if (srv->client_disconnection_) {
428 srv->client_disconnection_(*srv, *conn);
429 }
430 // Free connection structures
431 srv->RemoveConnection(conn->GetFd());
432 } else {
433 MS_LOG(WARNING) << "Unhandled event:" << events;
434 }
435 }
436
SetTcpNoDelay(const evutil_socket_t & fd)437 void TcpServer::SetTcpNoDelay(const evutil_socket_t &fd) {
438 const int one = 1;
439 int ret = setsockopt(fd, static_cast<int>(IPPROTO_TCP), static_cast<int>(TCP_NODELAY), &one, sizeof(int));
440 if (ret < 0) {
441 MS_LOG(EXCEPTION) << "Set socket no delay failed!";
442 }
443 }
444
SendMessage(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<CommMessage> & message)445 bool TcpServer::SendMessage(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<CommMessage> &message) {
446 MS_EXCEPTION_IF_NULL(conn);
447 MS_EXCEPTION_IF_NULL(message);
448 return conn->SendMessage(message);
449 }
450
SendMessage(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size)451 bool TcpServer::SendMessage(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
452 const Protos &protos, const void *data, size_t size) {
453 MS_EXCEPTION_IF_NULL(conn);
454 MS_EXCEPTION_IF_NULL(meta);
455 MS_EXCEPTION_IF_NULL(data);
456 return conn->SendMessage(meta, protos, data, size);
457 }
458
SendMessage(const std::shared_ptr<CommMessage> & message)459 void TcpServer::SendMessage(const std::shared_ptr<CommMessage> &message) {
460 MS_EXCEPTION_IF_NULL(message);
461 std::lock_guard<std::mutex> lock(connection_mutex_);
462
463 for (auto it = connections_.begin(); it != connections_.end(); ++it) {
464 SendMessage(it->second, message);
465 }
466 }
467
BoundPort() const468 uint16_t TcpServer::BoundPort() const { return server_port_; }
469
BoundIp() const470 std::string TcpServer::BoundIp() const { return server_address_; }
471
ConnectionNum() const472 int TcpServer::ConnectionNum() const { return SizeToInt(connections_.size()); }
473
Connections() const474 const std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> &TcpServer::Connections() const { return connections_; }
475
SetMessageCallback(const OnServerReceiveMessage & cb)476 void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
477 } // namespace core
478 } // namespace ps
479 } // namespace mindspore
480