• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_TCP_CLIENT_H_
18 #define MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_TCP_CLIENT_H_
19 
20 #include <string>
21 #include <memory>
22 #include <mutex>
23 #include <condition_variable>
24 
25 #include "include/backend/distributed/rpc/rpc_client_base.h"
26 #include "utils/ms_utils.h"
27 #include "include/backend/visible.h"
28 
29 namespace mindspore {
30 namespace distributed {
31 namespace rpc {
32 class TCPComm;
33 
34 class BACKEND_EXPORT TCPClient : public RPCClientBase {
35  public:
36   explicit TCPClient(bool enable_ssl = false);
37   ~TCPClient() override;
38 
39   // Build or destroy the TCP client.
40   bool Initialize() override;
41   void Finalize() override;
42 
43   // Connect to the specified server.
44   // Function free_cb binds with client's each connection. It frees the real memory after message is sent to the peer.
45   bool Connect(
46     const std::string &dst_url, size_t retry_count = 60, const MemFreeCallback &free_cb = [](void *data) {
47       MS_ERROR_IF_NULL(data);
48       delete static_cast<char *>(data);
49       return true;
50     }) override;
51 
52   // Check if the connection to dst_url has been established.
53   bool IsConnected(const std::string &dst_url) override;
54 
55   // Disconnect from the specified server.
56   bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5) override;
57 
58   // Send the message from the source to the destination synchronously and return the byte size by this method call.
59   bool SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes = nullptr) override;
60 
61   // Send the message from the source to the destination asynchronously.
62   void SendAsync(std::unique_ptr<MessageBase> &&msg) override;
63 
64   // Retrieve a message from tcp server specified by the input message.
65   // Returns nullptr after timeout.
66   MessageBase *ReceiveSync(std::unique_ptr<MessageBase> &&msg, uint32_t timeout = UINT32_MAX) override;
67 
68   // Force the data in the send buffer to be sent out.
69   bool Flush(const std::string &dst_url) override;
70 
71   std::string GetClientIPByDstUrl(const std::string &dst_url) const;
72 
73  private:
74   // The basic TCP communication component used by the client.
75   std::unique_ptr<TCPComm> tcp_comm_;
76 
77   // The mutex and condition variable used to synchronize the write and read of the received message returned by calling
78   // the `ReceiveSync` method.
79   std::mutex mutex_;
80   std::condition_variable wait_msg_cond_;
81 
82   // The received message from the meta server by calling the method `ReceiveSync`.
83   MessageBase *received_message_;
84 
85   // The timeout(second) window for receiving message from other nodes.
86   uint32_t receive_timeout_;
87 
88   DISABLE_COPY_AND_ASSIGN(TCPClient);
89 };
90 }  // namespace rpc
91 }  // namespace distributed
92 }  // namespace mindspore
93 
94 #endif
95