1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_RDMA_RDMA_CLIENT_H_ 18 #define MINDSPORE_CCSRC_DISTRIBUTED_RPC_RDMA_RDMA_CLIENT_H_ 19 20 #include <string> 21 #include <memory> 22 #include <mutex> 23 #include <condition_variable> 24 25 #include "include/backend/distributed/rpc/rdma/constants.h" 26 #include "include/backend/distributed/rpc/rpc_client_base.h" 27 28 namespace mindspore { 29 namespace distributed { 30 namespace rpc { 31 class BACKEND_EXPORT RDMAClient : public RPCClientBase { 32 public: 33 explicit RDMAClient(bool enable_ssl = false) RPCClientBase(enable_ssl)34 : RPCClientBase(enable_ssl), 35 dev_name_(kDefaultIfName), 36 ip_addr_(kDefaultIP), 37 port_(kDefaultPort), 38 func_id_(0), 39 urpc_allocator_(urpc_get_default_allocator_func()), 40 urpc_session_(nullptr) {} 41 ~RDMAClient() override = default; 42 43 bool Initialize() override; 44 void Finalize() override; 45 bool Connect( 46 const std::string &dst_url, size_t retry_count = 60, const MemFreeCallback &free_cb = [](void *data) { 47 MS_ERROR_IF_NULL(data); 48 delete static_cast<char *>(data); 49 return true; 50 }) override; 51 bool IsConnected(const std::string &dst_url) override; 52 bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5) override; 53 54 bool SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes = nullptr) override; 55 void SendAsync(std::unique_ptr<MessageBase> &&msg) override; 56 57 bool Flush(const std::string &dst_url) override; 58 59 // The callback after server responding. 60 static void urpc_rsp_cb(struct urpc_sgl *rsp, int err, void *arg); 61 62 private: 63 std::string dev_name_; 64 std::string ip_addr_; 65 uint16_t port_; 66 uint32_t func_id_; 67 68 struct urpc_buffer_allocator *urpc_allocator_; 69 urpc_session_t *urpc_session_; 70 71 // The variables for synchronization of async messages. 72 std::mutex mtx_; 73 std::condition_variable cv_; 74 75 // Callback arguments when request is successfully received by peer. 76 // It's used in async scenario to do releasing and synchronizing operations. 77 struct req_cb_arg cb_arg_; 78 }; 79 } // namespace rpc 80 } // namespace distributed 81 } // namespace mindspore 82 83 #endif // MINDSPORE_CCSRC_DISTRIBUTED_RPC_RDMA_RDMA_CLIENT_H_ 84