1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_TCP_COMM_H_ 18 #define MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_TCP_COMM_H_ 19 20 #include <map> 21 #include <string> 22 #include <memory> 23 #include <mutex> 24 25 #include "actor/msg.h" 26 #include "distributed/rpc/tcp/connection.h" 27 #include "distributed/rpc/tcp/connection_pool.h" 28 #include "distributed/rpc/tcp/event_loop.h" 29 30 namespace mindspore { 31 namespace distributed { 32 namespace rpc { 33 // Event handler for new connecting request arrived. 34 void OnAccept(int server, uint32_t events, void *arg); 35 36 void DoDisconnect(int fd, Connection *conn, uint32_t error, int soError); 37 38 void ConnectedEventHandler(int fd, uint32_t events, void *context); 39 40 class TCPComm { 41 public: 42 explicit TCPComm(bool enable_ssl = false) 43 : server_fd_(-1), recv_event_loop_(nullptr), send_event_loop_(nullptr), enable_ssl_(enable_ssl) {} 44 TCPComm(const TCPComm &) = delete; 45 TCPComm &operator=(const TCPComm &) = delete; 46 ~TCPComm() = default; 47 48 // Init the event loop for reading and writing. 49 bool Initialize(); 50 51 // Destroy all the resources. 52 void Finalize(); 53 54 // Create the server socket represented by url. 55 // allocate_cb is the method used to allocate memory when server receiving message from the remote. 56 int StartServerSocket(const std::string &url, const MemAllocateCallback &allocate_cb); 57 58 // Create the server socket with local IP and random port. 59 int StartServerSocket(const MemAllocateCallback &allocate_cb); 60 61 // Connection operation for a specified destination. 62 bool Connect(const std::string &dst_url, const MemFreeCallback &free_cb); 63 bool IsConnected(const std::string &dst_url); 64 bool Disconnect(const std::string &dst_url); 65 66 // Send the message from the source to the destination. 67 // The flag sync means if the message is sent directly or added to the task queue. 68 bool Send(MessageBase *msg, size_t *const send_bytes, bool sync = false); 69 70 // Force the data in the send buffer to be sent out. 71 bool Flush(const std::string &dst_url); 72 73 // Set the message processing handler. 74 void SetMessageHandler(const MessageHandler &handler); 75 76 // Get the file descriptor of server socket. 77 int GetServerFd() const; 78 GetClientSrcIP(const std::string & dst_url)79 const std::string &GetClientSrcIP(const std::string &dst_url) { return dst_url_to_src_ip_[dst_url]; } 80 81 /** 82 * @description: Returns the allocating callback. 83 * @return {const MemAllocateCallback &} 84 */ allocate_cb()85 const MemAllocateCallback &allocate_cb() const { return allocate_cb_; } 86 87 private: 88 // Build the connection. 89 Connection *CreateDefaultConn(const std::string &to); 90 91 // Send a message. 92 static void SendExitMsg(const std::string &from, const std::string &to); 93 94 // Called by ReadCallBack when new message arrived. 95 static int ReceiveMessage(Connection *conn); 96 97 static int SetConnectedHandler(Connection *conn); 98 99 static int DoConnect(Connection *conn, const struct sockaddr *sa, socklen_t saLen); 100 101 static void DropMessage(MessageBase *msg); 102 103 // Read and write events. 104 void ReadCallBack(void *conn); 105 void WriteCallBack(void *conn); 106 // Connected and Disconnected events. 107 void EventCallBack(void *conn); 108 109 // The server url. 110 std::string url_; 111 112 // The socket of server. 113 int server_fd_; 114 115 // User defined handler for Handling received messages. 116 MessageHandler message_handler_; 117 118 // All the connections share the same read and write event loop objects. 119 EventLoop *recv_event_loop_; 120 EventLoop *send_event_loop_; 121 122 // The connection pool used to store new connections. 123 std::shared_ptr<ConnectionPool> conn_pool_; 124 125 // The mutex for connection operations. 126 std::shared_ptr<std::mutex> conn_mutex_; 127 128 // The method used to allocate memory when tcp servers of this TcpComm receive message from the remote. 129 MemAllocateCallback allocate_cb_; 130 131 bool enable_ssl_; 132 133 friend void OnAccept(int server, uint32_t events, void *arg); 134 friend int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack event_callback, 135 ConnectionCallBack write_callback, ConnectionCallBack read_callback); 136 137 // The map from dst_url to src_ip which this tcp client uses. 138 std::map<std::string, std::string> dst_url_to_src_ip_; 139 }; 140 } // namespace rpc 141 } // namespace distributed 142 } // namespace mindspore 143 144 #endif 145