1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/core/communicator/tcp_server.h"
18
19 #include <arpa/inet.h>
20 #include <event2/buffer.h>
21 #include <event2/buffer_compat.h>
22 #include <event2/bufferevent.h>
23 #include <event2/event.h>
24 #include <event2/listener.h>
25 #include <event2/util.h>
26 #include <netinet/in.h>
27 #include <netinet/tcp.h>
28 #include <sys/socket.h>
29 #include <csignal>
30 #include <utility>
31
32 namespace mindspore {
33 namespace ps {
34 namespace core {
~TcpConnection()35 TcpConnection::~TcpConnection() { bufferevent_free(buffer_event_); }
InitConnection(const messageReceive & callback)36 void TcpConnection::InitConnection(const messageReceive &callback) { tcp_message_handler_.SetCallback(callback); }
37
OnReadHandler(const void * buffer,size_t num)38 void TcpConnection::OnReadHandler(const void *buffer, size_t num) {
39 MS_EXCEPTION_IF_NULL(buffer);
40 tcp_message_handler_.ReceiveMessage(buffer, num);
41 }
42
SendMessage(const void * buffer,size_t num) const43 void TcpConnection::SendMessage(const void *buffer, size_t num) const {
44 MS_EXCEPTION_IF_NULL(buffer);
45 MS_EXCEPTION_IF_NULL(buffer_event_);
46 bufferevent_lock(buffer_event_);
47 if (bufferevent_write(buffer_event_, buffer, num) == -1) {
48 MS_LOG(ERROR) << "Write message to buffer event failed!";
49 }
50 bufferevent_unlock(buffer_event_);
51 }
52
GetServer() const53 const TcpServer *TcpConnection::GetServer() const { return server_; }
54
GetFd() const55 const evutil_socket_t &TcpConnection::GetFd() const { return fd_; }
56
set_callback(const Callback & callback)57 void TcpConnection::set_callback(const Callback &callback) { callback_ = callback; }
58
SendMessage(const std::shared_ptr<CommMessage> & message) const59 bool TcpConnection::SendMessage(const std::shared_ptr<CommMessage> &message) const {
60 MS_EXCEPTION_IF_NULL(buffer_event_);
61 MS_EXCEPTION_IF_NULL(message);
62 bufferevent_lock(buffer_event_);
63 bool res = true;
64 size_t buf_size = message->ByteSizeLong();
65 if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) {
66 MS_LOG(ERROR) << "Event buffer add header failed!";
67 res = false;
68 }
69 if (bufferevent_write(buffer_event_, message->SerializeAsString().data(), buf_size) == -1) {
70 MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
71 res = false;
72 }
73 bufferevent_unlock(buffer_event_);
74 return res;
75 }
76
SendMessage(const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size) const77 bool TcpConnection::SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
78 size_t size) const {
79 MS_EXCEPTION_IF_NULL(buffer_event_);
80 MS_EXCEPTION_IF_NULL(meta);
81 MS_EXCEPTION_IF_NULL(data);
82 bufferevent_lock(buffer_event_);
83 bool res = true;
84 MessageHeader header;
85 header.message_proto_ = protos;
86 header.message_meta_length_ = SizeToUint(meta->ByteSizeLong());
87 header.message_length_ = size + header.message_meta_length_;
88
89 if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
90 MS_LOG(ERROR) << "Event buffer add header failed!";
91 res = false;
92 }
93 if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) {
94 MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
95 res = false;
96 }
97 if (bufferevent_write(buffer_event_, data, size) == -1) {
98 MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
99 res = false;
100 }
101 int result = bufferevent_flush(buffer_event_, EV_READ | EV_WRITE, BEV_FLUSH);
102 if (result < 0) {
103 bufferevent_unlock(buffer_event_);
104 MS_LOG(EXCEPTION) << "Bufferevent flush failed!";
105 }
106 bufferevent_unlock(buffer_event_);
107 return res;
108 }
109
TcpServer(const std::string & address,std::uint16_t port,Configuration * const config)110 TcpServer::TcpServer(const std::string &address, std::uint16_t port, Configuration *const config)
111 : base_(nullptr),
112 signal_event_(nullptr),
113 listener_(nullptr),
114 server_address_(std::move(address)),
115 server_port_(port),
116 is_stop_(true),
117 config_(config),
118 max_connection_(0) {}
119
~TcpServer()120 TcpServer::~TcpServer() {
121 if (signal_event_ != nullptr) {
122 event_free(signal_event_);
123 signal_event_ = nullptr;
124 }
125
126 if (listener_ != nullptr) {
127 evconnlistener_free(listener_);
128 listener_ = nullptr;
129 }
130
131 if (base_ != nullptr) {
132 event_base_free(base_);
133 base_ = nullptr;
134 }
135 }
136
SetServerCallback(const OnConnected & client_conn,const OnDisconnected & client_disconn,const OnAccepted & client_accept)137 void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
138 const OnAccepted &client_accept) {
139 this->client_connection_ = client_conn;
140 this->client_disconnection_ = client_disconn;
141 this->client_accept_ = client_accept;
142 }
143
set_timer_once_callback(const OnTimerOnce & timer)144 void TcpServer::set_timer_once_callback(const OnTimerOnce &timer) { on_timer_once_callback_ = timer; }
145
set_timer_callback(const OnTimer & timer)146 void TcpServer::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; }
147
Init()148 void TcpServer::Init() {
149 int result = evthread_use_pthreads();
150 if (result != 0) {
151 MS_LOG(EXCEPTION) << "Use event pthread failed!";
152 }
153
154 is_stop_ = false;
155 base_ = event_base_new();
156 MS_EXCEPTION_IF_NULL(base_);
157 if (!CommUtil::CheckIp(server_address_)) {
158 MS_LOG(EXCEPTION) << "The tcp server ip:" << server_address_ << " is illegal!";
159 }
160 MS_EXCEPTION_IF_NULL(config_);
161 max_connection_ = kConnectionNumDefault;
162 if (config_->Exists(kConnectionNum)) {
163 max_connection_ = config_->GetInt(kConnectionNum, 0);
164 }
165 MS_LOG(INFO) << "The max connection is:" << max_connection_;
166
167 struct sockaddr_in sin {};
168 if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
169 MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
170 }
171 sin.sin_family = AF_INET;
172 sin.sin_port = htons(server_port_);
173 sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
174
175 listener_ = evconnlistener_new_bind(base_, ListenerCallback, reinterpret_cast<void *>(this),
176 LEV_OPT_REUSEABLE | LEV_OPT_CLOSE_ON_FREE, -1,
177 reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
178
179 MS_EXCEPTION_IF_NULL(listener_);
180
181 if (server_port_ == 0) {
182 struct sockaddr_in sin_bound {};
183 if (memset_s(&sin, sizeof(sin_bound), 0, sizeof(sin_bound)) != EOK) {
184 MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
185 }
186 socklen_t addr_len = sizeof(struct sockaddr_in);
187 if (getsockname(evconnlistener_get_fd(listener_), (struct sockaddr *)&sin_bound, &addr_len) != 0) {
188 MS_LOG(EXCEPTION) << "Get sock name failed!";
189 }
190 server_port_ = htons(sin_bound.sin_port);
191 }
192
193 signal_event_ = evsignal_new(base_, SIGINT, SignalCallback, reinterpret_cast<void *>(this));
194 MS_EXCEPTION_IF_NULL(signal_event_);
195 if (event_add(signal_event_, nullptr) < 0) {
196 MS_LOG(EXCEPTION) << "Cannot create signal event.";
197 }
198 }
199
Start()200 void TcpServer::Start() {
201 MS_LOG(INFO) << "Start tcp server!";
202 MS_EXCEPTION_IF_NULL(base_);
203 int ret = event_base_dispatch(base_);
204 MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
205 MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
206 << "Event base dispatch failed with no events pending or active!";
207 MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
208 MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpected error code!";
209 }
210
StartWithNoBlock()211 void TcpServer::StartWithNoBlock() {
212 std::lock_guard<std::mutex> lock(connection_mutex_);
213 MS_LOG(INFO) << "Start tcp server with no block!";
214 MS_EXCEPTION_IF_NULL(base_);
215 int ret = event_base_loop(base_, EVLOOP_NONBLOCK);
216 MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
217 MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
218 MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
219 MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpected error code!";
220 }
221
Stop()222 void TcpServer::Stop() {
223 MS_EXCEPTION_IF_NULL(base_);
224 std::lock_guard<std::mutex> lock(connection_mutex_);
225 MS_LOG(INFO) << "Stop tcp server!";
226 if (event_base_got_break(base_)) {
227 MS_LOG(DEBUG) << "The event base has stopped!";
228 is_stop_ = true;
229 return;
230 }
231 if (!is_stop_.load()) {
232 is_stop_ = true;
233 int ret = event_base_loopbreak(base_);
234 if (ret != 0) {
235 MS_LOG(ERROR) << "Event base loop break failed!";
236 }
237 }
238 }
239
SendToAllClients(const char * data,size_t len)240 void TcpServer::SendToAllClients(const char *data, size_t len) {
241 MS_EXCEPTION_IF_NULL(data);
242 std::lock_guard<std::mutex> lock(connection_mutex_);
243 for (auto it = connections_.begin(); it != connections_.end(); ++it) {
244 it->second->SendMessage(data, len);
245 }
246 }
247
AddConnection(const evutil_socket_t & fd,std::shared_ptr<TcpConnection> connection)248 void TcpServer::AddConnection(const evutil_socket_t &fd, std::shared_ptr<TcpConnection> connection) {
249 MS_EXCEPTION_IF_NULL(connection);
250 std::lock_guard<std::mutex> lock(connection_mutex_);
251 connections_.insert(std::make_pair(fd, connection));
252 }
253
RemoveConnection(const evutil_socket_t & fd)254 void TcpServer::RemoveConnection(const evutil_socket_t &fd) {
255 std::lock_guard<std::mutex> lock(connection_mutex_);
256 connections_.erase(fd);
257 }
258
GetConnectionByFd(const evutil_socket_t & fd)259 std::shared_ptr<TcpConnection> TcpServer::GetConnectionByFd(const evutil_socket_t &fd) { return connections_[fd]; }
260
ListenerCallback(struct evconnlistener *,evutil_socket_t fd,struct sockaddr * sockaddr,int,void * data)261 void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *sockaddr, int,
262 void *data) {
263 auto server = reinterpret_cast<class TcpServer *>(data);
264 MS_EXCEPTION_IF_NULL(server);
265 auto base = reinterpret_cast<struct event_base *>(server->base_);
266 MS_EXCEPTION_IF_NULL(base);
267 MS_EXCEPTION_IF_NULL(sockaddr);
268
269 if (server->ConnectionNum() >= server->max_connection_) {
270 MS_LOG(WARNING) << "The current connection num:" << server->ConnectionNum() << " is greater or equal to "
271 << server->max_connection_;
272 return;
273 }
274
275 struct bufferevent *bev = nullptr;
276
277 if (!PSContext::instance()->enable_ssl()) {
278 MS_LOG(INFO) << "SSL is disable.";
279 bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
280 } else {
281 MS_LOG(INFO) << "Enable ssl support.";
282 SSL *ssl = SSL_new(SSLWrapper::GetInstance().GetSSLCtx());
283 MS_EXCEPTION_IF_NULL(ssl);
284 bev = bufferevent_openssl_socket_new(base, fd, ssl, BUFFEREVENT_SSL_ACCEPTING,
285 BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
286 }
287 if (bev == nullptr) {
288 MS_LOG(ERROR) << "Error constructing buffer event!";
289 int ret = event_base_loopbreak(base);
290 if (ret != 0) {
291 MS_LOG(EXCEPTION) << "event base loop break failed!";
292 }
293 return;
294 }
295
296 std::shared_ptr<TcpConnection> conn = server->onCreateConnection(bev, fd);
297 MS_EXCEPTION_IF_NULL(conn);
298 SetTcpNoDelay(fd);
299 server->AddConnection(fd, conn);
300 conn->InitConnection(
301 [=](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) {
302 OnServerReceiveMessage on_server_receive = server->GetServerReceive();
303 if (on_server_receive) {
304 on_server_receive(conn, meta, protos, data, size);
305 }
306 });
307 bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback,
308 reinterpret_cast<void *>(conn.get()));
309 if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
310 MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!";
311 }
312 }
313
onCreateConnection(struct bufferevent * bev,const evutil_socket_t & fd)314 std::shared_ptr<TcpConnection> TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) {
315 MS_EXCEPTION_IF_NULL(bev);
316 std::shared_ptr<TcpConnection> conn = nullptr;
317 if (client_accept_) {
318 conn = (client_accept_(*this));
319 } else {
320 conn = std::make_shared<TcpConnection>(bev, fd, this);
321 }
322
323 return conn;
324 }
325
GetServerReceive() const326 OnServerReceiveMessage TcpServer::GetServerReceive() const { return message_callback_; }
327
SignalCallback(evutil_socket_t,std::int16_t,void * data)328 void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *data) {
329 MS_EXCEPTION_IF_NULL(data);
330 auto server = reinterpret_cast<class TcpServer *>(data);
331 struct event_base *base = server->base_;
332 MS_EXCEPTION_IF_NULL(base);
333 struct timeval delay = {0, 0};
334 MS_LOG(ERROR) << "Caught an interrupt signal; exiting cleanly in 0 seconds.";
335 if (event_base_loopexit(base, &delay) == -1) {
336 MS_LOG(ERROR) << "Event base loop exit failed.";
337 }
338 }
339
ReadCallback(struct bufferevent * bev,void * connection)340 void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) {
341 MS_EXCEPTION_IF_NULL(bev);
342 MS_EXCEPTION_IF_NULL(connection);
343
344 auto conn = static_cast<class TcpConnection *>(connection);
345 struct evbuffer *buf = bufferevent_get_input(bev);
346 MS_EXCEPTION_IF_NULL(buf);
347 char read_buffer[kMessageChunkLength];
348 while (EVBUFFER_LENGTH(buf) > 0) {
349 int read = evbuffer_remove(buf, &read_buffer, sizeof(read_buffer));
350 if (read == -1) {
351 MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
352 }
353 conn->OnReadHandler(read_buffer, IntToSize(read));
354 }
355 }
356
EventCallback(struct bufferevent * bev,std::int16_t events,void * data)357 void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void *data) {
358 MS_EXCEPTION_IF_NULL(bev);
359 MS_EXCEPTION_IF_NULL(data);
360 struct evbuffer *output = bufferevent_get_output(bev);
361 MS_EXCEPTION_IF_NULL(output);
362 auto conn = static_cast<class TcpConnection *>(data);
363 auto srv = const_cast<TcpServer *>(conn->GetServer());
364 MS_EXCEPTION_IF_NULL(srv);
365
366 if (events & BEV_EVENT_EOF) {
367 MS_LOG(INFO) << "Event buffer end of file, a client is disconnected from this server!";
368 // Notify about disconnection
369 if (srv->client_disconnection_) {
370 srv->client_disconnection_(*srv, *conn);
371 }
372 // Free connection structures
373 srv->RemoveConnection(conn->GetFd());
374 } else if (events & BEV_EVENT_ERROR) {
375 MS_LOG(WARNING) << "Connect to server error.";
376 if (PSContext::instance()->enable_ssl()) {
377 uint64_t err = bufferevent_get_openssl_error(bev);
378 MS_LOG(WARNING) << "The error number is:" << err;
379
380 MS_LOG(WARNING) << "Error message:" << ERR_reason_error_string(err)
381 << ", the error lib:" << ERR_lib_error_string(err)
382 << ", the error func:" << ERR_func_error_string(err);
383 }
384 // Free connection structures
385 srv->RemoveConnection(conn->GetFd());
386
387 // Notify about disconnection
388 if (srv->client_disconnection_) {
389 srv->client_disconnection_(*srv, *conn);
390 }
391 } else {
392 MS_LOG(WARNING) << "Unhandled event:" << events;
393 }
394 }
395
TimerCallback(evutil_socket_t,int16_t,void * arg)396 void TcpServer::TimerCallback(evutil_socket_t, int16_t, void *arg) {
397 MS_EXCEPTION_IF_NULL(arg);
398 auto tcp_server = reinterpret_cast<TcpServer *>(arg);
399 if (tcp_server->on_timer_callback_) {
400 tcp_server->on_timer_callback_();
401 }
402 }
403
TimerOnceCallback(evutil_socket_t,int16_t,void * arg)404 void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) {
405 MS_EXCEPTION_IF_NULL(arg);
406 auto tcp_server = reinterpret_cast<TcpServer *>(arg);
407 if (tcp_server->on_timer_once_callback_) {
408 tcp_server->on_timer_once_callback_(*tcp_server);
409 }
410 }
411
SetTcpNoDelay(const evutil_socket_t & fd)412 void TcpServer::SetTcpNoDelay(const evutil_socket_t &fd) {
413 const int one = 1;
414 int ret = setsockopt(fd, static_cast<int>(IPPROTO_TCP), static_cast<int>(TCP_NODELAY), &one, sizeof(int));
415 if (ret < 0) {
416 MS_LOG(EXCEPTION) << "Set socket no delay failed!";
417 }
418 }
419
SendMessage(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<CommMessage> & message)420 bool TcpServer::SendMessage(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<CommMessage> &message) {
421 MS_EXCEPTION_IF_NULL(conn);
422 MS_EXCEPTION_IF_NULL(message);
423 return conn->SendMessage(message);
424 }
425
SendMessage(const std::shared_ptr<TcpConnection> & conn,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size)426 bool TcpServer::SendMessage(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
427 const Protos &protos, const void *data, size_t size) {
428 MS_EXCEPTION_IF_NULL(conn);
429 MS_EXCEPTION_IF_NULL(meta);
430 MS_EXCEPTION_IF_NULL(data);
431 return conn->SendMessage(meta, protos, data, size);
432 }
433
SendMessage(const std::shared_ptr<CommMessage> & message)434 void TcpServer::SendMessage(const std::shared_ptr<CommMessage> &message) {
435 MS_EXCEPTION_IF_NULL(message);
436 std::lock_guard<std::mutex> lock(connection_mutex_);
437
438 for (auto it = connections_.begin(); it != connections_.end(); ++it) {
439 SendMessage(it->second, message);
440 }
441 }
442
BoundPort() const443 uint16_t TcpServer::BoundPort() const { return server_port_; }
444
BoundIp() const445 std::string TcpServer::BoundIp() const { return server_address_; }
446
ConnectionNum() const447 int TcpServer::ConnectionNum() const { return SizeToInt(connections_.size()); }
448
Connections() const449 const std::map<evutil_socket_t, std::shared_ptr<TcpConnection>> &TcpServer::Connections() const { return connections_; }
450
SetMessageCallback(const OnServerReceiveMessage & cb)451 void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
452 } // namespace core
453 } // namespace ps
454 } // namespace mindspore
455