• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/rpc_node_scheduler.h"
18 #include <vector>
19 #include <string>
20 #include "include/backend/distributed/cluster/topology/compute_graph_node.h"
21 #include "include/common/utils/anfalgo.h"
22 #include "runtime/graph_scheduler/actor/rpc/mux_send_actor.h"
23 #include "runtime/graph_scheduler/actor/rpc/mux_recv_actor.h"
24 
25 namespace mindspore {
26 namespace runtime {
27 namespace {
28 // MuxSendActor and MuxRecvActor of the server are used in pairs, and the MuxSendActor
29 // needs to obtain the information(ip and port) of peer that initiates this service from the corresponding
30 // MuxRecvActor to response request, so need to set the paired MuxRecvActor for MuxSendActor.
SetMuxRecvActorForMuxSendActor(const RpcActorSetPtr & rpc_actor_set)31 void SetMuxRecvActorForMuxSendActor(const RpcActorSetPtr &rpc_actor_set) {
32   MS_EXCEPTION_IF_NULL(rpc_actor_set);
33 
34   // 1. Check whether exist mux recv actor.
35   bool exist_mux_recv_actor = false;
36   std::vector<RecvActorPtr> recv_actors;
37   for (const auto &recv_actor : rpc_actor_set->recv_actors_) {
38     MS_EXCEPTION_IF_NULL(recv_actor);
39     CNodePtr rpc_recv_kernel = recv_actor->kernel();
40     MS_EXCEPTION_IF_NULL(rpc_recv_kernel);
41     if (common::AnfAlgo::HasNodeAttr(kAttrIsMuxRpcKernel, rpc_recv_kernel) &&
42         (common::AnfAlgo::GetNodeAttr<bool>(rpc_recv_kernel, kAttrIsMuxRpcKernel) == true)) {
43       exist_mux_recv_actor = true;
44       (void)recv_actors.emplace_back(recv_actor);
45     }
46   }
47 
48   if (!exist_mux_recv_actor) {
49     return;
50   }
51   if (recv_actors.size() != 1) {
52     MS_LOG(EXCEPTION) << "Currently the actor set is only allowed to contain one MuxRecvActor, but got: "
53                       << recv_actors.size();
54   }
55 
56   // 2. Set mux recv actor for mux send actor.
57   MuxRecvActorPtr mux_recv_actor = std::dynamic_pointer_cast<MuxRecvActor>(recv_actors.front());
58   MS_EXCEPTION_IF_NULL(mux_recv_actor);
59 
60   for (const auto &send_actor : rpc_actor_set->send_actors_) {
61     MS_EXCEPTION_IF_NULL(send_actor);
62     MuxSendActorPtr mux_send_actor = std::dynamic_pointer_cast<MuxSendActor>(send_actor);
63     MS_EXCEPTION_IF_NULL(mux_send_actor);
64     mux_send_actor->set_mux_recv_actor(mux_recv_actor);
65   }
66 }
67 }  // namespace
68 
Build(const ActorSet * actor_set)69 RpcActorSetPtr RpcNodeScheduler::Build(const ActorSet *actor_set) {
70   MS_EXCEPTION_IF_NULL(actor_set);
71 
72   // RpcActor inherits from KernelActor, so we need to filter out the rpc actors from kernel actors list.
73   std::vector<KernelActorPtr> kernel_actors = actor_set->kernel_actors_;
74   RpcActorSetPtr rpc_actor_set = std::make_shared<RpcActorSet>();
75   MS_EXCEPTION_IF_NULL(rpc_actor_set);
76 
77   std::vector<RpcActorPtr> rpc_actors;
78   for (const auto &kernel_actor : kernel_actors) {
79     auto rpc_actor = std::dynamic_pointer_cast<RpcActor>(kernel_actor);
80     if (std::dynamic_pointer_cast<RpcActor>(kernel_actor) == nullptr) {
81       continue;
82     } else {
83       (void)rpc_actors.emplace_back(rpc_actor);
84       if (std::dynamic_pointer_cast<SendActor>(rpc_actor) != nullptr) {
85         (void)rpc_actor_set->send_actors_.emplace_back(std::dynamic_pointer_cast<SendActor>(rpc_actor));
86       } else if (std::dynamic_pointer_cast<RecvActor>(rpc_actor) != nullptr) {
87         (void)rpc_actor_set->recv_actors_.emplace_back(std::dynamic_pointer_cast<RecvActor>(rpc_actor));
88       } else {
89         MS_LOG(EXCEPTION) << "Rpc actor should be either SendActor or RecvActor.";
90       }
91     }
92   }
93 
94   // Set the paired MuxRecvActor for MuxSendActor, used in embedding cache case.
95   SetMuxRecvActorForMuxSendActor(rpc_actor_set);
96 
97   // Create route table proxy for each rpc actor and set.
98   for (auto &rpc_actor : rpc_actors) {
99     auto proxy = CreateRouteTableProxy();
100     MS_EXCEPTION_IF_NULL(rpc_actor);
101     MS_EXCEPTION_IF_NULL(proxy);
102     rpc_actor->set_actor_route_table_proxy(proxy);
103   }
104 
105   // Update the reference counts of rpc kernel inputs and workspaces.
106   UpdateRpcActorRefCounts(rpc_actor_set);
107 
108   return rpc_actor_set;
109 }
110 
Link(const ActorSet * actor_set) const111 void RpcNodeScheduler::Link(const ActorSet *actor_set) const {
112   MS_EXCEPTION_IF_NULL(actor_set);
113   RpcActorSetPtr rpc_actor_set = actor_set->rpc_actors_;
114   MS_EXCEPTION_IF_NULL(rpc_actor_set);
115   std::vector<SendActorPtr> send_actors = rpc_actor_set->send_actors_;
116   std::vector<RecvActorPtr> recv_actors = rpc_actor_set->recv_actors_;
117 
118   // The inter-process edge is connected to a remote peer. So the peer info attributes in the kernel should be
119   // sufficient for route table.
120   for (auto &send_actor : send_actors) {
121     MS_EXCEPTION_IF_NULL(send_actor);
122     CNodePtr rpc_send_kernel = send_actor->kernel();
123     MS_EXCEPTION_IF_NULL(rpc_send_kernel);
124 
125     auto send_dst_ranks = common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(rpc_send_kernel, kAttrSendDstRanks);
126     auto send_dst_roles = common::AnfAlgo::GetNodeAttr<std::vector<std::string>>(rpc_send_kernel, kAttrSendDstRoles);
127     std::string send_src_node_name = common::AnfAlgo::GetNodeAttr<std::string>(rpc_send_kernel, kAttrSendSrcNodeName);
128     std::string send_dst_node_name = common::AnfAlgo::GetNodeAttr<std::string>(rpc_send_kernel, kAttrSendDstNodeName);
129     std::vector<std::string> edge_names =
130       common::AnfAlgo::GetNodeAttr<std::vector<std::string>>(rpc_send_kernel, kAttrInterProcessEdgeNames);
131 
132     if (send_dst_ranks.empty() || send_dst_roles.empty()) {
133       MS_LOG_WITH_NODE(EXCEPTION, rpc_send_kernel)
134         << "The attributes of send node " << rpc_send_kernel->fullname_with_scope()
135         << " is invalid. send_dst_ranks: " << send_dst_ranks << ", send_dst_roles: " << send_dst_roles
136         << ", send_src_node_name: " << send_src_node_name << ", send_dst_node_name: " << send_dst_node_name;
137     }
138     send_actor->set_inter_process_edge_names(edge_names);
139     send_actor->SetRouteInfo(send_dst_ranks[0], send_dst_roles[0], send_src_node_name, send_dst_node_name);
140   }
141   for (auto &recv_actor : recv_actors) {
142     MS_EXCEPTION_IF_NULL(recv_actor);
143     CNodePtr rpc_recv_kernel = recv_actor->kernel();
144     MS_EXCEPTION_IF_NULL(rpc_recv_kernel);
145 
146     auto recv_src_ranks = common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(rpc_recv_kernel, kAttrRecvSrcRanks);
147     auto recv_src_roles = common::AnfAlgo::GetNodeAttr<std::vector<std::string>>(rpc_recv_kernel, kAttrRecvSrcRoles);
148     std::string recv_src_node_name = common::AnfAlgo::GetNodeAttr<std::string>(rpc_recv_kernel, kAttrRecvSrcNodeName);
149     std::string recv_dst_node_name = common::AnfAlgo::GetNodeAttr<std::string>(rpc_recv_kernel, kAttrRecvDstNodeName);
150     std::vector<std::string> edge_names =
151       common::AnfAlgo::GetNodeAttr<std::vector<std::string>>(rpc_recv_kernel, kAttrInterProcessEdgeNames);
152 
153     if (recv_src_ranks.empty() || recv_src_roles.empty()) {
154       MS_LOG_WITH_NODE(EXCEPTION, rpc_recv_kernel)
155         << "The attributes of recv node " << rpc_recv_kernel->fullname_with_scope()
156         << " is invalid. recv_src_ranks: " << recv_src_ranks << ", recv_src_roles: " << recv_src_roles
157         << ", recv_src_node_name: " << recv_src_node_name << ", recv_dst_node_name: " << recv_dst_node_name;
158     }
159     recv_actor->set_inter_process_edge_names(edge_names);
160     recv_actor->SetRouteInfo(recv_src_ranks[0], recv_src_roles[0], recv_src_node_name, recv_dst_node_name);
161   }
162 }
163 
Schedule(const ActorSet * actor_set) const164 void RpcNodeScheduler::Schedule(const ActorSet *actor_set) const {
165   MS_EXCEPTION_IF_NULL(actor_set);
166   RpcActorSetPtr rpc_actor_set = actor_set->rpc_actors_;
167   MS_EXCEPTION_IF_NULL(rpc_actor_set);
168   // Must start server and register route table before looking up route and connecting.
169 
170   // Start servers of recv actors and register route table.
171   for (auto &recv_actor : rpc_actor_set->recv_actors_) {
172     MS_EXCEPTION_IF_NULL(recv_actor);
173     if (!recv_actor->StartServer()) {
174       MS_LOG(EXCEPTION) << "Failed to start server for the recv actor.";
175     }
176   }
177   // Lookup route and connect to servers for send actors.
178   for (auto &send_actor : rpc_actor_set->send_actors_) {
179     MS_EXCEPTION_IF_NULL(send_actor);
180     if (!send_actor->ConnectServer()) {
181       MS_LOG(EXCEPTION) << "Failed to connect servers for the send actor.";
182     }
183   }
184 }
185 
SetOpcontext(const RpcActorSetPtr & rpc_actors,OpContext<DeviceTensor> * const op_context)186 void RpcNodeScheduler::SetOpcontext(const RpcActorSetPtr &rpc_actors, OpContext<DeviceTensor> *const op_context) {
187   MS_EXCEPTION_IF_NULL(op_context);
188   MS_EXCEPTION_IF_NULL(rpc_actors);
189 
190   for (auto &recv_actor : rpc_actors->recv_actors_) {
191     MS_EXCEPTION_IF_NULL(recv_actor);
192     recv_actor->SetOpcontext(op_context);
193   }
194   for (auto &send_actor : rpc_actors->send_actors_) {
195     MS_EXCEPTION_IF_NULL(send_actor);
196     send_actor->SetOpcontext(op_context);
197   }
198 
199   // Set op_context and rpc actor set for later usage.
200   op_context_ = op_context;
201   rpc_actors_ = rpc_actors;
202 }
203 
ResetOpcontext(const RpcActorSetPtr & rpc_actors)204 void RpcNodeScheduler::ResetOpcontext(const RpcActorSetPtr &rpc_actors) {
205   MS_EXCEPTION_IF_NULL(rpc_actors);
206 
207   for (auto &recv_actor : rpc_actors->recv_actors_) {
208     MS_EXCEPTION_IF_NULL(recv_actor);
209     recv_actor->ResetOpcontext();
210   }
211   for (auto &send_actor : rpc_actors->send_actors_) {
212     MS_EXCEPTION_IF_NULL(send_actor);
213     send_actor->ResetOpcontext();
214   }
215   op_context_ = nullptr;
216 }
217 
Clear()218 void RpcNodeScheduler::Clear() {
219   if (rpc_actors_ != nullptr) {
220     MS_LOG(INFO) << "Start finalizing tcp server and client for rpc actors.";
221     for (auto &recv_actor : rpc_actors_->recv_actors_) {
222       MS_EXCEPTION_IF_NULL(recv_actor);
223       recv_actor->Clear();
224     }
225     for (auto &send_actor : rpc_actors_->send_actors_) {
226       MS_EXCEPTION_IF_NULL(send_actor);
227       send_actor->Clear();
228     }
229     rpc_actors_ = nullptr;
230     MS_LOG(INFO) << "End finalizing tcp server and client for rpc actors.";
231   }
232 }
233 
Abort()234 void RpcNodeScheduler::Abort() {
235   MS_LOG(INFO) << "Start aborting rpc actors.";
236   if (rpc_actors_ == nullptr) {
237     return;
238   }
239   for (const auto &recv_actor : rpc_actors_->recv_actors_) {
240     MS_EXCEPTION_IF_NULL(recv_actor);
241     ActorDispatcher::Send(recv_actor->GetAID(), &RecvActor::StopRpcAtException);
242   }
243   MS_LOG(INFO) << "End aborting rpc actors.";
244 
245   if (op_context_ != nullptr) {
246     // Set op_context success to exit output actor.
247     SET_OPCONTEXT_SUCCESS_RET(*op_context_);
248   }
249 }
250 
UpdateRpcActorRefCounts(RpcActorSetPtr rpc_actor_set) const251 void RpcNodeScheduler::UpdateRpcActorRefCounts(RpcActorSetPtr rpc_actor_set) const {
252   MS_EXCEPTION_IF_NULL(rpc_actor_set);
253   for (const auto &send_actor : rpc_actor_set->send_actors_) {
254     MS_EXCEPTION_IF_NULL(send_actor);
255     auto kernel_mod = AnfAlgo::GetKernelMod(send_actor->kernel_);
256     MS_EXCEPTION_IF_NULL(kernel_mod);
257     size_t workspace_num = kernel_mod->GetWorkspaceSizeList().size();
258     if (workspace_num == 0) {
259       MS_LOG(EXCEPTION) << "Rpc send kernel must have workspace assigned.";
260     }
261     for (size_t i = 0; i < workspace_num; ++i) {
262       auto device_tensor = AnfAlgo::GetMutableWorkspaceAddr(send_actor->kernel_, i);
263       MS_EXCEPTION_IF_NULL(device_tensor);
264       UpdateRefCount(device_tensor.get());
265     }
266   }
267 }
268 
CreateRouteTableProxy() const269 ActorRouteTableProxyPtr RpcNodeScheduler::CreateRouteTableProxy() const {
270   ActorRouteTableProxyPtr actor_route_table_proxy;
271   if (!ClusterContext::instance()->IsScheduler()) {
272     auto cgn = std::dynamic_pointer_cast<distributed::cluster::topology::ComputeGraphNode>(
273       ClusterContext::instance()->node_base());
274     actor_route_table_proxy = std::make_shared<ActorRouteTableProxy>(cgn);
275     MS_EXCEPTION_IF_NULL(actor_route_table_proxy);
276   }
277   return actor_route_table_proxy;
278 }
279 
GetInstance()280 RpcActorStatusUpdater &RpcActorStatusUpdater::GetInstance() {
281   static RpcActorStatusUpdater instance;
282   return instance;
283 }
284 
set_rpc_actors(const std::string & graph_name,const RpcActorSetPtr & rpc_actors)285 void RpcActorStatusUpdater::set_rpc_actors(const std::string &graph_name, const RpcActorSetPtr &rpc_actors) {
286   if (rpc_actors != nullptr) {
287     graph_to_rpc_actors_[graph_name] = rpc_actors;
288   }
289 }
290 
UpdateRpcActorStatus(const std::string & graph_name)291 void RpcActorStatusUpdater::UpdateRpcActorStatus(const std::string &graph_name) {
292   // Update status for recv actors to control their execution orders.
293   if (graph_to_rpc_actors_.count(graph_name) != 0) {
294     auto rpc_actors = graph_to_rpc_actors_[graph_name];
295     if (rpc_actors.lock() != nullptr) {
296       for (auto &recv_actor : rpc_actors.lock()->recv_actors_) {
297         MS_EXCEPTION_IF_NULL(recv_actor);
298         recv_actor->UpdateStatus();
299       }
300     }
301   }
302 }
303 
FlushRpcData(const std::string & graph_name)304 void RpcActorStatusUpdater::FlushRpcData(const std::string &graph_name) {
305   // Flush data for send actors.
306   if (graph_to_rpc_actors_.count(graph_name) != 0) {
307     auto rpc_actors = graph_to_rpc_actors_[graph_name];
308     if (rpc_actors.lock() != nullptr) {
309       for (auto &send_actor : rpc_actors.lock()->send_actors_) {
310         MS_EXCEPTION_IF_NULL(send_actor);
311         send_actor->FlushData();
312       }
313     }
314   }
315 }
316 }  // namespace runtime
317 }  // namespace mindspore
318