1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/backend/distributed/ps/ps_context.h"
18
19 #include "ir/tensor.h"
20 #include "kernel/kernel.h"
21 #include "utils/log_adapter.h"
22 #include "utils/ms_utils.h"
23 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
24 #include "include/backend/distributed/cluster/cluster_context.h"
25 #include "include/backend/distributed/ps/ps_cache/ps_data_prefetch.h"
26 #include "ps/core/cluster_config.h"
27 #include "include/backend/distributed/embedding_cache/embedding_cache_utils.h"
28 #else
29 #include "include/backend/distributed/cluster/dummy_cluster_context.h"
30 #include "ps/core/cluster_config.h"
31 #endif
32
33 namespace mindspore {
34 namespace ps {
35 namespace {
36 constexpr uint16_t kDefaultSchedPort = 6667;
37 constexpr uint16_t kDefaultSchedManagerPort = 11202;
38 } // namespace
PSContext()39 PSContext::PSContext()
40 : ps_enabled_(false),
41 is_worker_(false),
42 is_pserver_(false),
43 is_sched_(false),
44 rank_id_(0),
45 worker_num_(0),
46 server_num_(0),
47 scheduler_host_("0.0.0.0"),
48 scheduler_port_(kDefaultSchedPort),
49 role_(kEnvRoleOfNotPS),
50 server_mode_(""),
51 cluster_config_(nullptr),
52 scheduler_manage_port_(kDefaultSchedManagerPort),
53 config_file_path_(""),
54 node_id_(""),
55 enable_ssl_(false),
56 client_password_(),
57 server_password_(),
58 http_url_prefix_(""),
59 instance_name_("") {}
60
~PSContext()61 PSContext::~PSContext() {}
62
instance()63 std::shared_ptr<PSContext> PSContext::instance() {
64 static std::once_flag init_flag;
65 static std::shared_ptr<PSContext> ps_instance = nullptr;
66 std::call_once(init_flag, [&]() {
67 if (ps_instance == nullptr) {
68 ps_instance.reset(new (std::nothrow) PSContext());
69 MS_EXCEPTION_IF_NULL(ps_instance);
70 }
71 });
72
73 return ps_instance;
74 }
75
SetPSEnable(bool enabled)76 void PSContext::SetPSEnable(bool enabled) {
77 ps_enabled_ = enabled;
78 if (ps_enabled_) {
79 std::string ms_role = common::GetEnv(kEnvRole);
80 if (ms_role == "") {
81 ms_role = this->ms_role();
82 }
83 MS_LOG(INFO) << "PS mode is enabled. MS_ROLE is " << ms_role;
84
85 if (ms_role == kEnvRoleOfWorker) {
86 is_worker_ = true;
87 } else if (ms_role == kEnvRoleOfPServer || ms_role == kEnvRoleOfServer) {
88 is_pserver_ = true;
89 } else if (ms_role == kEnvRoleOfScheduler) {
90 is_sched_ = true;
91 }
92
93 worker_num_ = std::strtol(common::GetEnv(kEnvWorkerNum).c_str(), nullptr, kBase);
94 server_num_ = std::strtol(common::GetEnv(kEnvPServerNum).c_str(), nullptr, kBase);
95 scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
96 if (scheduler_host_.length() > kLength) {
97 MS_LOG(EXCEPTION) << "The scheduler host's length can not exceed " << kLength;
98 }
99 scheduler_port_ = std::strtol(common::GetEnv(kEnvSchedulerPort).c_str(), nullptr, kBase);
100 if (scheduler_port_ > kMaxPort) {
101 MS_LOG(EXCEPTION) << "The port: " << scheduler_port_ << " is illegal.";
102 }
103 scheduler_manage_port_ =
104 static_cast<uint16_t>((std::strtol(common::GetEnv(kEnvSchedulerManagePort).c_str(), nullptr, kBase)));
105 if (scheduler_manage_port_ > kMaxPort) {
106 MS_LOG(EXCEPTION) << "The port << " << scheduler_manage_port_ << " is illegal.";
107 }
108 cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
109 node_id_ = common::GetEnv(kEnvNodeId);
110 if (node_id_.length() > kLength) {
111 MS_LOG(EXCEPTION) << "The node id length can not exceed " << kLength;
112 }
113 server_mode_ = kServerModePS;
114 } else {
115 MS_LOG(INFO) << "PS mode is disabled.";
116 is_worker_ = false;
117 is_pserver_ = false;
118 is_sched_ = false;
119 }
120 }
121
is_ps_mode() const122 bool PSContext::is_ps_mode() const { return ps_enabled_; }
123
Reset()124 void PSContext::Reset() {
125 ps_enabled_ = false;
126 is_worker_ = false;
127 is_pserver_ = false;
128 is_sched_ = false;
129 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
130 if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
131 set_cache_enable(false);
132 }
133 #endif
134 }
135
ms_role() const136 std::string PSContext::ms_role() const {
137 if (is_worker_) {
138 return kEnvRoleOfWorker;
139 } else if (is_pserver_) {
140 return kEnvRoleOfPServer;
141 } else if (is_sched_) {
142 return kEnvRoleOfScheduler;
143 } else {
144 return kEnvRoleOfNotPS;
145 }
146 }
147
is_worker() const148 bool PSContext::is_worker() const {
149 if (distributed::cluster::ClusterContext::instance()->initialized()) {
150 return role_ == kEnvRoleOfWorker;
151 }
152 return is_worker_;
153 }
154
is_server() const155 bool PSContext::is_server() const {
156 if (distributed::cluster::ClusterContext::instance()->initialized()) {
157 return role_ == kEnvRoleOfServer || role_ == kEnvRoleOfPServer;
158 }
159 return is_pserver_;
160 }
161
is_scheduler() const162 bool PSContext::is_scheduler() const {
163 if (distributed::cluster::ClusterContext::instance()->initialized()) {
164 return role_ == kEnvRoleOfScheduler;
165 }
166 return is_sched_;
167 }
168
initial_worker_num() const169 uint32_t PSContext::initial_worker_num() const { return worker_num_; }
170
initial_server_num() const171 uint32_t PSContext::initial_server_num() const { return server_num_; }
172
scheduler_host() const173 std::string PSContext::scheduler_host() const { return scheduler_host_; }
174
SetPSRankId(uint32_t rank_id)175 void PSContext::SetPSRankId(uint32_t rank_id) { rank_id_ = rank_id; }
176
ps_rank_id() const177 uint32_t PSContext::ps_rank_id() const { return rank_id_; }
178
InsertHashTableSize(const std::string & param_name,size_t cache_vocab_size,size_t embedding_size,size_t vocab_size,int32_t param_key) const179 void PSContext::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size,
180 size_t vocab_size, int32_t param_key) const {
181 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
182 if (enable_distributed_mindrt()) {
183 embedding_cache_table_manager.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size,
184 param_key);
185 }
186 #endif
187 }
188
ReInsertHashTableSize(const std::string & new_param_name,const std::string & cur_param_name) const189 void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name) const {
190 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
191 if (enable_distributed_mindrt()) {
192 embedding_cache_table_manager.ReInsertHashTableSize(new_param_name, cur_param_name);
193 }
194 #endif
195 }
196
InsertAccumuInitInfo(const std::string & param_name,float init_val) const197 void PSContext::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const {
198 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
199 embedding_cache_table_manager.InsertAccumuInitInfo(param_name, init_val);
200 #endif
201 }
202
CloneHashTable(const std::string & dest_param_name,int32_t dest_param_key,const std::string & src_param_name,int32_t src_param_key) const203 void PSContext::CloneHashTable(const std::string &dest_param_name, int32_t dest_param_key,
204 const std::string &src_param_name, int32_t src_param_key) const {
205 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
206 if (enable_distributed_mindrt()) {
207 embedding_cache_table_manager.CloneHashTable(dest_param_name, dest_param_key, src_param_name, src_param_key);
208 }
209 #endif
210 }
211
set_cache_enable(bool cache_enable) const212 void PSContext::set_cache_enable(bool cache_enable) const {
213 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
214 PsDataPrefetch::GetInstance().set_cache_enable(cache_enable);
215 #endif
216 }
217
cache_enable() const218 bool PSContext::cache_enable() const {
219 #if ((defined ENABLE_CPU) && (!defined _WIN32)) && !defined(__APPLE__)
220 return PsDataPrefetch::GetInstance().cache_enable();
221 #endif
222 return false;
223 }
224
set_cache_size(size_t cache_size) const225 void PSContext::set_cache_size(size_t cache_size) const {
226 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
227 distributed::EmbeddingCacheTableManager::GetInstance().set_cache_size(cache_size);
228 #endif
229 }
230
set_sparse_format(bool is_sparse)231 void PSContext::set_sparse_format(bool is_sparse) {
232 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
233 distributed::EmbeddingCacheTableManager::GetInstance().set_sparse_format(is_sparse);
234 #endif
235 }
236
set_rank_id(uint32_t) const237 void PSContext::set_rank_id(uint32_t) const { return; }
238
set_server_mode(const std::string & server_mode)239 void PSContext::set_server_mode(const std::string &server_mode) {
240 if (server_mode != kServerModePS) {
241 MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS;
242 return;
243 }
244 MS_LOG(INFO) << "Server mode: " << server_mode << " is used for Server and Worker. Scheduler will ignore it.";
245 server_mode_ = server_mode;
246 }
247
server_mode() const248 const std::string &PSContext::server_mode() const { return server_mode_; }
249
set_ms_role(const std::string & role)250 void PSContext::set_ms_role(const std::string &role) {
251 if (role != kEnvRoleOfWorker && role != kEnvRoleOfPServer && role != kEnvRoleOfServer &&
252 role != kEnvRoleOfScheduler) {
253 MS_LOG(EXCEPTION) << "ms_role " << role << " is invalid.";
254 return;
255 }
256 MS_LOG(INFO) << "MS_ROLE of this node is " << role;
257 role_ = role;
258 }
259
set_worker_num(uint32_t worker_num)260 void PSContext::set_worker_num(uint32_t worker_num) { worker_num_ = worker_num; }
worker_num() const261 uint32_t PSContext::worker_num() const { return worker_num_; }
262
set_server_num(uint32_t server_num)263 void PSContext::set_server_num(uint32_t server_num) { server_num_ = server_num; }
server_num() const264 uint32_t PSContext::server_num() const { return server_num_; }
265
set_scheduler_ip(const std::string & sched_ip)266 void PSContext::set_scheduler_ip(const std::string &sched_ip) { scheduler_host_ = sched_ip; }
267
scheduler_ip() const268 std::string PSContext::scheduler_ip() const { return scheduler_host_; }
269
set_scheduler_port(uint16_t sched_port)270 void PSContext::set_scheduler_port(uint16_t sched_port) { scheduler_port_ = sched_port; }
271
scheduler_port() const272 uint16_t PSContext::scheduler_port() const { return scheduler_port_; }
273
cluster_config()274 core::ClusterConfig &PSContext::cluster_config() {
275 if (cluster_config_ == nullptr) {
276 cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
277 MS_EXCEPTION_IF_NULL(cluster_config_);
278 }
279 return *cluster_config_;
280 }
281
set_scheduler_manage_port(uint16_t sched_port)282 void PSContext::set_scheduler_manage_port(uint16_t sched_port) { scheduler_manage_port_ = sched_port; }
scheduler_manage_port() const283 uint16_t PSContext::scheduler_manage_port() const { return scheduler_manage_port_; }
284
set_config_file_path(const std::string & path)285 void PSContext::set_config_file_path(const std::string &path) { config_file_path_ = path; }
286
config_file_path() const287 std::string PSContext::config_file_path() const { return config_file_path_; }
288
set_node_id(const std::string & node_id)289 void PSContext::set_node_id(const std::string &node_id) { node_id_ = node_id; }
290
node_id() const291 const std::string &PSContext::node_id() const { return node_id_; }
292
enable_ssl() const293 bool PSContext::enable_ssl() const { return enable_ssl_; }
294
set_enable_ssl(bool enabled)295 void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; }
296
client_password()297 char *PSContext::client_password() { return client_password_; }
set_client_password(const char * password)298 void PSContext::set_client_password(const char *password) {
299 if (password == nullptr) {
300 MS_LOG(EXCEPTION) << "Can't set None or nullptr for client password.";
301 }
302 if (strlen(password) >= kMaxPasswordLen) {
303 MS_LOG(EXCEPTION) << "Client password is longer than max password length " << kMaxPasswordLen;
304 }
305 int ret = memcpy_s(client_password_, kMaxPasswordLen, password, strlen(password));
306 if (ret != EOK) {
307 MS_LOG(EXCEPTION) << "memcpy_s client password failed, error: " << ret;
308 }
309 }
310
ClearClientPassword()311 void PSContext::ClearClientPassword() {
312 int ret = memset_s(client_password_, kMaxPasswordLen, 0x00, kMaxPasswordLen);
313 if (ret != 0) {
314 MS_LOG(EXCEPTION) << "Clear client password failed, error: " << ret;
315 }
316 }
317
server_password()318 char *PSContext::server_password() { return server_password_; }
set_server_password(const char * password)319 void PSContext::set_server_password(const char *password) {
320 if (password == nullptr) {
321 MS_LOG(EXCEPTION) << "Can't set None or nullptr for server password.";
322 }
323 if (strlen(password) >= kMaxPasswordLen) {
324 MS_LOG(EXCEPTION) << "Client password is longer than max password length " << kMaxPasswordLen;
325 }
326 int ret = memcpy_s(server_password_, kMaxPasswordLen, password, strlen(password));
327 if (ret != EOK) {
328 MS_LOG(EXCEPTION) << "memcpy_s server password failed, error: " << ret;
329 }
330 }
331
ClearServerPassword()332 void PSContext::ClearServerPassword() {
333 int ret = memset_s(server_password_, kMaxPasswordLen, 0x00, kMaxPasswordLen);
334 if (ret != 0) {
335 MS_LOG(EXCEPTION) << "Clear client password failed, error: " << ret;
336 }
337 }
338
http_url_prefix() const339 std::string PSContext::http_url_prefix() const { return http_url_prefix_; }
340
set_instance_name(const std::string & instance_name)341 void PSContext::set_instance_name(const std::string &instance_name) { instance_name_ = instance_name; }
342
instance_name() const343 const std::string &PSContext::instance_name() const { return instance_name_; }
344
enable_distributed_mindrt() const345 bool PSContext::enable_distributed_mindrt() const {
346 bool ms_cluster_enabled = distributed::cluster::ClusterContext::instance()->initialized();
347 return ms_cluster_enabled;
348 }
349
set_checkpoint_load_status(bool status)350 void PSContext::set_checkpoint_load_status(bool status) {
351 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
352 return embedding_cache_table_manager.set_checkpoint_load_status(status);
353 #endif
354 }
355
StoreWarmUpPtrByTensor(const int32_t param_key,const tensor::TensorPtr & tensor_ptr)356 int32_t PSContext::StoreWarmUpPtrByTensor(const int32_t param_key, const tensor::TensorPtr &tensor_ptr) {
357 MS_EXCEPTION_IF_NULL(tensor_ptr);
358 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
359 return embedding_cache_table_manager.StoreWarmUpPtr(param_key, tensor_ptr);
360 #else
361 return -1;
362 #endif
363 }
364
StoreWarmUpPtrByTensorList(const int32_t param_key,const tensor::TensorPtr & key_ptr,const tensor::TensorPtr & value_ptr,const tensor::TensorPtr & status_ptr)365 int32_t PSContext::StoreWarmUpPtrByTensorList(const int32_t param_key, const tensor::TensorPtr &key_ptr,
366 const tensor::TensorPtr &value_ptr, const tensor::TensorPtr &status_ptr) {
367 MS_EXCEPTION_IF_NULL(key_ptr);
368 MS_EXCEPTION_IF_NULL(value_ptr);
369 MS_EXCEPTION_IF_NULL(status_ptr);
370 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
371 return embedding_cache_table_manager.StoreWarmUpPtr(param_key, key_ptr, value_ptr, status_ptr);
372 #else
373 return -1;
374 #endif
375 }
376 } // namespace ps
377 } // namespace mindspore
378