1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/backend/distributed/rpc/rdma/rdma_server.h"
18
19 #include <string>
20
21 namespace mindspore {
22 namespace distributed {
23 namespace rpc {
Initialize(const std::string & url,const MemAllocateCallback & allocate_cb)24 bool RDMAServer::Initialize(const std::string &url, const MemAllocateCallback &allocate_cb) {
25 if (!ParseURL(url, &ip_addr_, &port_)) {
26 MS_LOG(EXCEPTION) << "Failed to parse url " << url;
27 }
28
29 return InitializeURPC(dev_name_, ip_addr_, port_);
30 }
31
Finalize()32 void RDMAServer::Finalize() {
33 if (message_handler_) {
34 urpc_unregister_handler_func(nullptr, func_id_);
35 }
36 if (kURPCInited) {
37 urpc_uninit_func();
38 kURPCInited = false;
39 }
40 }
41
SetMessageHandler(const MessageHandler & handler,uint32_t func_id)42 void RDMAServer::SetMessageHandler(const MessageHandler &handler, uint32_t func_id) {
43 if (!handler) {
44 MS_LOG(EXCEPTION) << "The handler of RDMAServer is empty.";
45 }
46 message_handler_ = handler;
47 func_id_ = func_id;
48
49 if (urpc_register_raw_handler_explicit_func(urpc_req_handler, this, urpc_rsp_handler, urpc_allocator_, func_id_) !=
50 kURPCSuccess) {
51 MS_LOG(EXCEPTION) << "Failed to set handler for RDMAServer of func_id: " << func_id_;
52 }
53 }
54
GetIP() const55 std::string RDMAServer::GetIP() const { return ip_addr_; }
56
GetPort() const57 uint32_t RDMAServer::GetPort() const { return static_cast<uint32_t>(port_); }
58
urpc_req_handler(struct urpc_sgl * req,void * arg,struct urpc_sgl * rsp)59 void RDMAServer::urpc_req_handler(struct urpc_sgl *req, void *arg, struct urpc_sgl *rsp) {
60 MS_ERROR_IF_NULL_WO_RET_VAL(req);
61 MS_ERROR_IF_NULL_WO_RET_VAL(arg);
62 MS_ERROR_IF_NULL_WO_RET_VAL(rsp);
63
64 MessageBase *msg = new (std::nothrow) MessageBase();
65 MS_ERROR_IF_NULL_WO_RET_VAL(msg);
66 // Pay attention: when client send one message with URPC_SGE_FLAG_RENDEZVOUS, the data is stored in sge[1].
67 msg->data = reinterpret_cast<void *>(req->sge[1].addr);
68 msg->size = req->sge[1].length;
69
70 RDMAServer *server = static_cast<RDMAServer *>(arg);
71 MessageHandler message_handler = server->message_handler_;
72 (void)message_handler(msg);
73
74 std::string rsp_msg = "Client calls " + std::to_string(server->func_id()) + " function.";
75 auto rsp_buf = server->urpc_allocator()->alloc(rsp_msg.size());
76 if (memcpy_s(rsp_buf, rsp_msg.size(), rsp_msg.c_str(), rsp_msg.size()) != EOK) {
77 server->urpc_allocator().free(rsp_buf);
78 MS_LOG(EXCEPTION) << "Failed to memcpy_s for response message.";
79 }
80 rsp->sge[0].addr = reinterpret_cast<uintptr_t>(rsp_buf);
81 rsp->sge[0].length = rsp_msg.size();
82 rsp->sge[0].flag = URPC_SGE_FLAG_ZERO_COPY;
83 rsp->sge_num = 1;
84 }
85
urpc_rsp_handler(struct urpc_sgl * rsp,void * arg)86 void RDMAServer::urpc_rsp_handler(struct urpc_sgl *rsp, void *arg) {
87 MS_ERROR_IF_NULL_WO_RET_VAL(rsp);
88 MS_ERROR_IF_NULL_WO_RET_VAL(arg);
89
90 auto urpc_allocator = static_cast<struct urpc_buffer_allocator *>(arg);
91 urpc_allocator->free(reinterpret_cast<void *>(rsp->sge[0].addr));
92 }
93 } // namespace rpc
94 } // namespace distributed
95 } // namespace mindspore
96