1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/core/communicator/tcp_client.h"
18
19 #include <arpa/inet.h>
20 #include <event2/buffer.h>
21 #include <event2/buffer_compat.h>
22 #include <event2/bufferevent.h>
23 #include <event2/event.h>
24 #include <netinet/in.h>
25 #include <netinet/tcp.h>
26 #include <sys/socket.h>
27 #include <cstdlib>
28 #include <cstring>
29 #include <iostream>
30 #include <string>
31 #include <utility>
32
33 namespace mindspore {
34 namespace ps {
35 namespace core {
36 event_base *TcpClient::event_base_ = nullptr;
37 std::mutex TcpClient::event_base_mutex_;
38 bool TcpClient::is_started_ = false;
39
TcpClient(const std::string & address,std::uint16_t port,NodeRole peer_role)40 TcpClient::TcpClient(const std::string &address, std::uint16_t port, NodeRole peer_role)
41 : event_timeout_(nullptr),
42 buffer_event_(nullptr),
43 server_address_(std::move(address)),
44 server_port_(port),
45 peer_role_(peer_role),
46 connection_status_(-1) {
47 message_handler_.SetCallback(
48 [this](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) {
49 if (message_callback_) {
50 message_callback_(meta, protos, data, size);
51 }
52 });
53 }
54
~TcpClient()55 TcpClient::~TcpClient() {
56 if (buffer_event_) {
57 bufferevent_free(buffer_event_);
58 buffer_event_ = nullptr;
59 }
60 if (event_timeout_) {
61 event_free(event_timeout_);
62 event_timeout_ = nullptr;
63 }
64 }
65
GetServerAddress() const66 std::string TcpClient::GetServerAddress() const { return server_address_; }
67
set_disconnected_callback(const OnDisconnected & disconnected)68 void TcpClient::set_disconnected_callback(const OnDisconnected &disconnected) { disconnected_callback_ = disconnected; }
69
set_connected_callback(const OnConnected & connected)70 void TcpClient::set_connected_callback(const OnConnected &connected) { connected_callback_ = connected; }
71
PeerRoleName() const72 std::string TcpClient::PeerRoleName() const {
73 switch (peer_role_) {
74 case SERVER:
75 return "Server";
76 case WORKER:
77 return "Worker";
78 case SCHEDULER:
79 return "Scheduler";
80 default:
81 return "RoleUndefined";
82 }
83 }
84
WaitConnected(const uint32_t & connected_timeout)85 bool TcpClient::WaitConnected(const uint32_t &connected_timeout) {
86 std::unique_lock<std::mutex> lock(connection_mutex_);
87 bool res = connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout),
88 [this] { return this->connection_status_ == 1; });
89 return res;
90 }
91
Init()92 void TcpClient::Init() {
93 std::lock_guard<std::mutex> lock(connection_mutex_);
94 if (connection_status_ != -1) {
95 return;
96 }
97 connection_status_ = 0;
98 if (buffer_event_) {
99 bufferevent_free(buffer_event_);
100 buffer_event_ = nullptr;
101 }
102 if (!CommUtil::CheckIp(server_address_)) {
103 MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!";
104 }
105
106 int result = evthread_use_pthreads();
107 if (result != 0) {
108 MS_LOG(EXCEPTION) << "Use event pthread failed!";
109 }
110 if (event_base_ == nullptr) {
111 event_base_ = event_base_new();
112 MS_EXCEPTION_IF_NULL(event_base_);
113 }
114
115 if (!PSContext::instance()->enable_ssl()) {
116 MS_LOG(INFO) << "SSL is disable.";
117 buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
118 } else {
119 if (!EstablishSSL()) {
120 MS_LOG(WARNING) << "Establish SSL failed.";
121 return;
122 }
123 }
124
125 MS_EXCEPTION_IF_NULL(buffer_event_);
126
127 bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this);
128 if (bufferevent_enable(buffer_event_, EV_READ | EV_WRITE) == -1) {
129 MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!";
130 }
131
132 if (server_port_ > 0) {
133 sockaddr_in sin{};
134 if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
135 MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
136 }
137 sin.sin_family = AF_INET;
138 sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
139 sin.sin_port = htons(server_port_);
140
141 int result_code = bufferevent_socket_connect(buffer_event_, reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
142 if (result_code < 0) {
143 MS_LOG(WARNING) << "Connect server ip:" << server_address_ << " and port: " << server_port_ << " is failed!";
144 }
145 }
146 }
147
StartWithDelay(int seconds)148 void TcpClient::StartWithDelay(int seconds) {
149 std::lock_guard<std::mutex> lock(connection_mutex_);
150 if (buffer_event_) {
151 return;
152 }
153
154 event_base_ = event_base_new();
155 MS_EXCEPTION_IF_NULL(event_base_);
156
157 timeval timeout_value{};
158 timeout_value.tv_sec = seconds;
159 timeout_value.tv_usec = 0;
160
161 event_timeout_ = evtimer_new(event_base_, TimeoutCallback, this);
162 MS_EXCEPTION_IF_NULL(event_timeout_);
163 if (evtimer_add(event_timeout_, &timeout_value) == -1) {
164 MS_LOG(EXCEPTION) << "Event timeout failed!";
165 }
166 }
167
Stop()168 void TcpClient::Stop() {
169 MS_EXCEPTION_IF_NULL(event_base_);
170 std::lock_guard<std::mutex> lock(connection_mutex_);
171 if (event_base_got_break(event_base_)) {
172 MS_LOG(DEBUG) << "The event base has already been stopped!";
173 return;
174 }
175
176 MS_LOG(INFO) << "Stop tcp client!";
177 int ret = event_base_loopbreak(event_base_);
178 if (ret != 0) {
179 MS_LOG(ERROR) << "Event base loop break failed!";
180 }
181 }
182
SetTcpNoDelay(const evutil_socket_t & fd)183 void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) {
184 const int one = 1;
185 int ret = setsockopt(fd, static_cast<int>(IPPROTO_TCP), static_cast<int>(TCP_NODELAY), &one, sizeof(int));
186 if (ret < 0) {
187 MS_LOG(EXCEPTION) << "Set socket no delay failed!";
188 }
189 }
190
TimeoutCallback(evutil_socket_t,std::int16_t,void * const arg)191 void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *const arg) {
192 try {
193 MS_EXCEPTION_IF_NULL(arg);
194 auto tcp_client = reinterpret_cast<TcpClient *>(arg);
195 tcp_client->Init();
196 } catch (const std::exception &e) {
197 MS_LOG(ERROR) << "Catch exception: " << e.what();
198 }
199 }
200
ReadCallback(struct bufferevent * bev,void * const ctx)201 void TcpClient::ReadCallback(struct bufferevent *bev, void *const ctx) {
202 try {
203 MS_EXCEPTION_IF_NULL(ctx);
204 auto tcp_client = reinterpret_cast<TcpClient *>(ctx);
205 tcp_client->ReadCallbackInner(bev);
206 } catch (const std::exception &e) {
207 MS_LOG(ERROR) << "Catch exception: " << e.what();
208 }
209 }
210
ReadCallbackInner(struct bufferevent * bev)211 void TcpClient::ReadCallbackInner(struct bufferevent *bev) {
212 MS_EXCEPTION_IF_NULL(bev);
213
214 char read_buffer[kMessageChunkLength];
215 size_t read = 0;
216
217 while ((read = bufferevent_read(bev, &read_buffer, sizeof(read_buffer))) > 0) {
218 OnReadHandler(read_buffer, read);
219 }
220 }
221
OnReadHandler(const void * buf,size_t num)222 void TcpClient::OnReadHandler(const void *buf, size_t num) {
223 MS_EXCEPTION_IF_NULL(buf);
224 if (read_callback_) {
225 read_callback_(buf, num);
226 }
227 message_handler_.ReceiveMessage(buf, num);
228 }
229
TimerCallback(evutil_socket_t,int16_t,void * arg)230 void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) {
231 MS_EXCEPTION_IF_NULL(arg);
232 auto tcp_client = reinterpret_cast<TcpClient *>(arg);
233 if (tcp_client->on_timer_callback_) {
234 tcp_client->on_timer_callback_();
235 }
236 }
237
NotifyConnected()238 void TcpClient::NotifyConnected() {
239 MS_LOG(INFO) << "Client connected to the server! Peer " << PeerRoleName() << " ip: " << server_address_
240 << ", port: " << server_port_;
241 connection_status_ = 1;
242 connection_cond_.notify_all();
243 }
244
EstablishSSL()245 bool TcpClient::EstablishSSL() {
246 MS_LOG(INFO) << "Enable ssl support.";
247
248 SSL *ssl = SSL_new(SSLClient::GetInstance().GetSSLCtx());
249 MS_ERROR_IF_NULL_W_RET_VAL(ssl, false);
250 MS_ERROR_IF_NULL_W_RET_VAL(event_base_, false);
251
252 buffer_event_ = bufferevent_openssl_socket_new(event_base_, -1, ssl, BUFFEREVENT_SSL_CONNECTING,
253 BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
254 return true;
255 }
256
EventCallback(struct bufferevent * bev,std::int16_t events,void * const ptr)257 void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *const ptr) {
258 try {
259 MS_EXCEPTION_IF_NULL(ptr);
260 auto tcp_client = reinterpret_cast<TcpClient *>(ptr);
261 tcp_client->EventCallbackInner(bev, events);
262 } catch (const std::exception &e) {
263 MS_LOG(ERROR) << "Catch exception: " << e.what();
264 }
265 }
266
EventCallbackInner(struct bufferevent * bev,std::int16_t events)267 void TcpClient::EventCallbackInner(struct bufferevent *bev, std::int16_t events) {
268 MS_EXCEPTION_IF_NULL(bev);
269 if (events & BEV_EVENT_CONNECTED) {
270 // Connected
271 if (connected_callback_) {
272 connected_callback_();
273 }
274 NotifyConnected();
275 evutil_socket_t fd = bufferevent_getfd(bev);
276 SetTcpNoDelay(fd);
277 MS_LOG(INFO) << "Client connected! Peer " << PeerRoleName() << " ip: " << server_address_
278 << ", port: " << server_port_;
279 } else if (events & BEV_EVENT_ERROR) {
280 if (PSContext::instance()->enable_ssl()) {
281 uint64_t err = bufferevent_get_openssl_error(bev);
282 const uint64_t server_not_start_err = 5;
283 if (err != server_not_start_err) {
284 MS_LOG(WARNING) << "The error number is:" << err << ", error message:" << ERR_reason_error_string(err)
285 << ", the error lib:" << ERR_lib_error_string(err)
286 << ", the error func:" << ERR_func_error_string(err);
287 return;
288 }
289 }
290 connection_status_ = -1;
291 if (disconnected_callback_) {
292 MS_LOG(WARNING) << "The client will retry to connect to the server! Peer " << PeerRoleName()
293 << " ip: " << server_address_ << ", port: " << server_port_;
294 disconnected_callback_();
295 }
296 } else if (events & BEV_EVENT_EOF) {
297 MS_LOG(WARNING) << "Client connected end of file! Peer " << PeerRoleName() << " ip: " << server_address_
298 << ", port: " << server_port_;
299 connection_status_ = -1;
300 if (disconnected_callback_) {
301 disconnected_callback_();
302 }
303 }
304 }
305
Start()306 void TcpClient::Start() {
307 event_base_mutex_.lock();
308 if (is_started_) {
309 event_base_mutex_.unlock();
310 return;
311 }
312 is_started_ = true;
313 event_base_mutex_.unlock();
314 MS_EXCEPTION_IF_NULL(event_base_);
315 int ret = event_base_dispatch(event_base_);
316 // is_started_ should be false when finish dispatch
317 is_started_ = false;
318 MSLOG_IF(MsLogLevel::kInfo, ret == 0, NoExceptionType, nullptr) << "Event base dispatch success!";
319 MSLOG_IF(MsLogLevel::kError, ret == 1, NoExceptionType, nullptr)
320 << "Event base dispatch failed with no events pending or active!";
321 MSLOG_IF(MsLogLevel::kError, ret == -1, NoExceptionType, nullptr)
322 << "Event base dispatch failed with error occurred!";
323 MSLOG_IF(MsLogLevel::kException, ret < -1, AbortedError, nullptr)
324 << "Event base dispatch with unexpected error code!";
325 }
326
StartWithNoBlock()327 void TcpClient::StartWithNoBlock() {
328 std::lock_guard<std::mutex> lock(connection_mutex_);
329 MS_LOG(INFO) << "Start tcp client with no block!";
330 MS_EXCEPTION_IF_NULL(event_base_);
331 int ret = event_base_loop(event_base_, EVLOOP_NONBLOCK);
332 MSLOG_IF(MsLogLevel::kInfo, ret == 0, NoExceptionType, nullptr) << "Event base loop success!";
333 MSLOG_IF(MsLogLevel::kError, ret == 1, NoExceptionType, nullptr)
334 << "Event base loop failed with no events pending or active!";
335 MSLOG_IF(MsLogLevel::kError, ret == -1, NoExceptionType, nullptr) << "Event base loop failed with error occurred!";
336 MSLOG_IF(MsLogLevel::kException, ret < -1, AbortedError, nullptr) << "Event base loop with unexpected error code!";
337 }
338
SetMessageCallback(const OnMessage & cb)339 void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; }
340
SendMessage(const CommMessage & message) const341 bool TcpClient::SendMessage(const CommMessage &message) const {
342 MS_EXCEPTION_IF_NULL(buffer_event_);
343 bufferevent_lock(buffer_event_);
344 bool res = true;
345 size_t buf_size = message.ByteSizeLong();
346 uint32_t meta_size = SizeToUint(message.pb_meta().ByteSizeLong());
347 MessageHeader header;
348 header.message_proto_ = Protos::PROTOBUF;
349 header.message_length_ = buf_size;
350 header.message_meta_length_ = meta_size;
351 if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
352 MS_LOG(ERROR) << "Event buffer add header failed!";
353 res = false;
354 }
355 if (bufferevent_write(buffer_event_, message.pb_meta().SerializeAsString().data(), meta_size) == -1) {
356 MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
357 res = false;
358 }
359 if (bufferevent_write(buffer_event_, message.data().data(), message.data().length()) == -1) {
360 MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
361 res = false;
362 }
363 bufferevent_unlock(buffer_event_);
364 return res;
365 }
366
SendMessage(const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size)367 bool TcpClient::SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
368 size_t size) {
369 MS_EXCEPTION_IF_NULL(buffer_event_);
370 MS_EXCEPTION_IF_NULL(meta);
371 MS_EXCEPTION_IF_NULL(data);
372 bufferevent_lock(buffer_event_);
373 bool res = true;
374
375 MessageHeader header;
376 header.message_proto_ = protos;
377 header.message_meta_length_ = SizeToUint(meta->ByteSizeLong());
378 header.message_length_ = size + header.message_meta_length_;
379
380 if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
381 MS_LOG(ERROR) << "Event buffer add header failed!";
382 res = false;
383 }
384 if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) {
385 MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
386 res = false;
387 }
388 if (bufferevent_write(buffer_event_, data, size) == -1) {
389 MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
390 res = false;
391 }
392 int result = bufferevent_flush(buffer_event_, EV_READ | EV_WRITE, BEV_FLUSH);
393 if (result < 0) {
394 MS_LOG(ERROR) << "Bufferevent flush failed!";
395 res = false;
396 }
397 bufferevent_unlock(buffer_event_);
398 return res;
399 }
400
set_timer_callback(const OnTimer & timer)401 void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; }
402
eventbase() const403 const event_base &TcpClient::eventbase() const { return *event_base_; }
404 } // namespace core
405 } // namespace ps
406 } // namespace mindspore
407