• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/backend/distributed/rpc/rdma/rdma_server.h"
18 
19 #include <string>
20 
21 namespace mindspore {
22 namespace distributed {
23 namespace rpc {
Initialize(const std::string & url,const MemAllocateCallback & allocate_cb)24 bool RDMAServer::Initialize(const std::string &url, const MemAllocateCallback &allocate_cb) {
25   if (!ParseURL(url, &ip_addr_, &port_)) {
26     MS_LOG(EXCEPTION) << "Failed to parse url " << url;
27   }
28 
29   return InitializeURPC(dev_name_, ip_addr_, port_);
30 }
31 
Finalize()32 void RDMAServer::Finalize() {
33   if (message_handler_) {
34     urpc_unregister_handler_func(nullptr, func_id_);
35   }
36   if (kURPCInited) {
37     urpc_uninit_func();
38     kURPCInited = false;
39   }
40 }
41 
SetMessageHandler(const MessageHandler & handler,uint32_t func_id)42 void RDMAServer::SetMessageHandler(const MessageHandler &handler, uint32_t func_id) {
43   if (!handler) {
44     MS_LOG(EXCEPTION) << "The handler of RDMAServer is empty.";
45   }
46   message_handler_ = handler;
47   func_id_ = func_id;
48 
49   if (urpc_register_raw_handler_explicit_func(urpc_req_handler, this, urpc_rsp_handler, urpc_allocator_, func_id_) !=
50       kURPCSuccess) {
51     MS_LOG(EXCEPTION) << "Failed to set handler for RDMAServer of func_id: " << func_id_;
52   }
53 }
54 
GetIP() const55 std::string RDMAServer::GetIP() const { return ip_addr_; }
56 
GetPort() const57 uint32_t RDMAServer::GetPort() const { return static_cast<uint32_t>(port_); }
58 
urpc_req_handler(struct urpc_sgl * req,void * arg,struct urpc_sgl * rsp)59 void RDMAServer::urpc_req_handler(struct urpc_sgl *req, void *arg, struct urpc_sgl *rsp) {
60   MS_ERROR_IF_NULL_WO_RET_VAL(req);
61   MS_ERROR_IF_NULL_WO_RET_VAL(arg);
62   MS_ERROR_IF_NULL_WO_RET_VAL(rsp);
63 
64   MessageBase *msg = new (std::nothrow) MessageBase();
65   MS_ERROR_IF_NULL_WO_RET_VAL(msg);
66   // Pay attention: when client send one message with URPC_SGE_FLAG_RENDEZVOUS, the data is stored in sge[1].
67   msg->data = reinterpret_cast<void *>(req->sge[1].addr);
68   msg->size = req->sge[1].length;
69 
70   RDMAServer *server = static_cast<RDMAServer *>(arg);
71   MessageHandler message_handler = server->message_handler_;
72   (void)message_handler(msg);
73 
74   std::string rsp_msg = "Client calls " + std::to_string(server->func_id()) + " function.";
75   auto rsp_buf = server->urpc_allocator()->alloc(rsp_msg.size());
76   if (memcpy_s(rsp_buf, rsp_msg.size(), rsp_msg.c_str(), rsp_msg.size()) != EOK) {
77     server->urpc_allocator().free(rsp_buf);
78     MS_LOG(EXCEPTION) << "Failed to memcpy_s for response message.";
79   }
80   rsp->sge[0].addr = reinterpret_cast<uintptr_t>(rsp_buf);
81   rsp->sge[0].length = rsp_msg.size();
82   rsp->sge[0].flag = URPC_SGE_FLAG_ZERO_COPY;
83   rsp->sge_num = 1;
84 }
85 
urpc_rsp_handler(struct urpc_sgl * rsp,void * arg)86 void RDMAServer::urpc_rsp_handler(struct urpc_sgl *rsp, void *arg) {
87   MS_ERROR_IF_NULL_WO_RET_VAL(rsp);
88   MS_ERROR_IF_NULL_WO_RET_VAL(arg);
89 
90   auto urpc_allocator = static_cast<struct urpc_buffer_allocator *>(arg);
91   urpc_allocator->free(reinterpret_cast<void *>(rsp->sge[0].addr));
92 }
93 }  // namespace rpc
94 }  // namespace distributed
95 }  // namespace mindspore
96