• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/backend/distributed/rpc/tcp/tcp_client.h"
18 #include "ps/core/communicator/ssl_client.h"
19 #include "distributed/rpc/tcp/tcp_comm.h"
20 
21 namespace mindspore {
22 namespace distributed {
23 namespace rpc {
TCPClient(bool enable_ssl)24 TCPClient::TCPClient(bool enable_ssl) : RPCClientBase(enable_ssl), tcp_comm_(nullptr), received_message_(nullptr) {
25   std::string env_receive_msg_timeout = common::GetEnv(kEnvReceiveMsgTimeOut);
26   int int_receive_timeout =
27     env_receive_msg_timeout.empty() ? kDefaultReceiveMsgTimeOut : std::stoi(env_receive_msg_timeout);
28   receive_timeout_ = (int_receive_timeout < 0) ? UINT64_MAX : int_receive_timeout;
29   MS_LOG(INFO) << "Tcp client receiving message timeout is " << receive_timeout_ << " seconds.";
30 }
~TCPClient()31 TCPClient::~TCPClient() {}
32 
Initialize()33 bool TCPClient::Initialize() {
34   bool rt = false;
35   if (tcp_comm_ == nullptr) {
36     if (enable_ssl_) {
37       (void)ps::core::SSLClient::GetInstance().GetSSLCtx();
38     }
39     tcp_comm_ = std::make_unique<TCPComm>(enable_ssl_);
40     MS_EXCEPTION_IF_NULL(tcp_comm_);
41 
42     // This message handler is used to accept and maintain the received message from the tcp server.
43     tcp_comm_->SetMessageHandler([this](MessageBase *const message) -> MessageBase *const {
44       // Wait for the previous received message has been handled.
45       const int sleep_time = 10;
46       while (received_message_ != nullptr) {
47         std::this_thread::sleep_for(std::chrono::milliseconds(sleep_time));
48       }
49       std::unique_lock<std::mutex> lock(mutex_);
50       received_message_ = message;
51       wait_msg_cond_.notify_one();
52       return NULL_MSG;
53     });
54     rt = tcp_comm_->Initialize();
55   } else {
56     rt = true;
57   }
58   return rt;
59 }
60 
Finalize()61 void TCPClient::Finalize() {
62   if (tcp_comm_ != nullptr) {
63     tcp_comm_->Finalize();
64     tcp_comm_.reset();
65     tcp_comm_ = nullptr;
66   }
67 }
68 
Connect(const std::string & dst_url,size_t retry_count,const MemFreeCallback & free_cb)69 bool TCPClient::Connect(const std::string &dst_url, size_t retry_count, const MemFreeCallback &free_cb) {
70   unsigned int interval = 2;
71   for (size_t i = 0; i < retry_count; ++i) {
72     if (tcp_comm_->Connect(dst_url, free_cb)) {
73       MS_LOG(INFO) << "Connected to the tcp server " << dst_url << " successfully.";
74       return true;
75     } else {
76       MS_LOG(WARNING) << "Failed to connect to the tcp server : " << dst_url << ", retry to reconnect(" << (i + 1)
77                       << "/" << retry_count << ")...";
78       if (!tcp_comm_->Disconnect(dst_url)) {
79         MS_LOG(ERROR) << "Can not disconnect from the server: " << dst_url;
80         return false;
81       }
82       (void)sleep(interval);
83     }
84   }
85   return false;
86 }
87 
IsConnected(const std::string & dst_url)88 bool TCPClient::IsConnected(const std::string &dst_url) { return tcp_comm_->IsConnected(dst_url); }
89 
Disconnect(const std::string & dst_url,size_t timeout_in_sec)90 bool TCPClient::Disconnect(const std::string &dst_url, size_t timeout_in_sec) {
91   bool rt = false;
92   if (!tcp_comm_->Disconnect(dst_url)) {
93     MS_LOG(ERROR) << "Can not disconnect from the server: " << dst_url;
94     return false;
95   }
96 
97   size_t timeout_in_ms = timeout_in_sec * 1000;
98   size_t sleep_in_ms = 100;
99   useconds_t sleep_in_us = 100000;
100 
101   while (true) {
102     if (!tcp_comm_->IsConnected(dst_url)) {
103       rt = true;
104       break;
105     }
106     if (timeout_in_ms > sleep_in_ms) {
107       timeout_in_ms -= sleep_in_ms;
108     } else {
109       break;
110     }
111     (void)usleep(sleep_in_us);
112   }
113   return rt;
114 }
115 
SendSync(std::unique_ptr<MessageBase> && msg,size_t * const send_bytes)116 bool TCPClient::SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes) {
117   return tcp_comm_->Send(msg.release(), send_bytes, true);
118 }
119 
SendAsync(std::unique_ptr<MessageBase> && msg)120 void TCPClient::SendAsync(std::unique_ptr<MessageBase> &&msg) { (void)tcp_comm_->Send(msg.release(), nullptr, false); }
121 
ReceiveSync(std::unique_ptr<MessageBase> && msg,uint32_t timeout)122 MessageBase *TCPClient::ReceiveSync(std::unique_ptr<MessageBase> &&msg, uint32_t timeout) {
123   if (timeout == UINT32_MAX) {
124     // This means we should use default ReceiveMsgTimeOut as timeout.
125     timeout = receive_timeout_;
126   }
127   bool retval = tcp_comm_->Send(msg.release(), nullptr, true);
128   if (retval) {
129     std::unique_lock<std::mutex> lock(mutex_);
130     received_message_ = nullptr;
131     bool res =
132       wait_msg_cond_.wait_for(lock, std::chrono::seconds(timeout), [this] { return received_message_ != nullptr; });
133     if (res) {
134       // Clear the address of received message before returning this address to the caller, because the next
135       // `ReceiveSync` call will block on the received message's condition variable.
136       MessageBase *message = received_message_;
137       return message;
138     } else {
139       MS_LOG(WARNING) << "Failed to receive message.";
140     }
141   } else {
142     MS_LOG(INFO) << "Failed to send message in ReceiveSync.";
143   }
144   return NULL_MSG;
145 }
146 
Flush(const std::string & dst_url)147 bool TCPClient::Flush(const std::string &dst_url) { return tcp_comm_->Flush(dst_url); }
148 
GetClientIPByDstUrl(const std::string & dst_url) const149 std::string TCPClient::GetClientIPByDstUrl(const std::string &dst_url) const {
150   return tcp_comm_->GetClientSrcIP(dst_url);
151 }
152 }  // namespace rpc
153 }  // namespace distributed
154 }  // namespace mindspore
155