1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_CONNECTION_H_ 18 #define MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_CONNECTION_H_ 19 20 #include <queue> 21 #include <string> 22 #include <mutex> 23 #include <memory> 24 25 #include "actor/msg.h" 26 #include "include/backend/distributed/rpc/tcp/constants.h" 27 #include "distributed/rpc/tcp/event_loop.h" 28 #include "distributed/rpc/tcp/socket_operation.h" 29 30 namespace mindspore { 31 namespace distributed { 32 namespace rpc { 33 /* 34 * The SendMetrics is responsible for collecting metrics when sending data through a connection. 35 */ 36 struct SendMetrics { 37 // Records the message number and max body size. UpdateMaxSendMetrics38 void UpdateMax(size_t size) { 39 accum_msg_count++; 40 if (size > max_msg_size) { 41 max_msg_size = size; 42 } 43 } 44 45 // Records the latest error message. 46 void UpdateError(bool fail, int err = 0) { 47 if (fail) { 48 last_fail_msg_name = last_send_msg_name; 49 error_code = err; 50 } else { 51 last_succ_msg_name = last_send_msg_name; 52 } 53 } 54 55 // Reset all the metrics info. ResetSendMetrics56 void Reset() { 57 accum_msg_count = 0; 58 max_msg_size = 0; 59 error_code = 0; 60 last_succ_msg_name = ""; 61 last_fail_msg_name = ""; 62 last_send_msg_name = ""; 63 } 64 65 // The total number of bytes sent already. 66 size_t accum_msg_count{0}; 67 68 // The max message body size sent in bytes. 69 size_t max_msg_size{0}; 70 int error_code{0}; 71 72 std::string last_succ_msg_name; 73 std::string last_fail_msg_name; 74 std::string last_send_msg_name; 75 }; 76 77 /* 78 * Represents a TCP or SSL connection. 79 */ 80 struct Connection { 81 public: 82 Connection(); 83 ~Connection() = default; 84 85 // Initialize the connection(eg. add some socket event handlers). 86 int Initialize(); 87 88 // Create a new socket operation if needed. 89 void InitSocketOperation(); 90 91 // Delete this socket fd(source client socket) and add back to the connection. 92 bool ReconnectSourceSocket(int fd, uint32_t events, int *soError, uint32_t error); 93 94 // Disconnect the socket fd from source. 95 void Disconnect(int fd); 96 97 // Close this connection. 98 void Close(); 99 100 int ReceiveMessage(); 101 void CheckMessageType(); 102 103 // Fill the message to be sent based on the input message. 104 void FillSendMessage(MessageBase *msg, const std::string &advertiseUrl, bool isHttpKmsg); 105 106 void FillRecvMessage(); 107 IsSameConnection108 bool IsSame(const Connection *that) { 109 return !(that != nullptr && that->destination == destination && that->is_remote == is_remote); 110 } 111 112 // Send all the messages in the message queue. 113 size_t Flush(); 114 115 /** 116 * @description: Set callback to allocate memory for this connection when receiving message from the remote. 117 * @param {MemAllocateCallback} &allocate_cb: The allocating memory callback. 118 * @return {void} 119 */ SetAllocateCallbackConnection120 void SetAllocateCallback(const MemAllocateCallback &allocate_cb) { allocate_cb_ = allocate_cb; } 121 122 /** 123 * @description: Set callback to free message for this connection. 124 * @param {MemFreeCallback} &free_cb: The callback which frees the real memory after message is sent to peer. 125 * @return {void} 126 */ SetMessageFreeCallbackConnection127 void SetMessageFreeCallback(const MemFreeCallback &free_cb) { free_cb_ = free_cb; } 128 129 /** 130 * @description: Returns the free callback. 131 * @return {const MemFreeCallback &} 132 */ free_cbConnection133 const MemFreeCallback &free_cb() const { return free_cb_; } 134 135 /** 136 * @description: Free the real data of the message using free callback. 137 * @param {MessageBase} *msg: The MessageBase object. 138 * @return {bool}: Whether successfully freeing the real data. 139 */ 140 bool FreeMessageMemory(MessageBase *msg); 141 142 // The socket used by this connection. 143 int socket_fd; 144 145 // Indicates whether this connection is deleted from link manager. 146 bool deleted; 147 148 // Indicates the priority of this connection. 149 ConnectionPriority priority{ConnectionPriority::kPriorityHigh}; 150 151 // Indicates whether this connection is connected from remote client. 152 // A connection is remote only when the connection is created by the `OnAccept` callback. 153 bool is_remote; 154 155 // TCP or SSL. 156 ConnectionType type; 157 158 // The socket address(ip:port) of client and server of this connection. 159 std::string source; 160 std::string destination; 161 162 // Peer address. 163 std::string peer; 164 165 // Specific operations for the socket in this connection. 166 SocketOperation *socket_operation; 167 168 bool enable_ssl{false}; 169 170 // The state of this connection(eg. kInit/kConnecting/..) 171 ConnectionState state{kInit}; 172 173 // The threads for handling the receive and send requsets on this connection. 174 EventLoop *send_event_loop; 175 EventLoop *recv_event_loop; 176 177 // Collects data sending metrics. 178 SendMetrics *send_metrics; 179 180 // The message data waiting to be sent and receive through this connection.. 181 MessageBase *send_message; 182 MessageBase *recv_message; 183 184 // Owned by the tcp_comm. 185 std::shared_ptr<std::mutex> conn_mutex; 186 187 // Owned by connection itself. 188 std::mutex conn_owned_mutex_; 189 190 State recv_state; 191 192 // Total length of received and sent messages. 193 size_t total_recv_len; 194 size_t total_send_len; 195 size_t recv_len; 196 197 std::string send_to; 198 std::string send_from; 199 std::string recv_to; 200 std::string recv_from; 201 202 // Message header. 203 MessageHeader send_msg_header; 204 MessageHeader recv_msg_header; 205 206 // The message structure of kernel. 207 struct msghdr send_kernel_msg; 208 struct msghdr recv_kernel_msg; 209 210 struct iovec recv_io_vec[RECV_MSG_IO_VEC_LEN]; 211 struct iovec send_io_vec[SEND_MSG_IO_VEC_LEN]; 212 213 ParseType recv_message_type{kTcpMsg}; 214 215 // Callbacks for io events 216 ConnectionCallBack event_callback; 217 ConnectionCallBack succ_callback; 218 ConnectionCallBack write_callback; 219 ConnectionCallBack read_callback; 220 221 // Function for handling received messages. 222 MessageHandler message_handler; 223 224 // Buffer for messages to be sent. 225 std::queue<MessageBase *> send_message_queue; 226 227 uint64_t output_buffer_size; 228 229 // The error code when sending or receiving messages. 230 int error_code; 231 232 // The method used to allocate memory when server receiving message from the remote. 233 MemAllocateCallback allocate_cb_; 234 235 // The method used to free the memory after client sending data to the remote. 236 MemFreeCallback free_cb_; 237 238 private: 239 // Add handler for socket connect event. 240 int AddConnnectEventHandler(); 241 242 // Parse message from socket recv buffer. 243 bool ParseMessage(); 244 245 // After ParseMessage, set from url and to url into recv message. 246 bool SetUrlForRecvMessage(); 247 248 // Make a http message based on given input message. 249 std::string GenerateHttpMessage(MessageBase *msg); 250 251 // Change the header body from network byte order to host byte order. 252 void ReorderHeader(MessageHeader *header) const; 253 254 /** 255 * @description: Get the real data pointer of the message. 256 * @param {MessageBase} *msg: The MessageBase object. 257 * @return {void *}: The pointer to the memory of the real data. 258 */ 259 void *GetMessageBaseRealData(const MessageBase *msg) const; 260 261 /** 262 * @description: Get size of the real data size. 263 * @param {MessageBase} *msg: The MessageBase object. 264 * @return {size_t}: The size of the real data. 265 */ 266 size_t GetMessageBaseRealDataSize(const MessageBase *msg) const; 267 268 std::string advertise_addr_; 269 }; 270 } // namespace rpc 271 } // namespace distributed 272 } // namespace mindspore 273 274 #endif 275