• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/core/communicator/tcp_client.h"
18 
19 #include <arpa/inet.h>
20 #include <event2/buffer.h>
21 #include <event2/buffer_compat.h>
22 #include <event2/bufferevent.h>
23 #include <event2/event.h>
24 #include <netinet/in.h>
25 #include <netinet/tcp.h>
26 #include <sys/socket.h>
27 #include <cstdlib>
28 #include <cstring>
29 #include <iostream>
30 #include <string>
31 #include <utility>
32 
33 namespace mindspore {
34 namespace ps {
35 namespace core {
36 event_base *TcpClient::event_base_ = nullptr;
37 std::mutex TcpClient::event_base_mutex_;
38 bool TcpClient::is_started_ = false;
39 
TcpClient(const std::string & address,std::uint16_t port,NodeRole peer_role)40 TcpClient::TcpClient(const std::string &address, std::uint16_t port, NodeRole peer_role)
41     : event_timeout_(nullptr),
42       buffer_event_(nullptr),
43       server_address_(std::move(address)),
44       server_port_(port),
45       peer_role_(peer_role),
46       connection_status_(-1) {
47   message_handler_.SetCallback(
48     [this](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) {
49       if (message_callback_) {
50         message_callback_(meta, protos, data, size);
51       }
52     });
53 }
54 
~TcpClient()55 TcpClient::~TcpClient() {
56   if (buffer_event_) {
57     bufferevent_free(buffer_event_);
58     buffer_event_ = nullptr;
59   }
60   if (event_timeout_) {
61     event_free(event_timeout_);
62     event_timeout_ = nullptr;
63   }
64 }
65 
GetServerAddress() const66 std::string TcpClient::GetServerAddress() const { return server_address_; }
67 
set_disconnected_callback(const OnDisconnected & disconnected)68 void TcpClient::set_disconnected_callback(const OnDisconnected &disconnected) { disconnected_callback_ = disconnected; }
69 
set_connected_callback(const OnConnected & connected)70 void TcpClient::set_connected_callback(const OnConnected &connected) { connected_callback_ = connected; }
71 
PeerRoleName() const72 std::string TcpClient::PeerRoleName() const {
73   switch (peer_role_) {
74     case SERVER:
75       return "Server";
76     case WORKER:
77       return "Worker";
78     case SCHEDULER:
79       return "Scheduler";
80     default:
81       return "RoleUndefined";
82   }
83 }
84 
WaitConnected(const uint32_t & connected_timeout)85 bool TcpClient::WaitConnected(const uint32_t &connected_timeout) {
86   std::unique_lock<std::mutex> lock(connection_mutex_);
87   bool res = connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout),
88                                        [this] { return this->connection_status_ == 1; });
89   return res;
90 }
91 
Init()92 void TcpClient::Init() {
93   std::lock_guard<std::mutex> lock(connection_mutex_);
94   if (connection_status_ != -1) {
95     return;
96   }
97   connection_status_ = 0;
98   if (buffer_event_) {
99     bufferevent_free(buffer_event_);
100     buffer_event_ = nullptr;
101   }
102   if (!CommUtil::CheckIp(server_address_)) {
103     MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!";
104   }
105 
106   int result = evthread_use_pthreads();
107   if (result != 0) {
108     MS_LOG(EXCEPTION) << "Use event pthread failed!";
109   }
110   if (event_base_ == nullptr) {
111     event_base_ = event_base_new();
112     MS_EXCEPTION_IF_NULL(event_base_);
113   }
114 
115   if (!PSContext::instance()->enable_ssl()) {
116     MS_LOG(INFO) << "SSL is disable.";
117     buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
118   } else {
119     if (!EstablishSSL()) {
120       MS_LOG(WARNING) << "Establish SSL failed.";
121       return;
122     }
123   }
124 
125   MS_EXCEPTION_IF_NULL(buffer_event_);
126 
127   bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this);
128   if (bufferevent_enable(buffer_event_, EV_READ | EV_WRITE) == -1) {
129     MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!";
130   }
131 
132   if (server_port_ > 0) {
133     sockaddr_in sin{};
134     if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
135       MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
136     }
137     sin.sin_family = AF_INET;
138     sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
139     sin.sin_port = htons(server_port_);
140 
141     int result_code = bufferevent_socket_connect(buffer_event_, reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
142     if (result_code < 0) {
143       MS_LOG(WARNING) << "Connect server ip:" << server_address_ << " and port: " << server_port_ << " is failed!";
144     }
145   }
146 }
147 
StartWithDelay(int seconds)148 void TcpClient::StartWithDelay(int seconds) {
149   std::lock_guard<std::mutex> lock(connection_mutex_);
150   if (buffer_event_) {
151     return;
152   }
153 
154   event_base_ = event_base_new();
155   MS_EXCEPTION_IF_NULL(event_base_);
156 
157   timeval timeout_value{};
158   timeout_value.tv_sec = seconds;
159   timeout_value.tv_usec = 0;
160 
161   event_timeout_ = evtimer_new(event_base_, TimeoutCallback, this);
162   MS_EXCEPTION_IF_NULL(event_timeout_);
163   if (evtimer_add(event_timeout_, &timeout_value) == -1) {
164     MS_LOG(EXCEPTION) << "Event timeout failed!";
165   }
166 }
167 
Stop()168 void TcpClient::Stop() {
169   MS_EXCEPTION_IF_NULL(event_base_);
170   std::lock_guard<std::mutex> lock(connection_mutex_);
171   if (event_base_got_break(event_base_)) {
172     MS_LOG(DEBUG) << "The event base has already been stopped!";
173     return;
174   }
175 
176   MS_LOG(INFO) << "Stop tcp client!";
177   int ret = event_base_loopbreak(event_base_);
178   if (ret != 0) {
179     MS_LOG(ERROR) << "Event base loop break failed!";
180   }
181 }
182 
SetTcpNoDelay(const evutil_socket_t & fd)183 void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) {
184   const int one = 1;
185   int ret = setsockopt(fd, static_cast<int>(IPPROTO_TCP), static_cast<int>(TCP_NODELAY), &one, sizeof(int));
186   if (ret < 0) {
187     MS_LOG(EXCEPTION) << "Set socket no delay failed!";
188   }
189 }
190 
TimeoutCallback(evutil_socket_t,std::int16_t,void * const arg)191 void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *const arg) {
192   try {
193     MS_EXCEPTION_IF_NULL(arg);
194     auto tcp_client = reinterpret_cast<TcpClient *>(arg);
195     tcp_client->Init();
196   } catch (const std::exception &e) {
197     MS_LOG(ERROR) << "Catch exception: " << e.what();
198   }
199 }
200 
ReadCallback(struct bufferevent * bev,void * const ctx)201 void TcpClient::ReadCallback(struct bufferevent *bev, void *const ctx) {
202   try {
203     MS_EXCEPTION_IF_NULL(ctx);
204     auto tcp_client = reinterpret_cast<TcpClient *>(ctx);
205     tcp_client->ReadCallbackInner(bev);
206   } catch (const std::exception &e) {
207     MS_LOG(ERROR) << "Catch exception: " << e.what();
208   }
209 }
210 
ReadCallbackInner(struct bufferevent * bev)211 void TcpClient::ReadCallbackInner(struct bufferevent *bev) {
212   MS_EXCEPTION_IF_NULL(bev);
213 
214   char read_buffer[kMessageChunkLength];
215   size_t read = 0;
216 
217   while ((read = bufferevent_read(bev, &read_buffer, sizeof(read_buffer))) > 0) {
218     OnReadHandler(read_buffer, read);
219   }
220 }
221 
OnReadHandler(const void * buf,size_t num)222 void TcpClient::OnReadHandler(const void *buf, size_t num) {
223   MS_EXCEPTION_IF_NULL(buf);
224   if (read_callback_) {
225     read_callback_(buf, num);
226   }
227   message_handler_.ReceiveMessage(buf, num);
228 }
229 
TimerCallback(evutil_socket_t,int16_t,void * arg)230 void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) {
231   MS_EXCEPTION_IF_NULL(arg);
232   auto tcp_client = reinterpret_cast<TcpClient *>(arg);
233   if (tcp_client->on_timer_callback_) {
234     tcp_client->on_timer_callback_();
235   }
236 }
237 
NotifyConnected()238 void TcpClient::NotifyConnected() {
239   MS_LOG(INFO) << "Client connected to the server! Peer " << PeerRoleName() << " ip: " << server_address_
240                << ", port: " << server_port_;
241   connection_status_ = 1;
242   connection_cond_.notify_all();
243 }
244 
EstablishSSL()245 bool TcpClient::EstablishSSL() {
246   MS_LOG(INFO) << "Enable ssl support.";
247 
248   SSL *ssl = SSL_new(SSLClient::GetInstance().GetSSLCtx());
249   MS_ERROR_IF_NULL_W_RET_VAL(ssl, false);
250   MS_ERROR_IF_NULL_W_RET_VAL(event_base_, false);
251 
252   buffer_event_ = bufferevent_openssl_socket_new(event_base_, -1, ssl, BUFFEREVENT_SSL_CONNECTING,
253                                                  BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
254   return true;
255 }
256 
EventCallback(struct bufferevent * bev,std::int16_t events,void * const ptr)257 void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *const ptr) {
258   try {
259     MS_EXCEPTION_IF_NULL(ptr);
260     auto tcp_client = reinterpret_cast<TcpClient *>(ptr);
261     tcp_client->EventCallbackInner(bev, events);
262   } catch (const std::exception &e) {
263     MS_LOG(ERROR) << "Catch exception: " << e.what();
264   }
265 }
266 
EventCallbackInner(struct bufferevent * bev,std::int16_t events)267 void TcpClient::EventCallbackInner(struct bufferevent *bev, std::int16_t events) {
268   MS_EXCEPTION_IF_NULL(bev);
269   if (events & BEV_EVENT_CONNECTED) {
270     // Connected
271     if (connected_callback_) {
272       connected_callback_();
273     }
274     NotifyConnected();
275     evutil_socket_t fd = bufferevent_getfd(bev);
276     SetTcpNoDelay(fd);
277     MS_LOG(INFO) << "Client connected! Peer " << PeerRoleName() << " ip: " << server_address_
278                  << ", port: " << server_port_;
279   } else if (events & BEV_EVENT_ERROR) {
280     if (PSContext::instance()->enable_ssl()) {
281       uint64_t err = bufferevent_get_openssl_error(bev);
282       const uint64_t server_not_start_err = 5;
283       if (err != server_not_start_err) {
284         MS_LOG(WARNING) << "The error number is:" << err << ", error message:" << ERR_reason_error_string(err)
285                         << ", the error lib:" << ERR_lib_error_string(err)
286                         << ", the error func:" << ERR_func_error_string(err);
287         return;
288       }
289     }
290     connection_status_ = -1;
291     if (disconnected_callback_) {
292       MS_LOG(WARNING) << "The client will retry to connect to the server! Peer " << PeerRoleName()
293                       << " ip: " << server_address_ << ", port: " << server_port_;
294       disconnected_callback_();
295     }
296   } else if (events & BEV_EVENT_EOF) {
297     MS_LOG(WARNING) << "Client connected end of file! Peer " << PeerRoleName() << " ip: " << server_address_
298                     << ", port: " << server_port_;
299     connection_status_ = -1;
300     if (disconnected_callback_) {
301       disconnected_callback_();
302     }
303   }
304 }
305 
Start()306 void TcpClient::Start() {
307   event_base_mutex_.lock();
308   if (is_started_) {
309     event_base_mutex_.unlock();
310     return;
311   }
312   is_started_ = true;
313   event_base_mutex_.unlock();
314   MS_EXCEPTION_IF_NULL(event_base_);
315   int ret = event_base_dispatch(event_base_);
316   // is_started_ should be false when finish dispatch
317   is_started_ = false;
318   MSLOG_IF(MsLogLevel::kInfo, ret == 0, NoExceptionType, nullptr) << "Event base dispatch success!";
319   MSLOG_IF(MsLogLevel::kError, ret == 1, NoExceptionType, nullptr)
320     << "Event base dispatch failed with no events pending or active!";
321   MSLOG_IF(MsLogLevel::kError, ret == -1, NoExceptionType, nullptr)
322     << "Event base dispatch failed with error occurred!";
323   MSLOG_IF(MsLogLevel::kException, ret < -1, AbortedError, nullptr)
324     << "Event base dispatch with unexpected error code!";
325 }
326 
StartWithNoBlock()327 void TcpClient::StartWithNoBlock() {
328   std::lock_guard<std::mutex> lock(connection_mutex_);
329   MS_LOG(INFO) << "Start tcp client with no block!";
330   MS_EXCEPTION_IF_NULL(event_base_);
331   int ret = event_base_loop(event_base_, EVLOOP_NONBLOCK);
332   MSLOG_IF(MsLogLevel::kInfo, ret == 0, NoExceptionType, nullptr) << "Event base loop success!";
333   MSLOG_IF(MsLogLevel::kError, ret == 1, NoExceptionType, nullptr)
334     << "Event base loop failed with no events pending or active!";
335   MSLOG_IF(MsLogLevel::kError, ret == -1, NoExceptionType, nullptr) << "Event base loop failed with error occurred!";
336   MSLOG_IF(MsLogLevel::kException, ret < -1, AbortedError, nullptr) << "Event base loop with unexpected error code!";
337 }
338 
SetMessageCallback(const OnMessage & cb)339 void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; }
340 
SendMessage(const CommMessage & message) const341 bool TcpClient::SendMessage(const CommMessage &message) const {
342   MS_EXCEPTION_IF_NULL(buffer_event_);
343   bufferevent_lock(buffer_event_);
344   bool res = true;
345   size_t buf_size = message.ByteSizeLong();
346   uint32_t meta_size = SizeToUint(message.pb_meta().ByteSizeLong());
347   MessageHeader header;
348   header.message_proto_ = Protos::PROTOBUF;
349   header.message_length_ = buf_size;
350   header.message_meta_length_ = meta_size;
351   if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
352     MS_LOG(ERROR) << "Event buffer add header failed!";
353     res = false;
354   }
355   if (bufferevent_write(buffer_event_, message.pb_meta().SerializeAsString().data(), meta_size) == -1) {
356     MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
357     res = false;
358   }
359   if (bufferevent_write(buffer_event_, message.data().data(), message.data().length()) == -1) {
360     MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
361     res = false;
362   }
363   bufferevent_unlock(buffer_event_);
364   return res;
365 }
366 
SendMessage(const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size)367 bool TcpClient::SendMessage(const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
368                             size_t size) {
369   MS_EXCEPTION_IF_NULL(buffer_event_);
370   MS_EXCEPTION_IF_NULL(meta);
371   MS_EXCEPTION_IF_NULL(data);
372   bufferevent_lock(buffer_event_);
373   bool res = true;
374 
375   MessageHeader header;
376   header.message_proto_ = protos;
377   header.message_meta_length_ = SizeToUint(meta->ByteSizeLong());
378   header.message_length_ = size + header.message_meta_length_;
379 
380   if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
381     MS_LOG(ERROR) << "Event buffer add header failed!";
382     res = false;
383   }
384   if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) {
385     MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
386     res = false;
387   }
388   if (bufferevent_write(buffer_event_, data, size) == -1) {
389     MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
390     res = false;
391   }
392   int result = bufferevent_flush(buffer_event_, EV_READ | EV_WRITE, BEV_FLUSH);
393   if (result < 0) {
394     MS_LOG(ERROR) << "Bufferevent flush failed!";
395     res = false;
396   }
397   bufferevent_unlock(buffer_event_);
398   return res;
399 }
400 
set_timer_callback(const OnTimer & timer)401 void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; }
402 
eventbase() const403 const event_base &TcpClient::eventbase() const { return *event_base_; }
404 }  // namespace core
405 }  // namespace ps
406 }  // namespace mindspore
407