1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/core/communicator/tcp_client.h"
18
19 #include <arpa/inet.h>
20 #include <event2/buffer.h>
21 #include <event2/buffer_compat.h>
22 #include <event2/bufferevent.h>
23 #include <event2/event.h>
24 #include <netinet/in.h>
25 #include <netinet/tcp.h>
26 #include <sys/socket.h>
27 #include <cstdlib>
28 #include <cstring>
29 #include <iostream>
30 #include <string>
31 #include <utility>
32
33 namespace mindspore {
34 namespace ps {
35 namespace core {
36 event_base *TcpClient::event_base_ = nullptr;
37 std::mutex TcpClient::event_base_mutex_;
38 bool TcpClient::is_started_ = false;
39
TcpClient(const std::string & address,std::uint16_t port,Configuration * const config)40 TcpClient::TcpClient(const std::string &address, std::uint16_t port, Configuration *const config)
41 : event_timeout_(nullptr),
42 buffer_event_(nullptr),
43 server_address_(std::move(address)),
44 server_port_(port),
45 is_stop_(true),
46 is_connected_(false),
47 config_(config) {
48 message_handler_.SetCallback(
49 [this](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) {
50 if (message_callback_) {
51 message_callback_(meta, protos, data, size);
52 }
53 });
54 }
55
~TcpClient()56 TcpClient::~TcpClient() {
57 if (buffer_event_) {
58 bufferevent_free(buffer_event_);
59 buffer_event_ = nullptr;
60 }
61 if (event_timeout_) {
62 event_free(event_timeout_);
63 event_timeout_ = nullptr;
64 }
65 }
66
GetServerAddress() const67 std::string TcpClient::GetServerAddress() const { return server_address_; }
68
set_disconnected_callback(const OnDisconnected & disconnected)69 void TcpClient::set_disconnected_callback(const OnDisconnected &disconnected) { disconnected_callback_ = disconnected; }
70
set_connected_callback(const OnConnected & connected)71 void TcpClient::set_connected_callback(const OnConnected &connected) { connected_callback_ = connected; }
72
WaitConnected(const uint32_t & connected_timeout)73 bool TcpClient::WaitConnected(const uint32_t &connected_timeout) {
74 std::unique_lock<std::mutex> lock(connection_mutex_);
75 bool res = connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout),
76 [this] { return this->is_connected_.load(); });
77 return res;
78 }
79
Init()80 void TcpClient::Init() {
81 std::lock_guard<std::mutex> lock(connection_mutex_);
82 if (buffer_event_) {
83 bufferevent_free(buffer_event_);
84 buffer_event_ = nullptr;
85 }
86 if (!CommUtil::CheckIp(server_address_)) {
87 MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!";
88 }
89
90 int result = evthread_use_pthreads();
91 if (result != 0) {
92 MS_LOG(EXCEPTION) << "Use event pthread failed!";
93 }
94 if (event_base_ == nullptr) {
95 event_base_ = event_base_new();
96 MS_EXCEPTION_IF_NULL(event_base_);
97 is_stop_ = false;
98 }
99
100 sockaddr_in sin{};
101 if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
102 MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
103 }
104 sin.sin_family = AF_INET;
105 sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
106 sin.sin_port = htons(server_port_);
107
108 if (!PSContext::instance()->enable_ssl()) {
109 MS_LOG(INFO) << "SSL is disable.";
110 buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
111 } else {
112 if (!EstablishSSL()) {
113 MS_LOG(WARNING) << "Establish SSL failed.";
114 return;
115 }
116 }
117
118 MS_EXCEPTION_IF_NULL(buffer_event_);
119
120 bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this);
121 if (bufferevent_enable(buffer_event_, EV_READ | EV_WRITE) == -1) {
122 MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!";
123 }
124
125 int result_code = bufferevent_socket_connect(buffer_event_, reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
126 if (result_code < 0) {
127 MS_LOG(EXCEPTION) << "Connect server ip:" << server_address_ << " and port: " << server_port_ << " is failed!";
128 }
129 }
130
StartWithDelay(int seconds)131 void TcpClient::StartWithDelay(int seconds) {
132 std::lock_guard<std::mutex> lock(connection_mutex_);
133 if (buffer_event_) {
134 return;
135 }
136
137 event_base_ = event_base_new();
138 MS_EXCEPTION_IF_NULL(event_base_);
139
140 timeval timeout_value{};
141 timeout_value.tv_sec = seconds;
142 timeout_value.tv_usec = 0;
143
144 event_timeout_ = evtimer_new(event_base_, TimeoutCallback, this);
145 MS_EXCEPTION_IF_NULL(event_timeout_);
146 if (evtimer_add(event_timeout_, &timeout_value) == -1) {
147 MS_LOG(EXCEPTION) << "Event timeout failed!";
148 }
149 }
150
Stop()151 void TcpClient::Stop() {
152 MS_EXCEPTION_IF_NULL(event_base_);
153 std::lock_guard<std::mutex> lock(connection_mutex_);
154 MS_LOG(INFO) << "Stop tcp client!";
155 int ret = event_base_loopbreak(event_base_);
156 if (ret != 0) {
157 MS_LOG(ERROR) << "Event base loop break failed!";
158 }
159 }
160
SetTcpNoDelay(const evutil_socket_t & fd)161 void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) {
162 const int one = 1;
163 int ret = setsockopt(fd, static_cast<int>(IPPROTO_TCP), static_cast<int>(TCP_NODELAY), &one, sizeof(int));
164 if (ret < 0) {
165 MS_LOG(EXCEPTION) << "Set socket no delay failed!";
166 }
167 }
168
TimeoutCallback(evutil_socket_t,std::int16_t,void * arg)169 void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *arg) {
170 MS_EXCEPTION_IF_NULL(arg);
171 auto tcp_client = reinterpret_cast<TcpClient *>(arg);
172 tcp_client->Init();
173 }
174
ReadCallback(struct bufferevent * bev,void * ctx)175 void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
176 MS_EXCEPTION_IF_NULL(bev);
177 MS_EXCEPTION_IF_NULL(ctx);
178 auto tcp_client = reinterpret_cast<TcpClient *>(ctx);
179
180 char read_buffer[kMessageChunkLength];
181 size_t read = 0;
182
183 while ((read = bufferevent_read(bev, &read_buffer, sizeof(read_buffer))) > 0) {
184 tcp_client->OnReadHandler(read_buffer, read);
185 }
186 }
187
OnReadHandler(const void * buf,size_t num)188 void TcpClient::OnReadHandler(const void *buf, size_t num) {
189 MS_EXCEPTION_IF_NULL(buf);
190 if (read_callback_) {
191 read_callback_(buf, num);
192 }
193 message_handler_.ReceiveMessage(buf, num);
194 }
195
TimerCallback(evutil_socket_t,int16_t,void * arg)196 void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) {
197 MS_EXCEPTION_IF_NULL(arg);
198 auto tcp_client = reinterpret_cast<TcpClient *>(arg);
199 if (tcp_client->on_timer_callback_) {
200 tcp_client->on_timer_callback_();
201 }
202 }
203
NotifyConnected()204 void TcpClient::NotifyConnected() {
205 MS_LOG(INFO) << "Client connected to the server!";
206 is_connected_ = true;
207 connection_cond_.notify_all();
208 }
209
EstablishSSL()210 bool TcpClient::EstablishSSL() {
211 MS_LOG(INFO) << "Enable ssl support.";
212
213 SSL *ssl = SSL_new(SSLClient::GetInstance().GetSSLCtx());
214 MS_ERROR_IF_NULL_W_RET_VAL(ssl, false);
215 MS_ERROR_IF_NULL_W_RET_VAL(event_base_, false);
216
217 buffer_event_ = bufferevent_openssl_socket_new(event_base_, -1, ssl, BUFFEREVENT_SSL_CONNECTING,
218 BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
219 return true;
220 }
221
EventCallback(struct bufferevent * bev,std::int16_t events,void * ptr)222 void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) {
223 MS_EXCEPTION_IF_NULL(bev);
224 MS_EXCEPTION_IF_NULL(ptr);
225 auto tcp_client = reinterpret_cast<TcpClient *>(ptr);
226 if (events & BEV_EVENT_CONNECTED) {
227 // Connected
228 if (tcp_client->connected_callback_) {
229 tcp_client->connected_callback_();
230 }
231 tcp_client->NotifyConnected();
232 evutil_socket_t fd = bufferevent_getfd(bev);
233 SetTcpNoDelay(fd);
234 MS_LOG(INFO) << "Client connected!";
235 } else if (events & BEV_EVENT_ERROR) {
236 MS_LOG(WARNING) << "The client will retry to connect to the server!";
237 if (tcp_client->disconnected_callback_) {
238 tcp_client->disconnected_callback_();
239 }
240 } else if (events & BEV_EVENT_EOF) {
241 MS_LOG(WARNING) << "Client connected end of file";
242 }
243 }
244
Start()245 void TcpClient::Start() {
246 event_base_mutex_.lock();
247 if (is_started_) {
248 event_base_mutex_.unlock();
249 return;
250 }
251 is_started_ = true;
252 event_base_mutex_.unlock();
253 MS_EXCEPTION_IF_NULL(event_base_);
254 int ret = event_base_dispatch(event_base_);
255 MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
256 MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
257 << "Event base dispatch failed with no events pending or active!";
258 MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
259 MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpected error code!";
260 }
261
StartWithNoBlock()262 void TcpClient::StartWithNoBlock() {
263 std::lock_guard<std::mutex> lock(connection_mutex_);
264 MS_LOG(INFO) << "Start tcp client with no block!";
265 MS_EXCEPTION_IF_NULL(event_base_);
266 int ret = event_base_loop(event_base_, EVLOOP_NONBLOCK);
267 MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
268 MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
269 MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
270 MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpected error code!";
271 }
272
SetMessageCallback(const OnMessage & cb)273 void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; }
274
SendMessage(const CommMessage & message) const275 bool TcpClient::SendMessage(const CommMessage &message) const {
276 MS_EXCEPTION_IF_NULL(buffer_event_);
277 bufferevent_lock(buffer_event_);
278 bool res = true;
279 size_t buf_size = message.ByteSizeLong();
280 uint32_t meta_size = SizeToUint(message.pb_meta().ByteSizeLong());
281 MessageHeader header;
282 header.message_proto_ = Protos::PROTOBUF;
283 header.message_length_ = buf_size;
284 header.message_meta_length_ = meta_size;
285 if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
286 MS_LOG(ERROR) << "Event buffer add header failed!";
287 res = false;
288 }
289 if (bufferevent_write(buffer_event_, message.pb_meta().SerializeAsString().data(), meta_size) == -1) {
290 MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
291 res = false;
292 }
293 if (bufferevent_write(buffer_event_, message.data().data(), message.data().length()) == -1) {
294 MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
295 res = false;
296 }
297 bufferevent_unlock(buffer_event_);
298 return res;
299 }
300
SendMessage(const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size)301 bool TcpClient::SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
302 size_t size) {
303 MS_EXCEPTION_IF_NULL(buffer_event_);
304 MS_EXCEPTION_IF_NULL(meta);
305 MS_EXCEPTION_IF_NULL(data);
306 bufferevent_lock(buffer_event_);
307 bool res = true;
308
309 MessageHeader header;
310 header.message_proto_ = protos;
311 header.message_meta_length_ = SizeToUint(meta->ByteSizeLong());
312 header.message_length_ = size + header.message_meta_length_;
313
314 if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
315 MS_LOG(ERROR) << "Event buffer add header failed!";
316 res = false;
317 }
318 if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) {
319 MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
320 res = false;
321 }
322 if (bufferevent_write(buffer_event_, data, size) == -1) {
323 MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
324 res = false;
325 }
326 int result = bufferevent_flush(buffer_event_, EV_READ | EV_WRITE, BEV_FLUSH);
327 if (result < 0) {
328 MS_LOG(ERROR) << "Bufferevent flush failed!";
329 res = false;
330 }
331 bufferevent_unlock(buffer_event_);
332 return res;
333 }
334
set_timer_callback(const OnTimer & timer)335 void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; }
336
eventbase() const337 const event_base &TcpClient::eventbase() const { return *event_base_; }
338 } // namespace core
339 } // namespace ps
340 } // namespace mindspore
341