1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <mutex>
18 #include <vector>
19 #include <string>
20 #include <memory>
21 #include "include/backend/distributed/cluster/cluster_context.h"
22 #include "include/backend/distributed/cluster/topology/common.h"
23 #include "include/backend/distributed/recovery/recovery_context.h"
24 #include "include/backend/distributed/cluster/topology/compute_graph_node.h"
25 #include "distributed/cluster/topology/meta_server_node.h"
26 #include "distributed/cluster/actor_route_table_proxy.h"
27 #include "include/backend/distributed/collective/collective_manager.h"
28 #include "proto/topology.pb.h"
29 #include "utils/ms_context.h"
30 #include "include/backend/distributed/ps/ps_context.h"
31 #include "ps/core/comm_util.h"
32 #include "ps/core/cluster_config.h"
33 #include "include/common/debug/common.h"
34
35 namespace mindspore {
36 namespace distributed {
37 namespace cluster {
ClusterContext()38 ClusterContext::ClusterContext()
39 : inited_(false),
40 finalized_(true),
41 cluster_exit_with_exception_(false),
42 node_num_each_role_({}),
43 scheduler_host_(kLocalHost),
44 scheduler_port_(kDefaultSchedPort),
45 node_id_(""),
46 node_role_(""),
47 cluster_config_(nullptr) {}
48
~ClusterContext()49 ClusterContext::~ClusterContext() {
50 if (!finalized_) {
51 try {
52 const uint32_t timeout = 0;
53 (void)Finalize(timeout);
54 } catch (std::exception &) {
55 MS_LOG(ERROR) << "Failed to finalize cluster context.";
56 }
57 }
58 finalized_ = true;
59 }
60
instance()61 std::shared_ptr<ClusterContext> ClusterContext::instance() {
62 static std::once_flag init_flag;
63 static std::shared_ptr<ClusterContext> cluster_instance = nullptr;
64 std::call_once(init_flag, [&]() {
65 if (cluster_instance == nullptr) {
66 cluster_instance.reset(new (std::nothrow) ClusterContext());
67 MS_EXCEPTION_IF_NULL(cluster_instance);
68 }
69 });
70
71 return cluster_instance;
72 }
73
Initialize()74 bool ClusterContext::Initialize() {
75 if (inited_) {
76 MS_LOG(INFO) << "The cluster has been initialized.";
77 return true;
78 }
79
80 // Step 1: Initialize cluster configuration.
81 InitClusterConfig();
82
83 // Step 2: Build network for this cluster. Every process will block in this method until networking is done.
84 if (!BuildCluster()) {
85 MsException::Instance().CheckException();
86 MS_LOG(ERROR) << "Building networking for " << node_role_ << " failed.";
87 return false;
88 }
89
90 // Step 3: Initialize some modules for the node, e.g., actor route table proxy.
91 if (!IsScheduler()) {
92 // Only node which is not the scheduler needs route table proxy.
93 auto cgn = std::dynamic_pointer_cast<distributed::cluster::topology::ComputeGraphNode>(node_base_);
94 MS_EXCEPTION_IF_NULL(cgn);
95 actor_route_table_proxy_ = std::make_shared<ActorRouteTableProxy>(cgn);
96 MS_EXCEPTION_IF_NULL(actor_route_table_proxy_);
97 }
98
99 inited_ = true;
100 finalized_ = false;
101 return true;
102 }
103
Finalize(uint32_t timeout)104 bool ClusterContext::Finalize(uint32_t timeout) {
105 if (finalized_) {
106 return true;
107 }
108 MS_EXCEPTION_IF_NULL(node_base_);
109
110 bool force = (timeout == 0);
111 uint32_t interval = 5;
112 while (!node_base_->Finalize(force)) {
113 MS_LOG(WARNING)
114 << "This log means the cluster is successfully created. Retry to finalize the node and exit cluster...";
115 (void)sleep(interval);
116 }
117 finalized_ = true;
118 return true;
119 }
120
IsScheduler()121 bool ClusterContext::IsScheduler() { return node_role_ == kEnvRoleOfScheduler; }
122
node() const123 const std::shared_ptr<topology::NodeBase> &ClusterContext::node() const { return node_base_; }
124
node_base() const125 const std::shared_ptr<topology::NodeBase> &ClusterContext::node_base() const { return node_base_; }
126
node_role() const127 const std::string &ClusterContext::node_role() const { return node_role_; }
128
node_num(const std::string & node_role)129 uint32_t ClusterContext::node_num(const std::string &node_role) {
130 if (node_num_each_role_.count(node_role) == 0) {
131 MS_LOG(EXCEPTION) << "Node role " << node_role << " is invalid.";
132 }
133 MS_LOG(INFO) << "Number of role " << node_role << " is " << node_num_each_role_[node_role];
134 return node_num_each_role_[node_role];
135 }
136
node_num() const137 uint32_t ClusterContext::node_num() const {
138 uint32_t node_num = 0;
139 for (auto iter = node_num_each_role_.begin(); iter != node_num_each_role_.end(); ++iter) {
140 if (iter->first != kEnvRoleOfScheduler) {
141 node_num += iter->second;
142 }
143 }
144 return node_num;
145 }
146
initialized() const147 bool ClusterContext::initialized() const { return inited_; }
148
actor_route_table_proxy() const149 const ActorRouteTableProxyPtr &ClusterContext::actor_route_table_proxy() const { return actor_route_table_proxy_; }
150
set_cluster_exit_with_exception()151 void ClusterContext::set_cluster_exit_with_exception() { cluster_exit_with_exception_ = true; }
152
cluster_exit_with_exception() const153 bool ClusterContext::cluster_exit_with_exception() const { return cluster_exit_with_exception_; }
154
InitClusterConfig()155 void ClusterContext::InitClusterConfig() {
156 InitNodeRole();
157 InitSchedulerIp();
158 InitSchedulerPort();
159 ps::PSContext::instance()->set_ms_role(node_role_);
160 ps::PSContext::instance()->set_worker_num(node_num_each_role_[kEnvRoleOfWorker]);
161 ps::PSContext::instance()->set_server_num(node_num_each_role_[kEnvRoleOfServer]);
162 ps::PSContext::instance()->set_scheduler_ip(scheduler_host_);
163 ps::PSContext::instance()->set_scheduler_port(scheduler_port_);
164 ps::PSContext::instance()->cluster_config().initial_worker_num = node_num_each_role_[kEnvRoleOfWorker];
165 ps::PSContext::instance()->cluster_config().initial_server_num = node_num_each_role_[kEnvRoleOfServer];
166 ps::PSContext::instance()->cluster_config().scheduler_host = scheduler_host_;
167 ps::PSContext::instance()->cluster_config().scheduler_port = scheduler_port_;
168 }
169
BuildCluster()170 bool ClusterContext::BuildCluster() {
171 // Get node_id from environment configuration or uuid generator.
172 node_id_ = common::GetEnv(kNodeId);
173 if (node_id_.length() == 0) {
174 node_id_ = ps::core::CommUtil::GenerateUUID();
175 }
176 // Init the node according to the process role.
177 if (node_role_ == kEnvRoleOfScheduler) {
178 auto node_num = node_num_each_role_[kEnvRoleOfWorker] + node_num_each_role_[kEnvRoleOfServer];
179 node_base_ = std::make_shared<topology::MetaServerNode>(node_id_, node_role_, node_num);
180 } else {
181 node_base_ = std::make_shared<topology::ComputeGraphNode>(node_id_, node_role_);
182 }
183 MS_EXCEPTION_IF_NULL(node_base_);
184 // For cgn, 'Initialize' will block until it connect to msn, or time out.
185 RETURN_IF_FALSE_WITH_LOG(node_base_->Initialize(), "Failed to initialize the node.");
186
187 // Check the state of topology construction.
188 auto check_func = [this]() -> bool {
189 // Check exception thrown by child threads in cgn or msn.
190 MsException::Instance().CheckException();
191 return this->node_base_->Initialized();
192 };
193 size_t retry_num = node_base_->topo_timeout() / topology::kExecuteInterval;
194 EXECUTE_WITH_RETRY(check_func, retry_num, topology::kExecuteInterval, "Topology build timed out.");
195
196 MS_LOG(WARNING) << "Cluster is successfully initialized.";
197 PostProcess();
198 return true;
199 }
200
InitNodeRole()201 void ClusterContext::InitNodeRole() {
202 node_role_ = common::GetEnv(kEnvRole);
203 if (kValidRoleName.count(node_role_) == 0) {
204 MS_LOG(EXCEPTION) << "Role name '" << node_role_ << "' is invalid. " << kDetailedFailureReason;
205 }
206
207 if (common::GetEnv(kEnvWorkerNum).empty()) {
208 if (node_role_ == kEnvRoleOfWorker) {
209 MS_LOG(EXCEPTION) << "Please set env 'WORKER_NUM' to a number greater than 0.";
210 }
211 node_num_each_role_[kEnvRoleOfWorker] = 0;
212 } else {
213 TRY_AND_CATCH_WITH_EXCEPTION(
214 (node_num_each_role_[kEnvRoleOfWorker] = IntToUint(std::stoi(common::GetEnv(kEnvWorkerNum)))),
215 "The environment variable MS_WORKER_NUM is invalid.");
216 }
217
218 // MS_PSERVER is supported for now. It should be deprecated after we use cluster for distributed training.
219 if (common::GetEnv(kEnvServerNum).empty()) {
220 if (node_role_ == kEnvRoleOfServer || node_role_ == kEnvRoleOfPServer) {
221 MS_LOG(EXCEPTION) << "Please set env 'SERVER_NUM' to a number greater than 0.";
222 }
223 node_num_each_role_[kEnvRoleOfServer] = 0;
224 node_num_each_role_[kEnvRoleOfPServer] = 0;
225 } else {
226 TRY_AND_CATCH_WITH_EXCEPTION(
227 (node_num_each_role_[kEnvRoleOfServer] = IntToUint(std::stoi(common::GetEnv(kEnvServerNum)))),
228 "The environment variable MS_SERVER_NUM is invalid.");
229 TRY_AND_CATCH_WITH_EXCEPTION(
230 (node_num_each_role_[kEnvRoleOfPServer] = IntToUint(std::stoi(common::GetEnv(kEnvServerNum)))),
231 "The environment variable MS_SERVER_NUM is invalid.");
232 }
233 }
234
InitSchedulerIp()235 void ClusterContext::InitSchedulerIp() {
236 scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
237 if (scheduler_host_.empty()) {
238 MS_LOG(EXCEPTION) << kEnvSchedulerHost << " is empty. " << kEnvSchedulerHost;
239 }
240 }
241
InitSchedulerPort()242 void ClusterContext::InitSchedulerPort() {
243 TRY_AND_CATCH_WITH_EXCEPTION((scheduler_port_ = static_cast<uint16_t>(std::stoi(common::GetEnv(kEnvSchedulerPort)))),
244 "The environment variable MS_SCHED_PORT is invalid.");
245 if (scheduler_port_ > kMaxPort) {
246 MS_LOG(EXCEPTION) << "The port: " << scheduler_port_ << " is invalid.";
247 }
248 }
249
PostProcess()250 void ClusterContext::PostProcess() {
251 if (node_role_ != kEnvRoleOfScheduler) {
252 auto cgn = std::dynamic_pointer_cast<topology::ComputeGraphNode>(node_base_);
253 MS_EXCEPTION_IF_NULL(cgn);
254 MS_LOG(INFO) << "Start post processing for computing graph nodes.";
255
256 // 1. Get new rank id from meta server node because it may be reassigned.
257 std::string final_rank_id = cgn->GetMetadata(node_role_ + node_id_);
258 if (!final_rank_id.empty()) {
259 cgn->set_rank_id(static_cast<uint32_t>(std::atoi(final_rank_id.c_str())));
260 MS_LOG(WARNING) << "This node " << node_id_ << " rank id: " << final_rank_id;
261 } else {
262 MS_LOG(WARNING) << "This node could be redundant and is not successfully registered.";
263 }
264
265 // 2. Set this node's client ip address in this cluster.
266 const std::string &client_ip_in_cluster = cgn->client_ip();
267 MS_LOG(INFO) << "Client ip address in this cluster of this compute graph node is " << client_ip_in_cluster;
268 (void)common::SetEnv(kEnvWorkerIp, client_ip_in_cluster.c_str());
269
270 // 3. Set port range of this node.
271 std::string port_range_pb = cgn->GetMetadata(kNodePortRange);
272 topology::NodePortRanges node_port_ranges;
273 (void)node_port_ranges.ParseFromArray(port_range_pb.c_str(), SizeToInt(port_range_pb.size()));
274 if (node_port_ranges.data().count(node_id_) != 0) {
275 auto port_range = node_port_ranges.data().at(node_id_);
276 port_range_.first = port_range.min_port();
277 port_range_.second = port_range.max_port();
278 MS_LOG(INFO) << "Port range assigned for this node " << node_id_ << " is " << port_range_.first << " to "
279 << port_range_.second;
280 }
281 }
282 }
283 } // namespace cluster
284 } // namespace distributed
285 } // namespace mindspore
286