• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/actor/rpc/send_actor.h"
18 
19 #include <utility>
20 #include "runtime/graph_scheduler/actor/memory_manager_actor.h"
21 
22 namespace mindspore {
23 namespace runtime {
~SendActor()24 SendActor::~SendActor() {
25   if (client_) {
26     try {
27       (void)client_->Disconnect(server_url_);
28       client_->Finalize();
29     } catch (const std::exception &) {
30       MS_LOG(ERROR) << "Failed to disconnect and finalize for rpc client in send actor.";
31     }
32     client_ = nullptr;
33   }
34 }
35 
SetRouteInfo(uint32_t,const std::string &,const std::string & send_src_node_name,const std::string & send_dst_node_name)36 void SendActor::SetRouteInfo(uint32_t, const std::string &, const std::string &send_src_node_name,
37                              const std::string &send_dst_node_name) {
38   peer_actor_ids_ = inter_process_edge_names_;
39   (void)rpc_output_node_name_.emplace_back(send_dst_node_name);
40 }
41 
ConnectServer()42 bool SendActor::ConnectServer() {
43 #ifdef ENABLE_RDMA
44   if (common::GetEnv(kEnableRDMA) == "1") {
45     client_ = std::make_unique<RDMAClient>();
46   } else {
47     client_ = std::make_unique<TCPClient>();
48   }
49 #else
50   client_ = std::make_unique<TCPClient>();
51 #endif
52   MS_EXCEPTION_IF_NULL(client_);
53 
54   if (!client_->Initialize()) {
55     MS_LOG(EXCEPTION) << "Failed to initialize rpc server for send actor.";
56   }
57   // Lookup actor addresses for each peer actor.
58   for (const auto &peer_actor_id : peer_actor_ids_) {
59     MS_EXCEPTION_IF_NULL(actor_route_table_proxy_);
60     auto peer_actor_address = actor_route_table_proxy_->LookupRoute(peer_actor_id);
61 
62     // If route is successfully looked up, peer_actor_address is not empty.
63     server_url_ = peer_actor_address.ip() + ":" + std::to_string(peer_actor_address.port());
64     remote_func_id_ = peer_actor_address.func_id();
65     auto free_callback = std::bind(&SendActor::FreeMessage, this, std::placeholders::_1);
66     size_t retry_count = 60;
67     if (!client_->Connect(server_url_, retry_count, free_callback)) {
68       MS_LOG(EXCEPTION) << "Failed to connect to server of actor " << peer_actor_id << ", server_url: " << server_url_;
69     }
70 
71     MS_LOG(INFO) << "Successfully connect to server " << server_url_ << ", remote function id: " << remote_func_id_
72                  << ", inter-process edge name: " << peer_actor_id;
73     peer_actor_urls_[peer_actor_id] = server_url_;
74   }
75 
76   return true;
77 }
78 
FlushData()79 void SendActor::FlushData() {
80   MS_EXCEPTION_IF_NULL(client_);
81   if (!client_->Flush(server_url_)) {
82     MS_LOG(EXCEPTION) << "Failed to flush client for server " << server_url_;
83   }
84 }
85 
Clear()86 void SendActor::Clear() {
87   if (client_) {
88     (void)client_->Disconnect(server_url_);
89     client_->Finalize();
90     client_ = nullptr;
91   }
92 }
93 
LaunchKernel(OpContext<DeviceTensor> * const context,bool is_skip_launch)94 bool SendActor::LaunchKernel(OpContext<DeviceTensor> *const context, bool is_skip_launch) {
95   if (is_skip_launch) {
96     return KernelActor::LaunchKernel(context, is_skip_launch);
97   }
98   MS_ERROR_IF_NULL_W_RET_VAL(context, false);
99   // Set context for later usage in FreeMessage.
100   context_ = context;
101 
102   if (!KernelActor::LaunchKernel(context, is_skip_launch)) {
103     MS_LOG(ERROR) << "Launching kernel for send actor failed.";
104     return false;
105   }
106 
107   // Send input data(inter-process data is the input of the Send kernel) to peers.
108   if (input_device_tensors_.empty()) {
109     MS_LOG(ERROR) << "Send kernel has no output tensor.";
110     return false;
111   }
112   for (const auto &peer : peer_actor_urls_) {
113     std::string peer_server_url = peer.second;
114     auto message = BuildRpcMessage(peer_server_url);
115     MS_ERROR_IF_NULL_W_RET_VAL(message, false);
116     MS_ERROR_IF_NULL_W_RET_VAL(client_, false);
117     MS_LOG(INFO) << "Rpc actor send message for inter-process edge: " << peer.first;
118     client_->SendAsync(std::move(message));
119   }
120   return true;
121 }
122 
EraseInput(const OpContext<DeviceTensor> * context)123 void SendActor::EraseInput(const OpContext<DeviceTensor> *context) {
124   MS_EXCEPTION_IF_NULL(context);
125   AbstractActor::EraseInput(context);
126 
127   if (input_op_inter_process_.count(context->sequential_num_) != 0) {
128     (void)input_op_inter_process_.erase(context->sequential_num_);
129   }
130 }
131 
BuildRpcMessage(const std::string & server_url)132 std::unique_ptr<MessageBase> SendActor::BuildRpcMessage(const std::string &server_url) {
133   std::unique_ptr<MessageBase> message = std::make_unique<MessageBase>();
134   MS_ERROR_IF_NULL_W_RET_VAL(message, nullptr);
135   message->to = AID("", server_url);
136   message->func_id_ = remote_func_id_;
137 
138   // To reach optimal performance, we use workspace memory as the data sent to the remote. So the size must be
139   // strictly checked to avoid illegal memory access.
140   auto send_workspace = workspace_device_tensors_;
141   if (send_workspace.empty()) {
142     MS_LOG(EXCEPTION) << "RpcSendKernel's workspace should not be empty.";
143   }
144   // Only use one piece of workspace memory to avoid extra memory copying and serialize inputs data to one message.
145   auto workspace_addr = send_workspace[kIndex0];
146   if (is_dynamic_shape_) {
147     MS_LOG(INFO) << "This send actor builds message with dynamic shape.";
148     SerializeDynamicShapeMessage(message.get(), workspace_addr);
149   } else {
150     SerializeCommonMessage(message.get(), workspace_addr);
151   }
152 
153   MS_LOG(DEBUG) << "RpcSend message size is " << message->size;
154   return message;
155 }
156 
FreeMessage(void * data)157 bool SendActor::FreeMessage(void *data) {
158   auto memory_free_list = FindDeviceTensorNeedsFree(data);
159   ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &memory_free_list,
160                             device_contexts_[0], context_, GetAID());
161   return true;
162 }
163 
Flush()164 void SendActor::Flush() {
165   MS_EXCEPTION_IF_NULL(client_);
166   for (const auto &url : peer_actor_urls_) {
167     MS_LOG(DEBUG) << "Flush for url " << url.second;
168     if (!client_->Flush(url.second)) {
169       MS_LOG(EXCEPTION) << "Failed to flush for url " << url.second;
170     }
171   }
172 }
173 
FindDeviceTensorNeedsFree(const void * data) const174 std::vector<DeviceTensor *> SendActor::FindDeviceTensorNeedsFree(const void *data) const {
175   std::vector<DeviceTensor *> free_list;
176   // The sent data uses the memory of workspace. So query the DeviceTensor from workspace_device_tensors_.
177   for (const auto &device_tensor : workspace_device_tensors_) {
178     MS_ERROR_IF_NULL_W_RET_VAL(device_tensor, {});
179     if (data == device_tensor->GetMutablePtr()) {
180       free_list.push_back(device_tensor);
181     }
182   }
183   return free_list;
184 }
185 
SerializeSingleDynamicShapeInput(RpcDataPtr rpc_data,const ShapeVector & shape_vec,const TypeId & data_type,const DeviceTensor * addr) const186 size_t SendActor::SerializeSingleDynamicShapeInput(RpcDataPtr rpc_data, const ShapeVector &shape_vec,
187                                                    const TypeId &data_type, const DeviceTensor *addr) const {
188   MS_EXCEPTION_IF_NULL(rpc_data);
189   MS_EXCEPTION_IF_NULL(addr);
190 
191   // The serialize data size needs to be computed.
192   size_t serialized_data_size = 0;
193 
194   // Serialize data's meta info to protobuffer.
195   rpc::DynamicShapeMessage pb_msg;
196   pb_msg.set_type_id(static_cast<int>(data_type));
197   *pb_msg.mutable_shape_vector() = {shape_vec.begin(), shape_vec.end()};
198   std::string pb_msg_str = pb_msg.SerializeAsString();
199 
200   // Part 1. Magic header for dynamic shape.
201   size_t header_size = strlen(kRpcDynamicShapeData);
202   if (!CopyRpcDataWithOffset(&rpc_data, kRpcDynamicShapeData, header_size)) {
203     MS_LOG(EXCEPTION) << "Failed to copy data for kRpcDynamicShapeData.";
204   }
205   serialized_data_size += header_size;
206 
207   // Part 2. The size of the protobuf message DynamicShapeMessage.
208   size_t pb_msg_size = pb_msg_str.size();
209   if (!CopyRpcDataWithOffset(&rpc_data, &pb_msg_size, sizeof(pb_msg_size))) {
210     MS_LOG(EXCEPTION) << "Failed to copy data for protobuffer data's size.";
211   }
212   serialized_data_size += sizeof(pb_msg_size);
213 
214   // Part 3. Protobuf message DynamicShapeMessage.
215   if (!CopyRpcDataWithOffset(&rpc_data, pb_msg_str.c_str(), pb_msg_str.size())) {
216     MS_LOG(EXCEPTION) << "Failed to copy data for protobuffer data.";
217   }
218   serialized_data_size += pb_msg_str.size();
219 
220   // Part 4. The real data buffer of the input.
221   if (!CopyRpcDataWithOffset(&rpc_data, addr->GetMutablePtr(), addr->GetSize())) {
222     MS_LOG(EXCEPTION) << "Failed to copy data for real input data.";
223   }
224   serialized_data_size += addr->GetSize();
225 
226   return serialized_data_size;
227 }
228 
SerializeDynamicShapeMessage(MessageBase * message,const DeviceTensor * workspace_addr) const229 void SendActor::SerializeDynamicShapeMessage(MessageBase *message, const DeviceTensor *workspace_addr) const {
230   MS_EXCEPTION_IF_NULL(workspace_addr);
231   size_t offset = 0;
232   RpcDataPtr rpc_data = static_cast<RpcDataPtr>(workspace_addr->GetMutablePtr());
233   for (size_t i = 0; i < input_kernel_tensors_.size(); i++) {
234     auto shapes = input_kernel_tensors_[i]->GetShapeVector();
235     TypeId data_type = input_kernel_tensors_[i]->dtype_id();
236     size_t serialized_data_size =
237       SerializeSingleDynamicShapeInput(rpc_data + offset, shapes, data_type, input_device_tensors_[i]);
238     offset += serialized_data_size;
239   }
240 
241   if (workspace_addr->GetSize() != offset) {
242     MS_LOG(EXCEPTION) << "Send void data size is not the same as workspace size.";
243   }
244   MS_EXCEPTION_IF_NULL(message);
245   message->data = workspace_addr->GetMutablePtr();
246   message->size = workspace_addr->GetSize();
247 }
248 
SerializeCommonMessage(MessageBase * message,const DeviceTensor * workspace_addr) const249 void SendActor::SerializeCommonMessage(MessageBase *message, const DeviceTensor *workspace_addr) const {
250   MS_EXCEPTION_IF_NULL(message);
251   MS_EXCEPTION_IF_NULL(workspace_addr);
252   MS_EXCEPTION_IF_NULL(workspace_addr->GetMutablePtr());
253   size_t total_size = 0;
254   total_size =
255     std::accumulate(input_device_tensors_.begin(), input_device_tensors_.end(), total_size,
256                     [](size_t total_size, const DeviceTensor *output) { return total_size + output->GetSize(); });
257   if (workspace_addr->GetSize() != total_size) {
258     MS_LOG(EXCEPTION) << "Workspace size should be the same as inputs size. But got " << workspace_addr->GetSize()
259                       << " and " << total_size;
260   }
261 
262   RpcDataPtr rpc_data = static_cast<RpcDataPtr>(workspace_addr->GetMutablePtr());
263   MS_EXCEPTION_IF_NULL(rpc_data);
264   for (size_t i = 0; i < input_device_tensors_.size(); i++) {
265     MS_EXCEPTION_IF_NULL(input_device_tensors_[i]);
266     if (!CopyRpcDataWithOffset(&rpc_data, input_device_tensors_[i]->GetMutablePtr(),
267                                input_device_tensors_[i]->GetSize())) {
268       MS_LOG(EXCEPTION) << "Failed to copy data for rpc send input " << i;
269     }
270   }
271   message->data = workspace_addr->GetMutablePtr();
272   message->size = workspace_addr->GetSize();
273 }
274 
275 }  // namespace runtime
276 }  // namespace mindspore
277