1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/backend/distributed/rpc/tcp/tcp_client.h"
18 #include "ps/core/communicator/ssl_client.h"
19 #include "distributed/rpc/tcp/tcp_comm.h"
20
21 namespace mindspore {
22 namespace distributed {
23 namespace rpc {
TCPClient(bool enable_ssl)24 TCPClient::TCPClient(bool enable_ssl) : RPCClientBase(enable_ssl), tcp_comm_(nullptr), received_message_(nullptr) {
25 std::string env_receive_msg_timeout = common::GetEnv(kEnvReceiveMsgTimeOut);
26 int int_receive_timeout =
27 env_receive_msg_timeout.empty() ? kDefaultReceiveMsgTimeOut : std::stoi(env_receive_msg_timeout);
28 receive_timeout_ = (int_receive_timeout < 0) ? UINT64_MAX : int_receive_timeout;
29 MS_LOG(INFO) << "Tcp client receiving message timeout is " << receive_timeout_ << " seconds.";
30 }
~TCPClient()31 TCPClient::~TCPClient() {}
32
Initialize()33 bool TCPClient::Initialize() {
34 bool rt = false;
35 if (tcp_comm_ == nullptr) {
36 if (enable_ssl_) {
37 (void)ps::core::SSLClient::GetInstance().GetSSLCtx();
38 }
39 tcp_comm_ = std::make_unique<TCPComm>(enable_ssl_);
40 MS_EXCEPTION_IF_NULL(tcp_comm_);
41
42 // This message handler is used to accept and maintain the received message from the tcp server.
43 tcp_comm_->SetMessageHandler([this](MessageBase *const message) -> MessageBase *const {
44 // Wait for the previous received message has been handled.
45 const int sleep_time = 10;
46 while (received_message_ != nullptr) {
47 std::this_thread::sleep_for(std::chrono::milliseconds(sleep_time));
48 }
49 std::unique_lock<std::mutex> lock(mutex_);
50 received_message_ = message;
51 wait_msg_cond_.notify_one();
52 return NULL_MSG;
53 });
54 rt = tcp_comm_->Initialize();
55 } else {
56 rt = true;
57 }
58 return rt;
59 }
60
Finalize()61 void TCPClient::Finalize() {
62 if (tcp_comm_ != nullptr) {
63 tcp_comm_->Finalize();
64 tcp_comm_.reset();
65 tcp_comm_ = nullptr;
66 }
67 }
68
Connect(const std::string & dst_url,size_t retry_count,const MemFreeCallback & free_cb)69 bool TCPClient::Connect(const std::string &dst_url, size_t retry_count, const MemFreeCallback &free_cb) {
70 unsigned int interval = 2;
71 for (size_t i = 0; i < retry_count; ++i) {
72 if (tcp_comm_->Connect(dst_url, free_cb)) {
73 MS_LOG(INFO) << "Connected to the tcp server " << dst_url << " successfully.";
74 return true;
75 } else {
76 MS_LOG(WARNING) << "Failed to connect to the tcp server : " << dst_url << ", retry to reconnect(" << (i + 1)
77 << "/" << retry_count << ")...";
78 if (!tcp_comm_->Disconnect(dst_url)) {
79 MS_LOG(ERROR) << "Can not disconnect from the server: " << dst_url;
80 return false;
81 }
82 (void)sleep(interval);
83 }
84 }
85 return false;
86 }
87
IsConnected(const std::string & dst_url)88 bool TCPClient::IsConnected(const std::string &dst_url) { return tcp_comm_->IsConnected(dst_url); }
89
Disconnect(const std::string & dst_url,size_t timeout_in_sec)90 bool TCPClient::Disconnect(const std::string &dst_url, size_t timeout_in_sec) {
91 bool rt = false;
92 if (!tcp_comm_->Disconnect(dst_url)) {
93 MS_LOG(ERROR) << "Can not disconnect from the server: " << dst_url;
94 return false;
95 }
96
97 size_t timeout_in_ms = timeout_in_sec * 1000;
98 size_t sleep_in_ms = 100;
99 useconds_t sleep_in_us = 100000;
100
101 while (true) {
102 if (!tcp_comm_->IsConnected(dst_url)) {
103 rt = true;
104 break;
105 }
106 if (timeout_in_ms > sleep_in_ms) {
107 timeout_in_ms -= sleep_in_ms;
108 } else {
109 break;
110 }
111 (void)usleep(sleep_in_us);
112 }
113 return rt;
114 }
115
SendSync(std::unique_ptr<MessageBase> && msg,size_t * const send_bytes)116 bool TCPClient::SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes) {
117 return tcp_comm_->Send(msg.release(), send_bytes, true);
118 }
119
SendAsync(std::unique_ptr<MessageBase> && msg)120 void TCPClient::SendAsync(std::unique_ptr<MessageBase> &&msg) { (void)tcp_comm_->Send(msg.release(), nullptr, false); }
121
ReceiveSync(std::unique_ptr<MessageBase> && msg,uint32_t timeout)122 MessageBase *TCPClient::ReceiveSync(std::unique_ptr<MessageBase> &&msg, uint32_t timeout) {
123 if (timeout == UINT32_MAX) {
124 // This means we should use default ReceiveMsgTimeOut as timeout.
125 timeout = receive_timeout_;
126 }
127 bool retval = tcp_comm_->Send(msg.release(), nullptr, true);
128 if (retval) {
129 std::unique_lock<std::mutex> lock(mutex_);
130 received_message_ = nullptr;
131 bool res =
132 wait_msg_cond_.wait_for(lock, std::chrono::seconds(timeout), [this] { return received_message_ != nullptr; });
133 if (res) {
134 // Clear the address of received message before returning this address to the caller, because the next
135 // `ReceiveSync` call will block on the received message's condition variable.
136 MessageBase *message = received_message_;
137 return message;
138 } else {
139 MS_LOG(WARNING) << "Failed to receive message.";
140 }
141 } else {
142 MS_LOG(INFO) << "Failed to send message in ReceiveSync.";
143 }
144 return NULL_MSG;
145 }
146
Flush(const std::string & dst_url)147 bool TCPClient::Flush(const std::string &dst_url) { return tcp_comm_->Flush(dst_url); }
148
GetClientIPByDstUrl(const std::string & dst_url) const149 std::string TCPClient::GetClientIPByDstUrl(const std::string &dst_url) const {
150 return tcp_comm_->GetClientSrcIP(dst_url);
151 }
152 } // namespace rpc
153 } // namespace distributed
154 } // namespace mindspore
155