1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/rpc_node_scheduler.h"
18 #include <vector>
19 #include <string>
20 #include "include/backend/distributed/cluster/topology/compute_graph_node.h"
21 #include "include/common/utils/anfalgo.h"
22 #include "runtime/graph_scheduler/actor/rpc/mux_send_actor.h"
23 #include "runtime/graph_scheduler/actor/rpc/mux_recv_actor.h"
24
25 namespace mindspore {
26 namespace runtime {
27 namespace {
28 // MuxSendActor and MuxRecvActor of the server are used in pairs, and the MuxSendActor
29 // needs to obtain the information(ip and port) of peer that initiates this service from the corresponding
30 // MuxRecvActor to response request, so need to set the paired MuxRecvActor for MuxSendActor.
SetMuxRecvActorForMuxSendActor(const RpcActorSetPtr & rpc_actor_set)31 void SetMuxRecvActorForMuxSendActor(const RpcActorSetPtr &rpc_actor_set) {
32 MS_EXCEPTION_IF_NULL(rpc_actor_set);
33
34 // 1. Check whether exist mux recv actor.
35 bool exist_mux_recv_actor = false;
36 std::vector<RecvActorPtr> recv_actors;
37 for (const auto &recv_actor : rpc_actor_set->recv_actors_) {
38 MS_EXCEPTION_IF_NULL(recv_actor);
39 CNodePtr rpc_recv_kernel = recv_actor->kernel();
40 MS_EXCEPTION_IF_NULL(rpc_recv_kernel);
41 if (common::AnfAlgo::HasNodeAttr(kAttrIsMuxRpcKernel, rpc_recv_kernel) &&
42 (common::AnfAlgo::GetNodeAttr<bool>(rpc_recv_kernel, kAttrIsMuxRpcKernel) == true)) {
43 exist_mux_recv_actor = true;
44 (void)recv_actors.emplace_back(recv_actor);
45 }
46 }
47
48 if (!exist_mux_recv_actor) {
49 return;
50 }
51 if (recv_actors.size() != 1) {
52 MS_LOG(EXCEPTION) << "Currently the actor set is only allowed to contain one MuxRecvActor, but got: "
53 << recv_actors.size();
54 }
55
56 // 2. Set mux recv actor for mux send actor.
57 MuxRecvActorPtr mux_recv_actor = std::dynamic_pointer_cast<MuxRecvActor>(recv_actors.front());
58 MS_EXCEPTION_IF_NULL(mux_recv_actor);
59
60 for (const auto &send_actor : rpc_actor_set->send_actors_) {
61 MS_EXCEPTION_IF_NULL(send_actor);
62 MuxSendActorPtr mux_send_actor = std::dynamic_pointer_cast<MuxSendActor>(send_actor);
63 MS_EXCEPTION_IF_NULL(mux_send_actor);
64 mux_send_actor->set_mux_recv_actor(mux_recv_actor);
65 }
66 }
67 } // namespace
68
Build(const ActorSet * actor_set)69 RpcActorSetPtr RpcNodeScheduler::Build(const ActorSet *actor_set) {
70 MS_EXCEPTION_IF_NULL(actor_set);
71
72 // RpcActor inherits from KernelActor, so we need to filter out the rpc actors from kernel actors list.
73 std::vector<KernelActorPtr> kernel_actors = actor_set->kernel_actors_;
74 RpcActorSetPtr rpc_actor_set = std::make_shared<RpcActorSet>();
75 MS_EXCEPTION_IF_NULL(rpc_actor_set);
76
77 std::vector<RpcActorPtr> rpc_actors;
78 for (const auto &kernel_actor : kernel_actors) {
79 auto rpc_actor = std::dynamic_pointer_cast<RpcActor>(kernel_actor);
80 if (std::dynamic_pointer_cast<RpcActor>(kernel_actor) == nullptr) {
81 continue;
82 } else {
83 (void)rpc_actors.emplace_back(rpc_actor);
84 if (std::dynamic_pointer_cast<SendActor>(rpc_actor) != nullptr) {
85 (void)rpc_actor_set->send_actors_.emplace_back(std::dynamic_pointer_cast<SendActor>(rpc_actor));
86 } else if (std::dynamic_pointer_cast<RecvActor>(rpc_actor) != nullptr) {
87 (void)rpc_actor_set->recv_actors_.emplace_back(std::dynamic_pointer_cast<RecvActor>(rpc_actor));
88 } else {
89 MS_LOG(EXCEPTION) << "Rpc actor should be either SendActor or RecvActor.";
90 }
91 }
92 }
93
94 // Set the paired MuxRecvActor for MuxSendActor, used in embedding cache case.
95 SetMuxRecvActorForMuxSendActor(rpc_actor_set);
96
97 // Create route table proxy for each rpc actor and set.
98 for (auto &rpc_actor : rpc_actors) {
99 auto proxy = CreateRouteTableProxy();
100 MS_EXCEPTION_IF_NULL(rpc_actor);
101 MS_EXCEPTION_IF_NULL(proxy);
102 rpc_actor->set_actor_route_table_proxy(proxy);
103 }
104
105 // Update the reference counts of rpc kernel inputs and workspaces.
106 UpdateRpcActorRefCounts(rpc_actor_set);
107
108 return rpc_actor_set;
109 }
110
Link(const ActorSet * actor_set) const111 void RpcNodeScheduler::Link(const ActorSet *actor_set) const {
112 MS_EXCEPTION_IF_NULL(actor_set);
113 RpcActorSetPtr rpc_actor_set = actor_set->rpc_actors_;
114 MS_EXCEPTION_IF_NULL(rpc_actor_set);
115 std::vector<SendActorPtr> send_actors = rpc_actor_set->send_actors_;
116 std::vector<RecvActorPtr> recv_actors = rpc_actor_set->recv_actors_;
117
118 // The inter-process edge is connected to a remote peer. So the peer info attributes in the kernel should be
119 // sufficient for route table.
120 for (auto &send_actor : send_actors) {
121 MS_EXCEPTION_IF_NULL(send_actor);
122 CNodePtr rpc_send_kernel = send_actor->kernel();
123 MS_EXCEPTION_IF_NULL(rpc_send_kernel);
124
125 auto send_dst_ranks = common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(rpc_send_kernel, kAttrSendDstRanks);
126 auto send_dst_roles = common::AnfAlgo::GetNodeAttr<std::vector<std::string>>(rpc_send_kernel, kAttrSendDstRoles);
127 std::string send_src_node_name = common::AnfAlgo::GetNodeAttr<std::string>(rpc_send_kernel, kAttrSendSrcNodeName);
128 std::string send_dst_node_name = common::AnfAlgo::GetNodeAttr<std::string>(rpc_send_kernel, kAttrSendDstNodeName);
129 std::vector<std::string> edge_names =
130 common::AnfAlgo::GetNodeAttr<std::vector<std::string>>(rpc_send_kernel, kAttrInterProcessEdgeNames);
131
132 if (send_dst_ranks.empty() || send_dst_roles.empty()) {
133 MS_LOG_WITH_NODE(EXCEPTION, rpc_send_kernel)
134 << "The attributes of send node " << rpc_send_kernel->fullname_with_scope()
135 << " is invalid. send_dst_ranks: " << send_dst_ranks << ", send_dst_roles: " << send_dst_roles
136 << ", send_src_node_name: " << send_src_node_name << ", send_dst_node_name: " << send_dst_node_name;
137 }
138 send_actor->set_inter_process_edge_names(edge_names);
139 send_actor->SetRouteInfo(send_dst_ranks[0], send_dst_roles[0], send_src_node_name, send_dst_node_name);
140 }
141 for (auto &recv_actor : recv_actors) {
142 MS_EXCEPTION_IF_NULL(recv_actor);
143 CNodePtr rpc_recv_kernel = recv_actor->kernel();
144 MS_EXCEPTION_IF_NULL(rpc_recv_kernel);
145
146 auto recv_src_ranks = common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(rpc_recv_kernel, kAttrRecvSrcRanks);
147 auto recv_src_roles = common::AnfAlgo::GetNodeAttr<std::vector<std::string>>(rpc_recv_kernel, kAttrRecvSrcRoles);
148 std::string recv_src_node_name = common::AnfAlgo::GetNodeAttr<std::string>(rpc_recv_kernel, kAttrRecvSrcNodeName);
149 std::string recv_dst_node_name = common::AnfAlgo::GetNodeAttr<std::string>(rpc_recv_kernel, kAttrRecvDstNodeName);
150 std::vector<std::string> edge_names =
151 common::AnfAlgo::GetNodeAttr<std::vector<std::string>>(rpc_recv_kernel, kAttrInterProcessEdgeNames);
152
153 if (recv_src_ranks.empty() || recv_src_roles.empty()) {
154 MS_LOG_WITH_NODE(EXCEPTION, rpc_recv_kernel)
155 << "The attributes of recv node " << rpc_recv_kernel->fullname_with_scope()
156 << " is invalid. recv_src_ranks: " << recv_src_ranks << ", recv_src_roles: " << recv_src_roles
157 << ", recv_src_node_name: " << recv_src_node_name << ", recv_dst_node_name: " << recv_dst_node_name;
158 }
159 recv_actor->set_inter_process_edge_names(edge_names);
160 recv_actor->SetRouteInfo(recv_src_ranks[0], recv_src_roles[0], recv_src_node_name, recv_dst_node_name);
161 }
162 }
163
Schedule(const ActorSet * actor_set) const164 void RpcNodeScheduler::Schedule(const ActorSet *actor_set) const {
165 MS_EXCEPTION_IF_NULL(actor_set);
166 RpcActorSetPtr rpc_actor_set = actor_set->rpc_actors_;
167 MS_EXCEPTION_IF_NULL(rpc_actor_set);
168 // Must start server and register route table before looking up route and connecting.
169
170 // Start servers of recv actors and register route table.
171 for (auto &recv_actor : rpc_actor_set->recv_actors_) {
172 MS_EXCEPTION_IF_NULL(recv_actor);
173 if (!recv_actor->StartServer()) {
174 MS_LOG(EXCEPTION) << "Failed to start server for the recv actor.";
175 }
176 }
177 // Lookup route and connect to servers for send actors.
178 for (auto &send_actor : rpc_actor_set->send_actors_) {
179 MS_EXCEPTION_IF_NULL(send_actor);
180 if (!send_actor->ConnectServer()) {
181 MS_LOG(EXCEPTION) << "Failed to connect servers for the send actor.";
182 }
183 }
184 }
185
SetOpcontext(const RpcActorSetPtr & rpc_actors,OpContext<DeviceTensor> * const op_context)186 void RpcNodeScheduler::SetOpcontext(const RpcActorSetPtr &rpc_actors, OpContext<DeviceTensor> *const op_context) {
187 MS_EXCEPTION_IF_NULL(op_context);
188 MS_EXCEPTION_IF_NULL(rpc_actors);
189
190 for (auto &recv_actor : rpc_actors->recv_actors_) {
191 MS_EXCEPTION_IF_NULL(recv_actor);
192 recv_actor->SetOpcontext(op_context);
193 }
194 for (auto &send_actor : rpc_actors->send_actors_) {
195 MS_EXCEPTION_IF_NULL(send_actor);
196 send_actor->SetOpcontext(op_context);
197 }
198
199 // Set op_context and rpc actor set for later usage.
200 op_context_ = op_context;
201 rpc_actors_ = rpc_actors;
202 }
203
ResetOpcontext(const RpcActorSetPtr & rpc_actors)204 void RpcNodeScheduler::ResetOpcontext(const RpcActorSetPtr &rpc_actors) {
205 MS_EXCEPTION_IF_NULL(rpc_actors);
206
207 for (auto &recv_actor : rpc_actors->recv_actors_) {
208 MS_EXCEPTION_IF_NULL(recv_actor);
209 recv_actor->ResetOpcontext();
210 }
211 for (auto &send_actor : rpc_actors->send_actors_) {
212 MS_EXCEPTION_IF_NULL(send_actor);
213 send_actor->ResetOpcontext();
214 }
215 op_context_ = nullptr;
216 }
217
Clear()218 void RpcNodeScheduler::Clear() {
219 if (rpc_actors_ != nullptr) {
220 MS_LOG(INFO) << "Start finalizing tcp server and client for rpc actors.";
221 for (auto &recv_actor : rpc_actors_->recv_actors_) {
222 MS_EXCEPTION_IF_NULL(recv_actor);
223 recv_actor->Clear();
224 }
225 for (auto &send_actor : rpc_actors_->send_actors_) {
226 MS_EXCEPTION_IF_NULL(send_actor);
227 send_actor->Clear();
228 }
229 rpc_actors_ = nullptr;
230 MS_LOG(INFO) << "End finalizing tcp server and client for rpc actors.";
231 }
232 }
233
Abort()234 void RpcNodeScheduler::Abort() {
235 MS_LOG(INFO) << "Start aborting rpc actors.";
236 if (rpc_actors_ == nullptr) {
237 return;
238 }
239 for (const auto &recv_actor : rpc_actors_->recv_actors_) {
240 MS_EXCEPTION_IF_NULL(recv_actor);
241 ActorDispatcher::Send(recv_actor->GetAID(), &RecvActor::StopRpcAtException);
242 }
243 MS_LOG(INFO) << "End aborting rpc actors.";
244
245 if (op_context_ != nullptr) {
246 // Set op_context success to exit output actor.
247 SET_OPCONTEXT_SUCCESS_RET(*op_context_);
248 }
249 }
250
UpdateRpcActorRefCounts(RpcActorSetPtr rpc_actor_set) const251 void RpcNodeScheduler::UpdateRpcActorRefCounts(RpcActorSetPtr rpc_actor_set) const {
252 MS_EXCEPTION_IF_NULL(rpc_actor_set);
253 for (const auto &send_actor : rpc_actor_set->send_actors_) {
254 MS_EXCEPTION_IF_NULL(send_actor);
255 auto kernel_mod = AnfAlgo::GetKernelMod(send_actor->kernel_);
256 MS_EXCEPTION_IF_NULL(kernel_mod);
257 size_t workspace_num = kernel_mod->GetWorkspaceSizeList().size();
258 if (workspace_num == 0) {
259 MS_LOG(EXCEPTION) << "Rpc send kernel must have workspace assigned.";
260 }
261 for (size_t i = 0; i < workspace_num; ++i) {
262 auto device_tensor = AnfAlgo::GetMutableWorkspaceAddr(send_actor->kernel_, i);
263 MS_EXCEPTION_IF_NULL(device_tensor);
264 UpdateRefCount(device_tensor.get());
265 }
266 }
267 }
268
CreateRouteTableProxy() const269 ActorRouteTableProxyPtr RpcNodeScheduler::CreateRouteTableProxy() const {
270 ActorRouteTableProxyPtr actor_route_table_proxy;
271 if (!ClusterContext::instance()->IsScheduler()) {
272 auto cgn = std::dynamic_pointer_cast<distributed::cluster::topology::ComputeGraphNode>(
273 ClusterContext::instance()->node_base());
274 actor_route_table_proxy = std::make_shared<ActorRouteTableProxy>(cgn);
275 MS_EXCEPTION_IF_NULL(actor_route_table_proxy);
276 }
277 return actor_route_table_proxy;
278 }
279
GetInstance()280 RpcActorStatusUpdater &RpcActorStatusUpdater::GetInstance() {
281 static RpcActorStatusUpdater instance;
282 return instance;
283 }
284
set_rpc_actors(const std::string & graph_name,const RpcActorSetPtr & rpc_actors)285 void RpcActorStatusUpdater::set_rpc_actors(const std::string &graph_name, const RpcActorSetPtr &rpc_actors) {
286 if (rpc_actors != nullptr) {
287 graph_to_rpc_actors_[graph_name] = rpc_actors;
288 }
289 }
290
UpdateRpcActorStatus(const std::string & graph_name)291 void RpcActorStatusUpdater::UpdateRpcActorStatus(const std::string &graph_name) {
292 // Update status for recv actors to control their execution orders.
293 if (graph_to_rpc_actors_.count(graph_name) != 0) {
294 auto rpc_actors = graph_to_rpc_actors_[graph_name];
295 if (rpc_actors.lock() != nullptr) {
296 for (auto &recv_actor : rpc_actors.lock()->recv_actors_) {
297 MS_EXCEPTION_IF_NULL(recv_actor);
298 recv_actor->UpdateStatus();
299 }
300 }
301 }
302 }
303
FlushRpcData(const std::string & graph_name)304 void RpcActorStatusUpdater::FlushRpcData(const std::string &graph_name) {
305 // Flush data for send actors.
306 if (graph_to_rpc_actors_.count(graph_name) != 0) {
307 auto rpc_actors = graph_to_rpc_actors_[graph_name];
308 if (rpc_actors.lock() != nullptr) {
309 for (auto &send_actor : rpc_actors.lock()->send_actors_) {
310 MS_EXCEPTION_IF_NULL(send_actor);
311 send_actor->FlushData();
312 }
313 }
314 }
315 }
316 } // namespace runtime
317 } // namespace mindspore
318