• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/actor/rpc/recv_actor.h"
18 
19 #include <memory>
20 #include <utility>
21 #include <functional>
22 #include <condition_variable>
23 #include "proto/topology.pb.h"
24 #include "kernel/framework_utils.h"
25 #include "plugin/device/cpu/kernel/rpc/rpc_recv_kernel.h"
26 #include "include/backend/optimizer/helper.h"
27 #include "include/backend/distributed/rpc/tcp/constants.h"
28 
29 namespace mindspore {
30 namespace runtime {
~RecvActor()31 RecvActor::~RecvActor() {
32   if (server_) {
33     try {
34       server_->Finalize();
35     } catch (const std::exception &) {
36       MS_LOG(ERROR) << "Failed to finalize for tcp server in recv actor.";
37     }
38     server_ = nullptr;
39   }
40 }
41 
SetOpcontext(OpContext<DeviceTensor> * const op_context)42 void RecvActor::SetOpcontext(OpContext<DeviceTensor> *const op_context) {
43   std::unique_lock<std::mutex> lock(context_mtx_);
44   MS_EXCEPTION_IF_NULL(op_context);
45   op_context_ = op_context;
46 }
47 
ResetOpcontext()48 void RecvActor::ResetOpcontext() {
49   std::unique_lock<std::mutex> lock(context_mtx_);
50   is_context_valid_ = false;
51 }
52 
UpdateStatus()53 void RecvActor::UpdateStatus() {
54   std::unique_lock<std::mutex> lock(context_mtx_);
55   is_context_valid_ = true;
56   context_cv_.notify_all();
57 }
58 
SetRouteInfo(uint32_t,const std::string &,const std::string & recv_src_node_name,const std::string & recv_dst_node_name)59 void RecvActor::SetRouteInfo(uint32_t, const std::string &, const std::string &recv_src_node_name,
60                              const std::string &recv_dst_node_name) {
61   (void)rpc_input_node_name_.emplace_back(recv_src_node_name);
62   input_inter_process_num_++;
63 }
64 
StartServer()65 bool RecvActor::StartServer() {
66   // Step 1: Create a rpc server and start listening.
67 
68 #ifdef ENABLE_RDMA
69   if (common::GetEnv(kEnableRDMA) == "1") {
70     std::string ip = common::GetEnv(kRDMAIP);
71     uint32_t min_port = ClusterContext::instance()->port_range().first;
72     uint32_t max_port = ClusterContext::instance()->port_range().second;
73     uint32_t current_port = min_port;
74     std::string url = ip + ":" + std::to_string(current_port);
75 
76     uint32_t retry_num = 0;
77     server_ = std::make_unique<RDMAServer>();
78     MS_EXCEPTION_IF_NULL(server_);
79     while (!server_->Initialize(url) && retry_num++ < kMaxRetryPortNum && current_port <= max_port) {
80       ++current_port;
81       MS_LOG(WARNING) << "Failed to initialize RDMAServer with url: " << url
82                       << ". Port number maybe occupied. Retry with increased port number: " << current_port;
83       url = ip + ":" + std::to_string(current_port);
84     }
85     if (!kURPCInited) {
86       MS_LOG(EXCEPTION) << "Failed to initialize RDMAServer.";
87     }
88   } else {
89     server_ = std::make_unique<TCPServer>(false, distributed::cluster::ClusterContext::instance()->port_range());
90     MS_EXCEPTION_IF_NULL(server_);
91     // Set the memory allocating callback using void* message.
92     std::function<void *(size_t size)> allocate_callback =
93       std::bind(&RecvActor::AllocateMessage, this, std::placeholders::_1);
94     if (!server_->Initialize(allocate_callback)) {
95       MS_LOG(EXCEPTION) << "Failed to initialize rpc server for recv actor";
96     }
97   }
98 #else
99   server_ = std::make_unique<TCPServer>(false, distributed::cluster::ClusterContext::instance()->port_range());
100   MS_EXCEPTION_IF_NULL(server_);
101   // Set the memory allocating callback using void* message.
102   std::function<void *(size_t size)> allocate_callback =
103     std::bind(&RecvActor::AllocateMessage, this, std::placeholders::_1);
104   if (!server_->Initialize(allocate_callback)) {
105     MS_LOG(EXCEPTION) << "Failed to initialize rpc server for recv actor";
106   }
107 #endif
108 
109   // Step 2: Set the message handler of the server.
110   SetMessageHandler();
111 
112   ip_ = server_->GetIP();
113   port_ = server_->GetPort();
114   std::string server_url = ip_ + ":" + std::to_string(port_);
115   // Step 3: Register the server address to route table. The server should not be connected before this step is done.
116   for (const auto &inter_process_edge_name : inter_process_edge_names_) {
117     MS_LOG(INFO) << "Start server for recv actor. Server address: " << server_url
118                  << ", remote function id: " << kRemoteFuncId
119                  << ", inter-process edge name: " << inter_process_edge_name;
120     distributed::cluster::topology::ActorAddress recv_actor_addresss;
121     recv_actor_addresss.set_actor_id(inter_process_edge_name);
122     recv_actor_addresss.set_ip(ip_);
123     recv_actor_addresss.set_port(port_);
124     recv_actor_addresss.set_func_id(kRemoteFuncId);
125     MS_EXCEPTION_IF_NULL(actor_route_table_proxy_);
126     if (!actor_route_table_proxy_->RegisterRoute(inter_process_edge_name, recv_actor_addresss)) {
127       MS_LOG(EXCEPTION) << "Failed to register route for " << inter_process_edge_name << " " << server_url
128                         << " when starting server.";
129     }
130   }
131   return true;
132 }
133 
Clear()134 void RecvActor::Clear() {
135   if (server_) {
136     server_->Finalize();
137     server_ = nullptr;
138   }
139 }
140 
StopRpcAtException()141 void RecvActor::StopRpcAtException() {
142   std::unique_lock<std::mutex> lock(context_mtx_);
143   if (!is_context_valid_) {
144     is_exception_thrown_ = true;
145     context_cv_.notify_all();
146   }
147 }
148 
RunOpInterProcessData(MessageBase * const msg,OpContext<DeviceTensor> * const context)149 void RecvActor::RunOpInterProcessData(MessageBase *const msg, OpContext<DeviceTensor> *const context) {
150   MS_ERROR_IF_NULL_WO_RET_VAL(msg);
151   MS_ERROR_IF_NULL_WO_RET_VAL(op_context_);
152   MS_ERROR_IF_NULL_WO_RET_VAL(context);
153   auto &sequential_num = context->sequential_num_;
154   (void)input_op_inter_process_[sequential_num].emplace_back(msg->From().Name());
155 
156   auto is_run = CheckRunningCondition(context);
157   MS_LOG(INFO) << "Actor(" << GetAID().Name() << ") receive the input op inter-process. Edge is "
158                << inter_process_edge_names_ << ". Check running condition:" << is_run;
159 
160   // Parse the message from remote peer and set to rpc recv kernel.
161   auto recv_kernel_mod = dynamic_cast<kernel::RpcKernelMod *>(kernel_info_->MutableKernelMod());
162   MS_ERROR_IF_NULL_WO_RET_VAL(recv_kernel_mod);
163 
164   // We set remote data by the interface of the rpc kernel, because currently there's no remote input for a kernel mod.
165   recv_kernel_mod->SetRemoteInput(msg);
166   if (common::GetEnv(kEnableRDMA) == "1") {
167     rdma_buf_ = msg->data;
168   }
169 
170   if (is_run) {
171     Run(context);
172   }
173   return;
174 }
175 
CheckRunningCondition(const OpContext<DeviceTensor> * context) const176 bool RecvActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
177   MS_EXCEPTION_IF_NULL(context);
178   // Step 1: Judge data and control inputs are satisfied.
179   bool is_data_and_control_arrow_satisfied = AbstractActor::CheckRunningCondition(context);
180   if (!is_data_and_control_arrow_satisfied) {
181     return false;
182   }
183 
184   if (input_inter_process_num_ != 0) {
185     // Step 2: Judge inter-process inputs are satisfied.
186     const auto &inter_process_iter = input_op_inter_process_.find(context->sequential_num_);
187     if (inter_process_iter == input_op_inter_process_.end()) {
188       return false;
189     }
190 
191     const auto &current_inter_process_inputs = inter_process_iter->second;
192     if (current_inter_process_inputs.size() < input_inter_process_num_) {
193       return false;
194     } else if (current_inter_process_inputs.size() > input_inter_process_num_) {
195       MS_LOG(ERROR) << "Invalid inter process input num:" << current_inter_process_inputs.size()
196                     << " need:" << input_inter_process_num_ << " for actor:" << GetAID();
197       return false;
198     }
199   }
200   return true;
201 }
202 
EraseInput(const OpContext<DeviceTensor> * context)203 void RecvActor::EraseInput(const OpContext<DeviceTensor> *context) {
204   MS_EXCEPTION_IF_NULL(context);
205   KernelActor::EraseInput(context);
206 
207   if (input_op_inter_process_.count(context->sequential_num_) != 0) {
208     (void)input_op_inter_process_.erase(context->sequential_num_);
209   }
210   // Release data allocated by AllocateMessage.
211   if (recv_data_ != nullptr) {
212     if (!WaitRuntimePipelineFinish(context)) {
213       MS_LOG(INFO) << "Run failed and early stop.";
214       return;
215     }
216     MS_EXCEPTION_IF_CHECK_FAIL((!device_contexts_.empty()), "The device context doesn't exist.");
217     MS_EXCEPTION_IF_NULL(device_contexts_[0]);
218     MS_EXCEPTION_IF_NULL(device_contexts_[0]->device_res_manager_);
219     device_contexts_[0]->device_res_manager_->FreeMemory(recv_data_.get());
220   }
221 
222 #ifdef ENABLE_RDMA
223   // Release data of URPC by caller.
224   if (common::GetEnv(kEnableRDMA) == "1" && rdma_buf_ != nullptr) {
225     if (!WaitRuntimePipelineFinish(context)) {
226       MS_LOG(INFO) << "Run failed and early stop.";
227       return;
228     }
229     auto rdma_server = dynamic_cast<RDMAServer *>(server_.get());
230     MS_EXCEPTION_IF_NULL(rdma_server);
231     auto urpc_alloc = rdma_server->urpc_allocator();
232     MS_EXCEPTION_IF_NULL(urpc_alloc);
233     urpc_alloc->free(rdma_buf_);
234   }
235 #endif
236 }
237 
Run(OpContext<DeviceTensor> * const context)238 void RecvActor::Run(OpContext<DeviceTensor> *const context) {
239   MS_EXCEPTION_IF_NULL(context);
240   MS_EXCEPTION_IF_NULL(kernel_info_);
241   auto recv_kernel_mod = dynamic_cast<kernel::RpcKernelMod *>(kernel_info_->MutableKernelMod());
242   MS_EXCEPTION_IF_NULL(recv_kernel_mod);
243   auto remote_input = recv_kernel_mod->GetRemoteInput();
244   bool need_finalize = false;
245   // Preprocess the remote input in case data is dynamic shape.
246   PreprocessRemoteInput(remote_input, &need_finalize);
247   if (need_finalize) {
248     return;
249   }
250   KernelActor::Run(context);
251 }
252 
AllocateMessage(size_t size)253 void *RecvActor::AllocateMessage(size_t size) {
254   // Block this method until the context is valid.
255   std::unique_lock<std::mutex> lock(context_mtx_);
256   context_cv_.wait(lock, [this] { return is_context_valid_; });
257   lock.unlock();
258 
259   return AllocateMemByDeviceRes(size);
260 }
261 
AllocateMemByDeviceRes(size_t size)262 void *RecvActor::AllocateMemByDeviceRes(size_t size) {
263   // Only need to create recv_data_ once.
264   // The real data is allocated and freed multiple times as recv_data_->ptr_.
265   if (recv_data_ == nullptr) {
266     recv_data_ = std::make_shared<CPUDeviceAddress>(nullptr, size);
267     MS_ERROR_IF_NULL_W_RET_VAL(recv_data_, nullptr);
268   } else {
269     recv_data_->SetSize(size);
270   }
271 
272   MS_EXCEPTION_IF_CHECK_FAIL((!device_contexts_.empty()), "The device context doesn't exist.");
273   MS_ERROR_IF_NULL_W_RET_VAL(device_contexts_[kIndex0], nullptr);
274   MS_ERROR_IF_NULL_W_RET_VAL(device_contexts_[kIndex0]->device_res_manager_, nullptr);
275   if (!device_contexts_[kIndex0]->device_res_manager_->AllocateMemory(recv_data_.get())) {
276     MS_LOG(ERROR) << "Failed to allocate memory size " << size;
277     return nullptr;
278   }
279   return recv_data_->GetMutablePtr();
280 }
281 
AddArgSpecForInput(AbstractBasePtrList * args_spec_list,const ShapeVector & shapes,TypeId data_type,size_t input_index) const282 void RecvActor::AddArgSpecForInput(AbstractBasePtrList *args_spec_list, const ShapeVector &shapes, TypeId data_type,
283                                    size_t input_index) const {
284   MS_EXCEPTION_IF_NULL(args_spec_list);
285   MS_EXCEPTION_IF_NULL(kernel_);
286   auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(kernel_, input_index, false);
287   auto real_input = input_node_with_index.first;
288   size_t real_input_index = input_node_with_index.second;
289   MS_EXCEPTION_IF_NULL(real_input);
290   auto output_addr = AnfAlgo::GetMutableOutputAddr(real_input, real_input_index, false);
291   MS_EXCEPTION_IF_NULL(output_addr);
292   if (output_addr->GetNodeIndex().first == nullptr) {
293     output_addr->SetNodeIndex(kernel_, input_index);
294   }
295   auto out_tensor = std::make_shared<tensor::Tensor>(data_type, shapes);
296   MS_EXCEPTION_IF_NULL(out_tensor);
297   out_tensor->set_device_address(output_addr, false);
298 
299   auto real_abs = real_input->abstract();
300   MS_EXCEPTION_IF_NULL(real_abs);
301   auto updated_shape = std::make_shared<abstract::Shape>(shapes);
302   MS_EXCEPTION_IF_NULL(updated_shape);
303   if (real_abs->isa<abstract::AbstractTensor>()) {
304     real_abs->set_value(out_tensor);
305     real_abs->set_shape(updated_shape);
306   } else if (real_abs->isa<abstract::AbstractTuple>()) {
307     if (common::AnfAlgo::IsDynamicSequence(real_input)) {
308       MS_LOG_WITH_NODE(EXCEPTION, real_input)
309         << "Invalid dynamic sequence for actor:" << GetAID() << " node:" << real_input->DebugString();
310     }
311     auto abstract_tuple = real_abs->cast<abstract::AbstractTuplePtr>();
312     MS_EXCEPTION_IF_NULL(abstract_tuple);
313     MS_EXCEPTION_IF_CHECK_FAIL((real_input_index < abstract_tuple->elements().size()), "Index is out of range.");
314     auto tuple_elements = abstract_tuple->elements()[real_input_index];
315     MS_EXCEPTION_IF_NULL(tuple_elements);
316     tuple_elements->set_value(out_tensor);
317     tuple_elements->set_shape(updated_shape);
318   }
319   common::AnfAlgo::AddArgList(args_spec_list, real_input, real_input_index);
320 
321   // The inputs of RpcRecv node are all in device tensor store(weight or value node), framework does not free these
322   // device tensors. If these device tensors are not released, they will persist in the same memory size. In dynamic
323   // shape scenarios, there will be out of memory bounds problems.
324   MS_EXCEPTION_IF_NULL(output_addr);
325   auto output_addr_size = AnfAlgo::GetOutputTensorMemSize(real_input, real_input_index);
326   if (output_addr_size != output_addr->GetSize()) {
327     output_addr->SetSize(output_addr_size);
328     MS_EXCEPTION_IF_NULL(device_contexts_[0]);
329     MS_EXCEPTION_IF_NULL(device_contexts_[0]->device_res_manager_);
330     device_contexts_[0]->device_res_manager_->FreeMemory(output_addr.get());
331   }
332 
333   // Update kernel tensor shape for dynamic shape case.
334   const auto &output_kernel_tensor = output_addr->kernel_tensor();
335   MS_EXCEPTION_IF_NULL(output_kernel_tensor);
336   const auto &new_shape = real_abs->GetShape();
337   MS_EXCEPTION_IF_NULL(new_shape);
338   output_kernel_tensor->SetShape(new_shape->Clone());
339 }
340 
ParseDynamicShapeData(const RpcDataPtr & dynamic_shape_data,size_t data_size,AbstractBasePtrList * args_spec_list,size_t count)341 size_t RecvActor::ParseDynamicShapeData(const RpcDataPtr &dynamic_shape_data, size_t data_size,
342                                         AbstractBasePtrList *args_spec_list, size_t count) {
343   // The data which could be parsed by offset in dynamic shape scenario.
344   auto data_to_be_parsed = dynamic_shape_data;
345   // The real data offsets which will be used by RpcRecvKernel.
346   std::vector<size_t> real_data_offsets;
347 
348   // Once the magic header is dynamic shape, each input of the Recv is dynamic shape.
349   // So traverse each input and parse the dynamic shape data.
350   size_t offset = 0;
351   for (size_t i = 0; i < count; i++) {
352     if (data_to_be_parsed >= dynamic_shape_data + data_size) {
353       MS_LOG(EXCEPTION) << "The dynamic shape data size is invalid.";
354     }
355     // Step 1: parse the magic header which indicates the dynamic shape.
356     std::string dynamic_shape_magic_header(data_to_be_parsed, strlen(kRpcDynamicShapeData));
357     if (dynamic_shape_magic_header != kRpcDynamicShapeData) {
358       MS_LOG(EXCEPTION) << "The dynamie shape data must have the magic header RPC_DYNAMIC_SHAPE_DATA. But got "
359                         << dynamic_shape_magic_header;
360     }
361 
362     // Step 2: parse the size of serialized protobuf message.
363     data_to_be_parsed += strlen(kRpcDynamicShapeData);
364     size_t pb_msg_size = 0;
365     MS_EXCEPTION_IF_CHECK_FAIL(memcpy_s(&pb_msg_size, sizeof(pb_msg_size), data_to_be_parsed, sizeof(size_t)) == EOK,
366                                "memcpy_s protobuf message size failed.");
367 
368     // Step 3: deserialize the protobuf message.
369     data_to_be_parsed += sizeof(pb_msg_size);
370     rpc::DynamicShapeMessage pb_msg;
371     (void)pb_msg.ParseFromArray(data_to_be_parsed, SizeToInt(pb_msg_size));
372 
373     // Step 4: parse the data shape and
374     ShapeVector shapes(pb_msg.shape_vector().begin(), pb_msg.shape_vector().end());
375     TypeId data_type = static_cast<TypeId>(pb_msg.type_id());
376     data_to_be_parsed += pb_msg_size;
377 
378     // Step 5: get the size of real data as recv's input.
379     int64_t real_data_size = 1;
380     if (!kernel::GetShapeSize(shapes, TypeIdToType(data_type), &real_data_size)) {
381       MS_LOG(EXCEPTION) << "Getting shape size for shape " << shapes << " failed.";
382     }
383     data_to_be_parsed += real_data_size;
384 
385     // Step 6: update the abstract.
386     AddArgSpecForInput(args_spec_list, shapes, data_type, i);
387 
388     offset += strlen(kRpcDynamicShapeData) + sizeof(pb_msg_size) + pb_msg_size;
389     real_data_offsets.push_back(offset);
390     offset += LongToSize(real_data_size);
391   }
392 
393   auto recv_kernel_mod = dynamic_cast<kernel::RpcRecvKernelMod *>(kernel_info_->MutableKernelMod());
394   MS_EXCEPTION_IF_NULL(recv_kernel_mod);
395   recv_kernel_mod->set_real_data_offset(real_data_offsets);
396   return offset;
397 }
398 
PreprocessRemoteInput(const MessageBase * const msg,bool * need_finalize)399 void RecvActor::PreprocessRemoteInput(const MessageBase *const msg, bool *need_finalize) {
400   MS_EXCEPTION_IF_NULL(msg);
401   MS_EXCEPTION_IF_NULL(need_finalize);
402 
403   // Parse the void * data.
404   size_t data_size = msg->size;
405   MS_EXCEPTION_IF_NULL(msg->data);
406   std::string msg_magic_header = std::string(static_cast<RpcDataPtr>(msg->data), strlen(kRpcDynamicShapeData));
407   RpcDataPtr dynamic_shape_data = static_cast<RpcDataPtr>(msg->data);
408 
409   if (data_size <= strlen(kRpcDynamicShapeData)) {
410     MS_LOG(DEBUG) << "This is not a dynamic shape data. No need to preprocess.";
411     return;
412   }
413   if (msg_magic_header != kRpcDynamicShapeData) {
414     MS_LOG(DEBUG) << "This is not a dynamic shape data. No need to preprocess.";
415     return;
416   }
417 
418   MS_LOG(INFO) << "Preprocess for dynamic shape data.";
419   AbstractBasePtrList args_spec_list;
420   size_t input_size = common::AnfAlgo::GetInputTensorNum(kernel_);
421   size_t dynamic_shape_data_msg_len = ParseDynamicShapeData(dynamic_shape_data, data_size, &args_spec_list, input_size);
422   ParseFinalizeReqData(dynamic_shape_data_msg_len, msg, need_finalize);
423 }
424 
HandleMessage(MessageBase * const msg)425 MessageBase *RecvActor::HandleMessage(MessageBase *const msg) {
426   // Block the message handler if the context is invalid.
427   std::unique_lock<std::mutex> lock(context_mtx_);
428   context_cv_.wait(lock, [this] { return is_context_valid_ || is_exception_thrown_; });
429   if (is_exception_thrown_) {
430     MS_LOG(WARNING) << "Recv actor stops waiting for op_context at exception.";
431     return distributed::rpc::NULL_MSG;
432   }
433   lock.unlock();
434   // Once recv actor is launched, lock the context so that the next step's recv will not be launched in advance.
435   ResetOpcontext();
436 
437   MS_LOG(INFO) << "Rpc actor recv message for inter-process edge: " << inter_process_edge_names_;
438 
439   if (msg == nullptr || op_context_ == nullptr) {
440     return distributed::rpc::NULL_MSG;
441   }
442   ActorDispatcher::Send(GetAID(), &RecvActor::RunOpInterProcessData, msg, op_context_);
443   return distributed::rpc::NULL_MSG;
444 }
445 
SetMessageHandler()446 void RecvActor::SetMessageHandler() {
447   server_->SetMessageHandler(std::bind(&RecvActor::HandleMessage, this, std::placeholders::_1), ++kRemoteFuncId);
448 }
449 }  // namespace runtime
450 }  // namespace mindspore
451