1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/actor/rpc/recv_actor.h"
18
19 #include <memory>
20 #include <utility>
21 #include <functional>
22 #include <condition_variable>
23 #include "proto/topology.pb.h"
24 #include "kernel/framework_utils.h"
25 #include "plugin/device/cpu/kernel/rpc/rpc_recv_kernel.h"
26 #include "include/backend/optimizer/helper.h"
27 #include "include/backend/distributed/rpc/tcp/constants.h"
28
29 namespace mindspore {
30 namespace runtime {
~RecvActor()31 RecvActor::~RecvActor() {
32 if (server_) {
33 try {
34 server_->Finalize();
35 } catch (const std::exception &) {
36 MS_LOG(ERROR) << "Failed to finalize for tcp server in recv actor.";
37 }
38 server_ = nullptr;
39 }
40 }
41
SetOpcontext(OpContext<DeviceTensor> * const op_context)42 void RecvActor::SetOpcontext(OpContext<DeviceTensor> *const op_context) {
43 std::unique_lock<std::mutex> lock(context_mtx_);
44 MS_EXCEPTION_IF_NULL(op_context);
45 op_context_ = op_context;
46 }
47
ResetOpcontext()48 void RecvActor::ResetOpcontext() {
49 std::unique_lock<std::mutex> lock(context_mtx_);
50 is_context_valid_ = false;
51 }
52
UpdateStatus()53 void RecvActor::UpdateStatus() {
54 std::unique_lock<std::mutex> lock(context_mtx_);
55 is_context_valid_ = true;
56 context_cv_.notify_all();
57 }
58
SetRouteInfo(uint32_t,const std::string &,const std::string & recv_src_node_name,const std::string & recv_dst_node_name)59 void RecvActor::SetRouteInfo(uint32_t, const std::string &, const std::string &recv_src_node_name,
60 const std::string &recv_dst_node_name) {
61 (void)rpc_input_node_name_.emplace_back(recv_src_node_name);
62 input_inter_process_num_++;
63 }
64
StartServer()65 bool RecvActor::StartServer() {
66 // Step 1: Create a rpc server and start listening.
67
68 #ifdef ENABLE_RDMA
69 if (common::GetEnv(kEnableRDMA) == "1") {
70 std::string ip = common::GetEnv(kRDMAIP);
71 uint32_t min_port = ClusterContext::instance()->port_range().first;
72 uint32_t max_port = ClusterContext::instance()->port_range().second;
73 uint32_t current_port = min_port;
74 std::string url = ip + ":" + std::to_string(current_port);
75
76 uint32_t retry_num = 0;
77 server_ = std::make_unique<RDMAServer>();
78 MS_EXCEPTION_IF_NULL(server_);
79 while (!server_->Initialize(url) && retry_num++ < kMaxRetryPortNum && current_port <= max_port) {
80 ++current_port;
81 MS_LOG(WARNING) << "Failed to initialize RDMAServer with url: " << url
82 << ". Port number maybe occupied. Retry with increased port number: " << current_port;
83 url = ip + ":" + std::to_string(current_port);
84 }
85 if (!kURPCInited) {
86 MS_LOG(EXCEPTION) << "Failed to initialize RDMAServer.";
87 }
88 } else {
89 server_ = std::make_unique<TCPServer>(false, distributed::cluster::ClusterContext::instance()->port_range());
90 MS_EXCEPTION_IF_NULL(server_);
91 // Set the memory allocating callback using void* message.
92 std::function<void *(size_t size)> allocate_callback =
93 std::bind(&RecvActor::AllocateMessage, this, std::placeholders::_1);
94 if (!server_->Initialize(allocate_callback)) {
95 MS_LOG(EXCEPTION) << "Failed to initialize rpc server for recv actor";
96 }
97 }
98 #else
99 server_ = std::make_unique<TCPServer>(false, distributed::cluster::ClusterContext::instance()->port_range());
100 MS_EXCEPTION_IF_NULL(server_);
101 // Set the memory allocating callback using void* message.
102 std::function<void *(size_t size)> allocate_callback =
103 std::bind(&RecvActor::AllocateMessage, this, std::placeholders::_1);
104 if (!server_->Initialize(allocate_callback)) {
105 MS_LOG(EXCEPTION) << "Failed to initialize rpc server for recv actor";
106 }
107 #endif
108
109 // Step 2: Set the message handler of the server.
110 SetMessageHandler();
111
112 ip_ = server_->GetIP();
113 port_ = server_->GetPort();
114 std::string server_url = ip_ + ":" + std::to_string(port_);
115 // Step 3: Register the server address to route table. The server should not be connected before this step is done.
116 for (const auto &inter_process_edge_name : inter_process_edge_names_) {
117 MS_LOG(INFO) << "Start server for recv actor. Server address: " << server_url
118 << ", remote function id: " << kRemoteFuncId
119 << ", inter-process edge name: " << inter_process_edge_name;
120 distributed::cluster::topology::ActorAddress recv_actor_addresss;
121 recv_actor_addresss.set_actor_id(inter_process_edge_name);
122 recv_actor_addresss.set_ip(ip_);
123 recv_actor_addresss.set_port(port_);
124 recv_actor_addresss.set_func_id(kRemoteFuncId);
125 MS_EXCEPTION_IF_NULL(actor_route_table_proxy_);
126 if (!actor_route_table_proxy_->RegisterRoute(inter_process_edge_name, recv_actor_addresss)) {
127 MS_LOG(EXCEPTION) << "Failed to register route for " << inter_process_edge_name << " " << server_url
128 << " when starting server.";
129 }
130 }
131 return true;
132 }
133
Clear()134 void RecvActor::Clear() {
135 if (server_) {
136 server_->Finalize();
137 server_ = nullptr;
138 }
139 }
140
StopRpcAtException()141 void RecvActor::StopRpcAtException() {
142 std::unique_lock<std::mutex> lock(context_mtx_);
143 if (!is_context_valid_) {
144 is_exception_thrown_ = true;
145 context_cv_.notify_all();
146 }
147 }
148
RunOpInterProcessData(MessageBase * const msg,OpContext<DeviceTensor> * const context)149 void RecvActor::RunOpInterProcessData(MessageBase *const msg, OpContext<DeviceTensor> *const context) {
150 MS_ERROR_IF_NULL_WO_RET_VAL(msg);
151 MS_ERROR_IF_NULL_WO_RET_VAL(op_context_);
152 MS_ERROR_IF_NULL_WO_RET_VAL(context);
153 auto &sequential_num = context->sequential_num_;
154 (void)input_op_inter_process_[sequential_num].emplace_back(msg->From().Name());
155
156 auto is_run = CheckRunningCondition(context);
157 MS_LOG(INFO) << "Actor(" << GetAID().Name() << ") receive the input op inter-process. Edge is "
158 << inter_process_edge_names_ << ". Check running condition:" << is_run;
159
160 // Parse the message from remote peer and set to rpc recv kernel.
161 auto recv_kernel_mod = dynamic_cast<kernel::RpcKernelMod *>(kernel_info_->MutableKernelMod());
162 MS_ERROR_IF_NULL_WO_RET_VAL(recv_kernel_mod);
163
164 // We set remote data by the interface of the rpc kernel, because currently there's no remote input for a kernel mod.
165 recv_kernel_mod->SetRemoteInput(msg);
166 if (common::GetEnv(kEnableRDMA) == "1") {
167 rdma_buf_ = msg->data;
168 }
169
170 if (is_run) {
171 Run(context);
172 }
173 return;
174 }
175
CheckRunningCondition(const OpContext<DeviceTensor> * context) const176 bool RecvActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
177 MS_EXCEPTION_IF_NULL(context);
178 // Step 1: Judge data and control inputs are satisfied.
179 bool is_data_and_control_arrow_satisfied = AbstractActor::CheckRunningCondition(context);
180 if (!is_data_and_control_arrow_satisfied) {
181 return false;
182 }
183
184 if (input_inter_process_num_ != 0) {
185 // Step 2: Judge inter-process inputs are satisfied.
186 const auto &inter_process_iter = input_op_inter_process_.find(context->sequential_num_);
187 if (inter_process_iter == input_op_inter_process_.end()) {
188 return false;
189 }
190
191 const auto ¤t_inter_process_inputs = inter_process_iter->second;
192 if (current_inter_process_inputs.size() < input_inter_process_num_) {
193 return false;
194 } else if (current_inter_process_inputs.size() > input_inter_process_num_) {
195 MS_LOG(ERROR) << "Invalid inter process input num:" << current_inter_process_inputs.size()
196 << " need:" << input_inter_process_num_ << " for actor:" << GetAID();
197 return false;
198 }
199 }
200 return true;
201 }
202
EraseInput(const OpContext<DeviceTensor> * context)203 void RecvActor::EraseInput(const OpContext<DeviceTensor> *context) {
204 MS_EXCEPTION_IF_NULL(context);
205 KernelActor::EraseInput(context);
206
207 if (input_op_inter_process_.count(context->sequential_num_) != 0) {
208 (void)input_op_inter_process_.erase(context->sequential_num_);
209 }
210 // Release data allocated by AllocateMessage.
211 if (recv_data_ != nullptr) {
212 if (!WaitRuntimePipelineFinish(context)) {
213 MS_LOG(INFO) << "Run failed and early stop.";
214 return;
215 }
216 MS_EXCEPTION_IF_CHECK_FAIL((!device_contexts_.empty()), "The device context doesn't exist.");
217 MS_EXCEPTION_IF_NULL(device_contexts_[0]);
218 MS_EXCEPTION_IF_NULL(device_contexts_[0]->device_res_manager_);
219 device_contexts_[0]->device_res_manager_->FreeMemory(recv_data_.get());
220 }
221
222 #ifdef ENABLE_RDMA
223 // Release data of URPC by caller.
224 if (common::GetEnv(kEnableRDMA) == "1" && rdma_buf_ != nullptr) {
225 if (!WaitRuntimePipelineFinish(context)) {
226 MS_LOG(INFO) << "Run failed and early stop.";
227 return;
228 }
229 auto rdma_server = dynamic_cast<RDMAServer *>(server_.get());
230 MS_EXCEPTION_IF_NULL(rdma_server);
231 auto urpc_alloc = rdma_server->urpc_allocator();
232 MS_EXCEPTION_IF_NULL(urpc_alloc);
233 urpc_alloc->free(rdma_buf_);
234 }
235 #endif
236 }
237
Run(OpContext<DeviceTensor> * const context)238 void RecvActor::Run(OpContext<DeviceTensor> *const context) {
239 MS_EXCEPTION_IF_NULL(context);
240 MS_EXCEPTION_IF_NULL(kernel_info_);
241 auto recv_kernel_mod = dynamic_cast<kernel::RpcKernelMod *>(kernel_info_->MutableKernelMod());
242 MS_EXCEPTION_IF_NULL(recv_kernel_mod);
243 auto remote_input = recv_kernel_mod->GetRemoteInput();
244 bool need_finalize = false;
245 // Preprocess the remote input in case data is dynamic shape.
246 PreprocessRemoteInput(remote_input, &need_finalize);
247 if (need_finalize) {
248 return;
249 }
250 KernelActor::Run(context);
251 }
252
AllocateMessage(size_t size)253 void *RecvActor::AllocateMessage(size_t size) {
254 // Block this method until the context is valid.
255 std::unique_lock<std::mutex> lock(context_mtx_);
256 context_cv_.wait(lock, [this] { return is_context_valid_; });
257 lock.unlock();
258
259 return AllocateMemByDeviceRes(size);
260 }
261
AllocateMemByDeviceRes(size_t size)262 void *RecvActor::AllocateMemByDeviceRes(size_t size) {
263 // Only need to create recv_data_ once.
264 // The real data is allocated and freed multiple times as recv_data_->ptr_.
265 if (recv_data_ == nullptr) {
266 recv_data_ = std::make_shared<CPUDeviceAddress>(nullptr, size);
267 MS_ERROR_IF_NULL_W_RET_VAL(recv_data_, nullptr);
268 } else {
269 recv_data_->SetSize(size);
270 }
271
272 MS_EXCEPTION_IF_CHECK_FAIL((!device_contexts_.empty()), "The device context doesn't exist.");
273 MS_ERROR_IF_NULL_W_RET_VAL(device_contexts_[kIndex0], nullptr);
274 MS_ERROR_IF_NULL_W_RET_VAL(device_contexts_[kIndex0]->device_res_manager_, nullptr);
275 if (!device_contexts_[kIndex0]->device_res_manager_->AllocateMemory(recv_data_.get())) {
276 MS_LOG(ERROR) << "Failed to allocate memory size " << size;
277 return nullptr;
278 }
279 return recv_data_->GetMutablePtr();
280 }
281
AddArgSpecForInput(AbstractBasePtrList * args_spec_list,const ShapeVector & shapes,TypeId data_type,size_t input_index) const282 void RecvActor::AddArgSpecForInput(AbstractBasePtrList *args_spec_list, const ShapeVector &shapes, TypeId data_type,
283 size_t input_index) const {
284 MS_EXCEPTION_IF_NULL(args_spec_list);
285 MS_EXCEPTION_IF_NULL(kernel_);
286 auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(kernel_, input_index, false);
287 auto real_input = input_node_with_index.first;
288 size_t real_input_index = input_node_with_index.second;
289 MS_EXCEPTION_IF_NULL(real_input);
290 auto output_addr = AnfAlgo::GetMutableOutputAddr(real_input, real_input_index, false);
291 MS_EXCEPTION_IF_NULL(output_addr);
292 if (output_addr->GetNodeIndex().first == nullptr) {
293 output_addr->SetNodeIndex(kernel_, input_index);
294 }
295 auto out_tensor = std::make_shared<tensor::Tensor>(data_type, shapes);
296 MS_EXCEPTION_IF_NULL(out_tensor);
297 out_tensor->set_device_address(output_addr, false);
298
299 auto real_abs = real_input->abstract();
300 MS_EXCEPTION_IF_NULL(real_abs);
301 auto updated_shape = std::make_shared<abstract::Shape>(shapes);
302 MS_EXCEPTION_IF_NULL(updated_shape);
303 if (real_abs->isa<abstract::AbstractTensor>()) {
304 real_abs->set_value(out_tensor);
305 real_abs->set_shape(updated_shape);
306 } else if (real_abs->isa<abstract::AbstractTuple>()) {
307 if (common::AnfAlgo::IsDynamicSequence(real_input)) {
308 MS_LOG_WITH_NODE(EXCEPTION, real_input)
309 << "Invalid dynamic sequence for actor:" << GetAID() << " node:" << real_input->DebugString();
310 }
311 auto abstract_tuple = real_abs->cast<abstract::AbstractTuplePtr>();
312 MS_EXCEPTION_IF_NULL(abstract_tuple);
313 MS_EXCEPTION_IF_CHECK_FAIL((real_input_index < abstract_tuple->elements().size()), "Index is out of range.");
314 auto tuple_elements = abstract_tuple->elements()[real_input_index];
315 MS_EXCEPTION_IF_NULL(tuple_elements);
316 tuple_elements->set_value(out_tensor);
317 tuple_elements->set_shape(updated_shape);
318 }
319 common::AnfAlgo::AddArgList(args_spec_list, real_input, real_input_index);
320
321 // The inputs of RpcRecv node are all in device tensor store(weight or value node), framework does not free these
322 // device tensors. If these device tensors are not released, they will persist in the same memory size. In dynamic
323 // shape scenarios, there will be out of memory bounds problems.
324 MS_EXCEPTION_IF_NULL(output_addr);
325 auto output_addr_size = AnfAlgo::GetOutputTensorMemSize(real_input, real_input_index);
326 if (output_addr_size != output_addr->GetSize()) {
327 output_addr->SetSize(output_addr_size);
328 MS_EXCEPTION_IF_NULL(device_contexts_[0]);
329 MS_EXCEPTION_IF_NULL(device_contexts_[0]->device_res_manager_);
330 device_contexts_[0]->device_res_manager_->FreeMemory(output_addr.get());
331 }
332
333 // Update kernel tensor shape for dynamic shape case.
334 const auto &output_kernel_tensor = output_addr->kernel_tensor();
335 MS_EXCEPTION_IF_NULL(output_kernel_tensor);
336 const auto &new_shape = real_abs->GetShape();
337 MS_EXCEPTION_IF_NULL(new_shape);
338 output_kernel_tensor->SetShape(new_shape->Clone());
339 }
340
ParseDynamicShapeData(const RpcDataPtr & dynamic_shape_data,size_t data_size,AbstractBasePtrList * args_spec_list,size_t count)341 size_t RecvActor::ParseDynamicShapeData(const RpcDataPtr &dynamic_shape_data, size_t data_size,
342 AbstractBasePtrList *args_spec_list, size_t count) {
343 // The data which could be parsed by offset in dynamic shape scenario.
344 auto data_to_be_parsed = dynamic_shape_data;
345 // The real data offsets which will be used by RpcRecvKernel.
346 std::vector<size_t> real_data_offsets;
347
348 // Once the magic header is dynamic shape, each input of the Recv is dynamic shape.
349 // So traverse each input and parse the dynamic shape data.
350 size_t offset = 0;
351 for (size_t i = 0; i < count; i++) {
352 if (data_to_be_parsed >= dynamic_shape_data + data_size) {
353 MS_LOG(EXCEPTION) << "The dynamic shape data size is invalid.";
354 }
355 // Step 1: parse the magic header which indicates the dynamic shape.
356 std::string dynamic_shape_magic_header(data_to_be_parsed, strlen(kRpcDynamicShapeData));
357 if (dynamic_shape_magic_header != kRpcDynamicShapeData) {
358 MS_LOG(EXCEPTION) << "The dynamie shape data must have the magic header RPC_DYNAMIC_SHAPE_DATA. But got "
359 << dynamic_shape_magic_header;
360 }
361
362 // Step 2: parse the size of serialized protobuf message.
363 data_to_be_parsed += strlen(kRpcDynamicShapeData);
364 size_t pb_msg_size = 0;
365 MS_EXCEPTION_IF_CHECK_FAIL(memcpy_s(&pb_msg_size, sizeof(pb_msg_size), data_to_be_parsed, sizeof(size_t)) == EOK,
366 "memcpy_s protobuf message size failed.");
367
368 // Step 3: deserialize the protobuf message.
369 data_to_be_parsed += sizeof(pb_msg_size);
370 rpc::DynamicShapeMessage pb_msg;
371 (void)pb_msg.ParseFromArray(data_to_be_parsed, SizeToInt(pb_msg_size));
372
373 // Step 4: parse the data shape and
374 ShapeVector shapes(pb_msg.shape_vector().begin(), pb_msg.shape_vector().end());
375 TypeId data_type = static_cast<TypeId>(pb_msg.type_id());
376 data_to_be_parsed += pb_msg_size;
377
378 // Step 5: get the size of real data as recv's input.
379 int64_t real_data_size = 1;
380 if (!kernel::GetShapeSize(shapes, TypeIdToType(data_type), &real_data_size)) {
381 MS_LOG(EXCEPTION) << "Getting shape size for shape " << shapes << " failed.";
382 }
383 data_to_be_parsed += real_data_size;
384
385 // Step 6: update the abstract.
386 AddArgSpecForInput(args_spec_list, shapes, data_type, i);
387
388 offset += strlen(kRpcDynamicShapeData) + sizeof(pb_msg_size) + pb_msg_size;
389 real_data_offsets.push_back(offset);
390 offset += LongToSize(real_data_size);
391 }
392
393 auto recv_kernel_mod = dynamic_cast<kernel::RpcRecvKernelMod *>(kernel_info_->MutableKernelMod());
394 MS_EXCEPTION_IF_NULL(recv_kernel_mod);
395 recv_kernel_mod->set_real_data_offset(real_data_offsets);
396 return offset;
397 }
398
PreprocessRemoteInput(const MessageBase * const msg,bool * need_finalize)399 void RecvActor::PreprocessRemoteInput(const MessageBase *const msg, bool *need_finalize) {
400 MS_EXCEPTION_IF_NULL(msg);
401 MS_EXCEPTION_IF_NULL(need_finalize);
402
403 // Parse the void * data.
404 size_t data_size = msg->size;
405 MS_EXCEPTION_IF_NULL(msg->data);
406 std::string msg_magic_header = std::string(static_cast<RpcDataPtr>(msg->data), strlen(kRpcDynamicShapeData));
407 RpcDataPtr dynamic_shape_data = static_cast<RpcDataPtr>(msg->data);
408
409 if (data_size <= strlen(kRpcDynamicShapeData)) {
410 MS_LOG(DEBUG) << "This is not a dynamic shape data. No need to preprocess.";
411 return;
412 }
413 if (msg_magic_header != kRpcDynamicShapeData) {
414 MS_LOG(DEBUG) << "This is not a dynamic shape data. No need to preprocess.";
415 return;
416 }
417
418 MS_LOG(INFO) << "Preprocess for dynamic shape data.";
419 AbstractBasePtrList args_spec_list;
420 size_t input_size = common::AnfAlgo::GetInputTensorNum(kernel_);
421 size_t dynamic_shape_data_msg_len = ParseDynamicShapeData(dynamic_shape_data, data_size, &args_spec_list, input_size);
422 ParseFinalizeReqData(dynamic_shape_data_msg_len, msg, need_finalize);
423 }
424
HandleMessage(MessageBase * const msg)425 MessageBase *RecvActor::HandleMessage(MessageBase *const msg) {
426 // Block the message handler if the context is invalid.
427 std::unique_lock<std::mutex> lock(context_mtx_);
428 context_cv_.wait(lock, [this] { return is_context_valid_ || is_exception_thrown_; });
429 if (is_exception_thrown_) {
430 MS_LOG(WARNING) << "Recv actor stops waiting for op_context at exception.";
431 return distributed::rpc::NULL_MSG;
432 }
433 lock.unlock();
434 // Once recv actor is launched, lock the context so that the next step's recv will not be launched in advance.
435 ResetOpcontext();
436
437 MS_LOG(INFO) << "Rpc actor recv message for inter-process edge: " << inter_process_edge_names_;
438
439 if (msg == nullptr || op_context_ == nullptr) {
440 return distributed::rpc::NULL_MSG;
441 }
442 ActorDispatcher::Send(GetAID(), &RecvActor::RunOpInterProcessData, msg, op_context_);
443 return distributed::rpc::NULL_MSG;
444 }
445
SetMessageHandler()446 void RecvActor::SetMessageHandler() {
447 server_->SetMessageHandler(std::bind(&RecvActor::HandleMessage, this, std::placeholders::_1), ++kRemoteFuncId);
448 }
449 } // namespace runtime
450 } // namespace mindspore
451