• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/core/communicator/tcp_client.h"
18 
19 #include <arpa/inet.h>
20 #include <event2/buffer.h>
21 #include <event2/buffer_compat.h>
22 #include <event2/bufferevent.h>
23 #include <event2/event.h>
24 #include <netinet/in.h>
25 #include <netinet/tcp.h>
26 #include <sys/socket.h>
27 #include <cstdlib>
28 #include <cstring>
29 #include <iostream>
30 #include <string>
31 #include <utility>
32 
33 namespace mindspore {
34 namespace ps {
35 namespace core {
36 event_base *TcpClient::event_base_ = nullptr;
37 std::mutex TcpClient::event_base_mutex_;
38 bool TcpClient::is_started_ = false;
39 
TcpClient(const std::string & address,std::uint16_t port,Configuration * const config)40 TcpClient::TcpClient(const std::string &address, std::uint16_t port, Configuration *const config)
41     : event_timeout_(nullptr),
42       buffer_event_(nullptr),
43       server_address_(std::move(address)),
44       server_port_(port),
45       is_stop_(true),
46       is_connected_(false),
47       config_(config) {
48   message_handler_.SetCallback(
49     [this](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) {
50       if (message_callback_) {
51         message_callback_(meta, protos, data, size);
52       }
53     });
54 }
55 
~TcpClient()56 TcpClient::~TcpClient() {
57   if (buffer_event_) {
58     bufferevent_free(buffer_event_);
59     buffer_event_ = nullptr;
60   }
61   if (event_timeout_) {
62     event_free(event_timeout_);
63     event_timeout_ = nullptr;
64   }
65 }
66 
GetServerAddress() const67 std::string TcpClient::GetServerAddress() const { return server_address_; }
68 
set_disconnected_callback(const OnDisconnected & disconnected)69 void TcpClient::set_disconnected_callback(const OnDisconnected &disconnected) { disconnected_callback_ = disconnected; }
70 
set_connected_callback(const OnConnected & connected)71 void TcpClient::set_connected_callback(const OnConnected &connected) { connected_callback_ = connected; }
72 
WaitConnected(const uint32_t & connected_timeout)73 bool TcpClient::WaitConnected(const uint32_t &connected_timeout) {
74   std::unique_lock<std::mutex> lock(connection_mutex_);
75   bool res = connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout),
76                                        [this] { return this->is_connected_.load(); });
77   return res;
78 }
79 
Init()80 void TcpClient::Init() {
81   std::lock_guard<std::mutex> lock(connection_mutex_);
82   if (buffer_event_) {
83     bufferevent_free(buffer_event_);
84     buffer_event_ = nullptr;
85   }
86   if (!CommUtil::CheckIp(server_address_)) {
87     MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!";
88   }
89 
90   int result = evthread_use_pthreads();
91   if (result != 0) {
92     MS_LOG(EXCEPTION) << "Use event pthread failed!";
93   }
94   if (event_base_ == nullptr) {
95     event_base_ = event_base_new();
96     MS_EXCEPTION_IF_NULL(event_base_);
97     is_stop_ = false;
98   }
99 
100   sockaddr_in sin{};
101   if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
102     MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
103   }
104   sin.sin_family = AF_INET;
105   sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
106   sin.sin_port = htons(server_port_);
107 
108   if (!PSContext::instance()->enable_ssl()) {
109     MS_LOG(INFO) << "SSL is disable.";
110     buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
111   } else {
112     if (!EstablishSSL()) {
113       MS_LOG(WARNING) << "Establish SSL failed.";
114       return;
115     }
116   }
117 
118   MS_EXCEPTION_IF_NULL(buffer_event_);
119 
120   bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this);
121   if (bufferevent_enable(buffer_event_, EV_READ | EV_WRITE) == -1) {
122     MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!";
123   }
124 
125   int result_code = bufferevent_socket_connect(buffer_event_, reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
126   if (result_code < 0) {
127     MS_LOG(EXCEPTION) << "Connect server ip:" << server_address_ << " and port: " << server_port_ << " is failed!";
128   }
129 }
130 
StartWithDelay(int seconds)131 void TcpClient::StartWithDelay(int seconds) {
132   std::lock_guard<std::mutex> lock(connection_mutex_);
133   if (buffer_event_) {
134     return;
135   }
136 
137   event_base_ = event_base_new();
138   MS_EXCEPTION_IF_NULL(event_base_);
139 
140   timeval timeout_value{};
141   timeout_value.tv_sec = seconds;
142   timeout_value.tv_usec = 0;
143 
144   event_timeout_ = evtimer_new(event_base_, TimeoutCallback, this);
145   MS_EXCEPTION_IF_NULL(event_timeout_);
146   if (evtimer_add(event_timeout_, &timeout_value) == -1) {
147     MS_LOG(EXCEPTION) << "Event timeout failed!";
148   }
149 }
150 
Stop()151 void TcpClient::Stop() {
152   MS_EXCEPTION_IF_NULL(event_base_);
153   std::lock_guard<std::mutex> lock(connection_mutex_);
154   MS_LOG(INFO) << "Stop tcp client!";
155   int ret = event_base_loopbreak(event_base_);
156   if (ret != 0) {
157     MS_LOG(ERROR) << "Event base loop break failed!";
158   }
159 }
160 
SetTcpNoDelay(const evutil_socket_t & fd)161 void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) {
162   const int one = 1;
163   int ret = setsockopt(fd, static_cast<int>(IPPROTO_TCP), static_cast<int>(TCP_NODELAY), &one, sizeof(int));
164   if (ret < 0) {
165     MS_LOG(EXCEPTION) << "Set socket no delay failed!";
166   }
167 }
168 
TimeoutCallback(evutil_socket_t,std::int16_t,void * arg)169 void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *arg) {
170   MS_EXCEPTION_IF_NULL(arg);
171   auto tcp_client = reinterpret_cast<TcpClient *>(arg);
172   tcp_client->Init();
173 }
174 
ReadCallback(struct bufferevent * bev,void * ctx)175 void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
176   MS_EXCEPTION_IF_NULL(bev);
177   MS_EXCEPTION_IF_NULL(ctx);
178   auto tcp_client = reinterpret_cast<TcpClient *>(ctx);
179 
180   char read_buffer[kMessageChunkLength];
181   size_t read = 0;
182 
183   while ((read = bufferevent_read(bev, &read_buffer, sizeof(read_buffer))) > 0) {
184     tcp_client->OnReadHandler(read_buffer, read);
185   }
186 }
187 
OnReadHandler(const void * buf,size_t num)188 void TcpClient::OnReadHandler(const void *buf, size_t num) {
189   MS_EXCEPTION_IF_NULL(buf);
190   if (read_callback_) {
191     read_callback_(buf, num);
192   }
193   message_handler_.ReceiveMessage(buf, num);
194 }
195 
TimerCallback(evutil_socket_t,int16_t,void * arg)196 void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) {
197   MS_EXCEPTION_IF_NULL(arg);
198   auto tcp_client = reinterpret_cast<TcpClient *>(arg);
199   if (tcp_client->on_timer_callback_) {
200     tcp_client->on_timer_callback_();
201   }
202 }
203 
NotifyConnected()204 void TcpClient::NotifyConnected() {
205   MS_LOG(INFO) << "Client connected to the server!";
206   is_connected_ = true;
207   connection_cond_.notify_all();
208 }
209 
EstablishSSL()210 bool TcpClient::EstablishSSL() {
211   MS_LOG(INFO) << "Enable ssl support.";
212 
213   SSL *ssl = SSL_new(SSLClient::GetInstance().GetSSLCtx());
214   MS_ERROR_IF_NULL_W_RET_VAL(ssl, false);
215   MS_ERROR_IF_NULL_W_RET_VAL(event_base_, false);
216 
217   buffer_event_ = bufferevent_openssl_socket_new(event_base_, -1, ssl, BUFFEREVENT_SSL_CONNECTING,
218                                                  BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
219   return true;
220 }
221 
EventCallback(struct bufferevent * bev,std::int16_t events,void * ptr)222 void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) {
223   MS_EXCEPTION_IF_NULL(bev);
224   MS_EXCEPTION_IF_NULL(ptr);
225   auto tcp_client = reinterpret_cast<TcpClient *>(ptr);
226   if (events & BEV_EVENT_CONNECTED) {
227     // Connected
228     if (tcp_client->connected_callback_) {
229       tcp_client->connected_callback_();
230     }
231     tcp_client->NotifyConnected();
232     evutil_socket_t fd = bufferevent_getfd(bev);
233     SetTcpNoDelay(fd);
234     MS_LOG(INFO) << "Client connected!";
235   } else if (events & BEV_EVENT_ERROR) {
236     MS_LOG(WARNING) << "The client will retry to connect to the server!";
237     if (tcp_client->disconnected_callback_) {
238       tcp_client->disconnected_callback_();
239     }
240   } else if (events & BEV_EVENT_EOF) {
241     MS_LOG(WARNING) << "Client connected end of file";
242   }
243 }
244 
Start()245 void TcpClient::Start() {
246   event_base_mutex_.lock();
247   if (is_started_) {
248     event_base_mutex_.unlock();
249     return;
250   }
251   is_started_ = true;
252   event_base_mutex_.unlock();
253   MS_EXCEPTION_IF_NULL(event_base_);
254   int ret = event_base_dispatch(event_base_);
255   MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
256   MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
257     << "Event base dispatch failed with no events pending or active!";
258   MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
259   MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpected error code!";
260 }
261 
StartWithNoBlock()262 void TcpClient::StartWithNoBlock() {
263   std::lock_guard<std::mutex> lock(connection_mutex_);
264   MS_LOG(INFO) << "Start tcp client with no block!";
265   MS_EXCEPTION_IF_NULL(event_base_);
266   int ret = event_base_loop(event_base_, EVLOOP_NONBLOCK);
267   MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
268   MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
269   MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
270   MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpected error code!";
271 }
272 
SetMessageCallback(const OnMessage & cb)273 void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; }
274 
SendMessage(const CommMessage & message) const275 bool TcpClient::SendMessage(const CommMessage &message) const {
276   MS_EXCEPTION_IF_NULL(buffer_event_);
277   bufferevent_lock(buffer_event_);
278   bool res = true;
279   size_t buf_size = message.ByteSizeLong();
280   uint32_t meta_size = SizeToUint(message.pb_meta().ByteSizeLong());
281   MessageHeader header;
282   header.message_proto_ = Protos::PROTOBUF;
283   header.message_length_ = buf_size;
284   header.message_meta_length_ = meta_size;
285   if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
286     MS_LOG(ERROR) << "Event buffer add header failed!";
287     res = false;
288   }
289   if (bufferevent_write(buffer_event_, message.pb_meta().SerializeAsString().data(), meta_size) == -1) {
290     MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
291     res = false;
292   }
293   if (bufferevent_write(buffer_event_, message.data().data(), message.data().length()) == -1) {
294     MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
295     res = false;
296   }
297   bufferevent_unlock(buffer_event_);
298   return res;
299 }
300 
SendMessage(const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size)301 bool TcpClient::SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
302                             size_t size) {
303   MS_EXCEPTION_IF_NULL(buffer_event_);
304   MS_EXCEPTION_IF_NULL(meta);
305   MS_EXCEPTION_IF_NULL(data);
306   bufferevent_lock(buffer_event_);
307   bool res = true;
308 
309   MessageHeader header;
310   header.message_proto_ = protos;
311   header.message_meta_length_ = SizeToUint(meta->ByteSizeLong());
312   header.message_length_ = size + header.message_meta_length_;
313 
314   if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
315     MS_LOG(ERROR) << "Event buffer add header failed!";
316     res = false;
317   }
318   if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) {
319     MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
320     res = false;
321   }
322   if (bufferevent_write(buffer_event_, data, size) == -1) {
323     MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
324     res = false;
325   }
326   int result = bufferevent_flush(buffer_event_, EV_READ | EV_WRITE, BEV_FLUSH);
327   if (result < 0) {
328     MS_LOG(ERROR) << "Bufferevent flush failed!";
329     res = false;
330   }
331   bufferevent_unlock(buffer_event_);
332   return res;
333 }
334 
set_timer_callback(const OnTimer & timer)335 void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; }
336 
eventbase() const337 const event_base &TcpClient::eventbase() const { return *event_base_; }
338 }  // namespace core
339 }  // namespace ps
340 }  // namespace mindspore
341