• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_TCP_COMM_H_
18 #define MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_TCP_COMM_H_
19 
20 #include <map>
21 #include <string>
22 #include <memory>
23 #include <mutex>
24 
25 #include "actor/msg.h"
26 #include "distributed/rpc/tcp/connection.h"
27 #include "distributed/rpc/tcp/connection_pool.h"
28 #include "distributed/rpc/tcp/event_loop.h"
29 
30 namespace mindspore {
31 namespace distributed {
32 namespace rpc {
33 // Event handler for new connecting request arrived.
34 void OnAccept(int server, uint32_t events, void *arg);
35 
36 void DoDisconnect(int fd, Connection *conn, uint32_t error, int soError);
37 
38 void ConnectedEventHandler(int fd, uint32_t events, void *context);
39 
40 class TCPComm {
41  public:
42   explicit TCPComm(bool enable_ssl = false)
43       : server_fd_(-1), recv_event_loop_(nullptr), send_event_loop_(nullptr), enable_ssl_(enable_ssl) {}
44   TCPComm(const TCPComm &) = delete;
45   TCPComm &operator=(const TCPComm &) = delete;
46   ~TCPComm() = default;
47 
48   // Init the event loop for reading and writing.
49   bool Initialize();
50 
51   // Destroy all the resources.
52   void Finalize();
53 
54   // Create the server socket represented by url.
55   // allocate_cb is the method used to allocate memory when server receiving message from the remote.
56   int StartServerSocket(const std::string &url, const MemAllocateCallback &allocate_cb);
57 
58   // Create the server socket with local IP and random port.
59   int StartServerSocket(const MemAllocateCallback &allocate_cb);
60 
61   // Connection operation for a specified destination.
62   bool Connect(const std::string &dst_url, const MemFreeCallback &free_cb);
63   bool IsConnected(const std::string &dst_url);
64   bool Disconnect(const std::string &dst_url);
65 
66   // Send the message from the source to the destination.
67   // The flag sync means if the message is sent directly or added to the task queue.
68   bool Send(MessageBase *msg, size_t *const send_bytes, bool sync = false);
69 
70   // Force the data in the send buffer to be sent out.
71   bool Flush(const std::string &dst_url);
72 
73   // Set the message processing handler.
74   void SetMessageHandler(const MessageHandler &handler);
75 
76   // Get the file descriptor of server socket.
77   int GetServerFd() const;
78 
GetClientSrcIP(const std::string & dst_url)79   const std::string &GetClientSrcIP(const std::string &dst_url) { return dst_url_to_src_ip_[dst_url]; }
80 
81   /**
82    * @description: Returns the allocating callback.
83    * @return {const MemAllocateCallback &}
84    */
allocate_cb()85   const MemAllocateCallback &allocate_cb() const { return allocate_cb_; }
86 
87  private:
88   // Build the connection.
89   Connection *CreateDefaultConn(const std::string &to);
90 
91   // Send a message.
92   static void SendExitMsg(const std::string &from, const std::string &to);
93 
94   // Called by ReadCallBack when new message arrived.
95   static int ReceiveMessage(Connection *conn);
96 
97   static int SetConnectedHandler(Connection *conn);
98 
99   static int DoConnect(Connection *conn, const struct sockaddr *sa, socklen_t saLen);
100 
101   static void DropMessage(MessageBase *msg);
102 
103   // Read and write events.
104   void ReadCallBack(void *conn);
105   void WriteCallBack(void *conn);
106   // Connected and Disconnected events.
107   void EventCallBack(void *conn);
108 
109   // The server url.
110   std::string url_;
111 
112   // The socket of server.
113   int server_fd_;
114 
115   // User defined handler for Handling received messages.
116   MessageHandler message_handler_;
117 
118   // All the connections share the same read and write event loop objects.
119   EventLoop *recv_event_loop_;
120   EventLoop *send_event_loop_;
121 
122   // The connection pool used to store new connections.
123   std::shared_ptr<ConnectionPool> conn_pool_;
124 
125   // The mutex for connection operations.
126   std::shared_ptr<std::mutex> conn_mutex_;
127 
128   // The method used to allocate memory when tcp servers of this TcpComm receive message from the remote.
129   MemAllocateCallback allocate_cb_;
130 
131   bool enable_ssl_;
132 
133   friend void OnAccept(int server, uint32_t events, void *arg);
134   friend int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack event_callback,
135                        ConnectionCallBack write_callback, ConnectionCallBack read_callback);
136 
137   // The map from dst_url to src_ip which this tcp client uses.
138   std::map<std::string, std::string> dst_url_to_src_ip_;
139 };
140 }  // namespace rpc
141 }  // namespace distributed
142 }  // namespace mindspore
143 
144 #endif
145