• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/backend/distributed/ps/ps_context.h"
18 
19 #include "ir/tensor.h"
20 #include "kernel/kernel.h"
21 #include "utils/log_adapter.h"
22 #include "utils/ms_utils.h"
23 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
24 #include "include/backend/distributed/cluster/cluster_context.h"
25 #include "include/backend/distributed/ps/ps_cache/ps_data_prefetch.h"
26 #include "ps/core/cluster_config.h"
27 #include "include/backend/distributed/embedding_cache/embedding_cache_utils.h"
28 #else
29 #include "include/backend/distributed/cluster/dummy_cluster_context.h"
30 #include "ps/core/cluster_config.h"
31 #endif
32 
33 namespace mindspore {
34 namespace ps {
35 namespace {
36 constexpr uint16_t kDefaultSchedPort = 6667;
37 constexpr uint16_t kDefaultSchedManagerPort = 11202;
38 }  // namespace
PSContext()39 PSContext::PSContext()
40     : ps_enabled_(false),
41       is_worker_(false),
42       is_pserver_(false),
43       is_sched_(false),
44       rank_id_(0),
45       worker_num_(0),
46       server_num_(0),
47       scheduler_host_("0.0.0.0"),
48       scheduler_port_(kDefaultSchedPort),
49       role_(kEnvRoleOfNotPS),
50       server_mode_(""),
51       cluster_config_(nullptr),
52       scheduler_manage_port_(kDefaultSchedManagerPort),
53       config_file_path_(""),
54       node_id_(""),
55       enable_ssl_(false),
56       client_password_(),
57       server_password_(),
58       http_url_prefix_(""),
59       instance_name_("") {}
60 
~PSContext()61 PSContext::~PSContext() {}
62 
instance()63 std::shared_ptr<PSContext> PSContext::instance() {
64   static std::once_flag init_flag;
65   static std::shared_ptr<PSContext> ps_instance = nullptr;
66   std::call_once(init_flag, [&]() {
67     if (ps_instance == nullptr) {
68       ps_instance.reset(new (std::nothrow) PSContext());
69       MS_EXCEPTION_IF_NULL(ps_instance);
70     }
71   });
72 
73   return ps_instance;
74 }
75 
SetPSEnable(bool enabled)76 void PSContext::SetPSEnable(bool enabled) {
77   ps_enabled_ = enabled;
78   if (ps_enabled_) {
79     std::string ms_role = common::GetEnv(kEnvRole);
80     if (ms_role == "") {
81       ms_role = this->ms_role();
82     }
83     MS_LOG(INFO) << "PS mode is enabled. MS_ROLE is " << ms_role;
84 
85     if (ms_role == kEnvRoleOfWorker) {
86       is_worker_ = true;
87     } else if (ms_role == kEnvRoleOfPServer || ms_role == kEnvRoleOfServer) {
88       is_pserver_ = true;
89     } else if (ms_role == kEnvRoleOfScheduler) {
90       is_sched_ = true;
91     }
92 
93     worker_num_ = std::strtol(common::GetEnv(kEnvWorkerNum).c_str(), nullptr, kBase);
94     server_num_ = std::strtol(common::GetEnv(kEnvPServerNum).c_str(), nullptr, kBase);
95     scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
96     if (scheduler_host_.length() > kLength) {
97       MS_LOG(EXCEPTION) << "The scheduler host's length can not exceed " << kLength;
98     }
99     scheduler_port_ = std::strtol(common::GetEnv(kEnvSchedulerPort).c_str(), nullptr, kBase);
100     if (scheduler_port_ > kMaxPort) {
101       MS_LOG(EXCEPTION) << "The port: " << scheduler_port_ << " is illegal.";
102     }
103     scheduler_manage_port_ =
104       static_cast<uint16_t>((std::strtol(common::GetEnv(kEnvSchedulerManagePort).c_str(), nullptr, kBase)));
105     if (scheduler_manage_port_ > kMaxPort) {
106       MS_LOG(EXCEPTION) << "The port << " << scheduler_manage_port_ << " is illegal.";
107     }
108     cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
109     node_id_ = common::GetEnv(kEnvNodeId);
110     if (node_id_.length() > kLength) {
111       MS_LOG(EXCEPTION) << "The node id length can not exceed " << kLength;
112     }
113     server_mode_ = kServerModePS;
114   } else {
115     MS_LOG(INFO) << "PS mode is disabled.";
116     is_worker_ = false;
117     is_pserver_ = false;
118     is_sched_ = false;
119   }
120 }
121 
is_ps_mode() const122 bool PSContext::is_ps_mode() const { return ps_enabled_; }
123 
Reset()124 void PSContext::Reset() {
125   ps_enabled_ = false;
126   is_worker_ = false;
127   is_pserver_ = false;
128   is_sched_ = false;
129 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
130   if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
131     set_cache_enable(false);
132   }
133 #endif
134 }
135 
ms_role() const136 std::string PSContext::ms_role() const {
137   if (is_worker_) {
138     return kEnvRoleOfWorker;
139   } else if (is_pserver_) {
140     return kEnvRoleOfPServer;
141   } else if (is_sched_) {
142     return kEnvRoleOfScheduler;
143   } else {
144     return kEnvRoleOfNotPS;
145   }
146 }
147 
is_worker() const148 bool PSContext::is_worker() const {
149   if (distributed::cluster::ClusterContext::instance()->initialized()) {
150     return role_ == kEnvRoleOfWorker;
151   }
152   return is_worker_;
153 }
154 
is_server() const155 bool PSContext::is_server() const {
156   if (distributed::cluster::ClusterContext::instance()->initialized()) {
157     return role_ == kEnvRoleOfServer || role_ == kEnvRoleOfPServer;
158   }
159   return is_pserver_;
160 }
161 
is_scheduler() const162 bool PSContext::is_scheduler() const {
163   if (distributed::cluster::ClusterContext::instance()->initialized()) {
164     return role_ == kEnvRoleOfScheduler;
165   }
166   return is_sched_;
167 }
168 
initial_worker_num() const169 uint32_t PSContext::initial_worker_num() const { return worker_num_; }
170 
initial_server_num() const171 uint32_t PSContext::initial_server_num() const { return server_num_; }
172 
scheduler_host() const173 std::string PSContext::scheduler_host() const { return scheduler_host_; }
174 
SetPSRankId(uint32_t rank_id)175 void PSContext::SetPSRankId(uint32_t rank_id) { rank_id_ = rank_id; }
176 
ps_rank_id() const177 uint32_t PSContext::ps_rank_id() const { return rank_id_; }
178 
InsertHashTableSize(const std::string & param_name,size_t cache_vocab_size,size_t embedding_size,size_t vocab_size,int32_t param_key) const179 void PSContext::InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
180                                     size_t vocab_size, int32_t param_key) const {
181 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
182   if (enable_distributed_mindrt()) {
183     embedding_cache_table_manager.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size,
184                                                       param_key);
185   }
186 #endif
187 }
188 
ReInsertHashTableSize(const std::string & new_param_name,const std::string & cur_param_name) const189 void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name) const {
190 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
191   if (enable_distributed_mindrt()) {
192     embedding_cache_table_manager.ReInsertHashTableSize(new_param_name, cur_param_name);
193   }
194 #endif
195 }
196 
InsertAccumuInitInfo(const std::string & param_name,float init_val) const197 void PSContext::InsertAccumuInitInfo(const std::string &param_name, float init_val) const {
198 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
199   embedding_cache_table_manager.InsertAccumuInitInfo(param_name, init_val);
200 #endif
201 }
202 
CloneHashTable(const std::string & dest_param_name,int32_t dest_param_key,const std::string & src_param_name,int32_t src_param_key) const203 void PSContext::CloneHashTable(const std::string &dest_param_name, int32_t dest_param_key,
204                                const std::string &src_param_name, int32_t src_param_key) const {
205 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
206   if (enable_distributed_mindrt()) {
207     embedding_cache_table_manager.CloneHashTable(dest_param_name, dest_param_key, src_param_name, src_param_key);
208   }
209 #endif
210 }
211 
set_cache_enable(bool cache_enable) const212 void PSContext::set_cache_enable(bool cache_enable) const {
213 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
214   PsDataPrefetch::GetInstance().set_cache_enable(cache_enable);
215 #endif
216 }
217 
cache_enable() const218 bool PSContext::cache_enable() const {
219 #if ((defined ENABLE_CPU) && (!defined _WIN32)) && !defined(__APPLE__)
220   return PsDataPrefetch::GetInstance().cache_enable();
221 #endif
222   return false;
223 }
224 
set_cache_size(size_t cache_size) const225 void PSContext::set_cache_size(size_t cache_size) const {
226 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
227   distributed::EmbeddingCacheTableManager::GetInstance().set_cache_size(cache_size);
228 #endif
229 }
230 
set_sparse_format(bool is_sparse)231 void PSContext::set_sparse_format(bool is_sparse) {
232 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
233   distributed::EmbeddingCacheTableManager::GetInstance().set_sparse_format(is_sparse);
234 #endif
235 }
236 
set_rank_id(uint32_t) const237 void PSContext::set_rank_id(uint32_t) const { return; }
238 
set_server_mode(const std::string & server_mode)239 void PSContext::set_server_mode(const std::string &server_mode) {
240   if (server_mode != kServerModePS) {
241     MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS;
242     return;
243   }
244   MS_LOG(INFO) << "Server mode: " << server_mode << " is used for Server and Worker. Scheduler will ignore it.";
245   server_mode_ = server_mode;
246 }
247 
server_mode() const248 const std::string &PSContext::server_mode() const { return server_mode_; }
249 
set_ms_role(const std::string & role)250 void PSContext::set_ms_role(const std::string &role) {
251   if (role != kEnvRoleOfWorker && role != kEnvRoleOfPServer && role != kEnvRoleOfServer &&
252       role != kEnvRoleOfScheduler) {
253     MS_LOG(EXCEPTION) << "ms_role " << role << " is invalid.";
254     return;
255   }
256   MS_LOG(INFO) << "MS_ROLE of this node is " << role;
257   role_ = role;
258 }
259 
set_worker_num(uint32_t worker_num)260 void PSContext::set_worker_num(uint32_t worker_num) { worker_num_ = worker_num; }
worker_num() const261 uint32_t PSContext::worker_num() const { return worker_num_; }
262 
set_server_num(uint32_t server_num)263 void PSContext::set_server_num(uint32_t server_num) { server_num_ = server_num; }
server_num() const264 uint32_t PSContext::server_num() const { return server_num_; }
265 
set_scheduler_ip(const std::string & sched_ip)266 void PSContext::set_scheduler_ip(const std::string &sched_ip) { scheduler_host_ = sched_ip; }
267 
scheduler_ip() const268 std::string PSContext::scheduler_ip() const { return scheduler_host_; }
269 
set_scheduler_port(uint16_t sched_port)270 void PSContext::set_scheduler_port(uint16_t sched_port) { scheduler_port_ = sched_port; }
271 
scheduler_port() const272 uint16_t PSContext::scheduler_port() const { return scheduler_port_; }
273 
cluster_config()274 core::ClusterConfig &PSContext::cluster_config() {
275   if (cluster_config_ == nullptr) {
276     cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
277     MS_EXCEPTION_IF_NULL(cluster_config_);
278   }
279   return *cluster_config_;
280 }
281 
set_scheduler_manage_port(uint16_t sched_port)282 void PSContext::set_scheduler_manage_port(uint16_t sched_port) { scheduler_manage_port_ = sched_port; }
scheduler_manage_port() const283 uint16_t PSContext::scheduler_manage_port() const { return scheduler_manage_port_; }
284 
set_config_file_path(const std::string & path)285 void PSContext::set_config_file_path(const std::string &path) { config_file_path_ = path; }
286 
config_file_path() const287 std::string PSContext::config_file_path() const { return config_file_path_; }
288 
set_node_id(const std::string & node_id)289 void PSContext::set_node_id(const std::string &node_id) { node_id_ = node_id; }
290 
node_id() const291 const std::string &PSContext::node_id() const { return node_id_; }
292 
enable_ssl() const293 bool PSContext::enable_ssl() const { return enable_ssl_; }
294 
set_enable_ssl(bool enabled)295 void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; }
296 
client_password()297 char *PSContext::client_password() { return client_password_; }
set_client_password(const char * password)298 void PSContext::set_client_password(const char *password) {
299   if (password == nullptr) {
300     MS_LOG(EXCEPTION) << "Can't set None or nullptr for client password.";
301   }
302   if (strlen(password) >= kMaxPasswordLen) {
303     MS_LOG(EXCEPTION) << "Client password is longer than max password length " << kMaxPasswordLen;
304   }
305   int ret = memcpy_s(client_password_, kMaxPasswordLen, password, strlen(password));
306   if (ret != EOK) {
307     MS_LOG(EXCEPTION) << "memcpy_s client password failed, error: " << ret;
308   }
309 }
310 
ClearClientPassword()311 void PSContext::ClearClientPassword() {
312   int ret = memset_s(client_password_, kMaxPasswordLen, 0x00, kMaxPasswordLen);
313   if (ret != 0) {
314     MS_LOG(EXCEPTION) << "Clear client password failed, error: " << ret;
315   }
316 }
317 
server_password()318 char *PSContext::server_password() { return server_password_; }
set_server_password(const char * password)319 void PSContext::set_server_password(const char *password) {
320   if (password == nullptr) {
321     MS_LOG(EXCEPTION) << "Can't set None or nullptr for server password.";
322   }
323   if (strlen(password) >= kMaxPasswordLen) {
324     MS_LOG(EXCEPTION) << "Client password is longer than max password length " << kMaxPasswordLen;
325   }
326   int ret = memcpy_s(server_password_, kMaxPasswordLen, password, strlen(password));
327   if (ret != EOK) {
328     MS_LOG(EXCEPTION) << "memcpy_s server password failed, error: " << ret;
329   }
330 }
331 
ClearServerPassword()332 void PSContext::ClearServerPassword() {
333   int ret = memset_s(server_password_, kMaxPasswordLen, 0x00, kMaxPasswordLen);
334   if (ret != 0) {
335     MS_LOG(EXCEPTION) << "Clear client password failed, error: " << ret;
336   }
337 }
338 
http_url_prefix() const339 std::string PSContext::http_url_prefix() const { return http_url_prefix_; }
340 
set_instance_name(const std::string & instance_name)341 void PSContext::set_instance_name(const std::string &instance_name) { instance_name_ = instance_name; }
342 
instance_name() const343 const std::string &PSContext::instance_name() const { return instance_name_; }
344 
enable_distributed_mindrt() const345 bool PSContext::enable_distributed_mindrt() const {
346   bool ms_cluster_enabled = distributed::cluster::ClusterContext::instance()->initialized();
347   return ms_cluster_enabled;
348 }
349 
set_checkpoint_load_status(bool status)350 void PSContext::set_checkpoint_load_status(bool status) {
351 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
352   return embedding_cache_table_manager.set_checkpoint_load_status(status);
353 #endif
354 }
355 
StoreWarmUpPtrByTensor(const int32_t param_key,const tensor::TensorPtr & tensor_ptr)356 int32_t PSContext::StoreWarmUpPtrByTensor(const int32_t param_key, const tensor::TensorPtr &tensor_ptr) {
357   MS_EXCEPTION_IF_NULL(tensor_ptr);
358 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
359   return embedding_cache_table_manager.StoreWarmUpPtr(param_key, tensor_ptr);
360 #else
361   return -1;
362 #endif
363 }
364 
StoreWarmUpPtrByTensorList(const int32_t param_key,const tensor::TensorPtr & key_ptr,const tensor::TensorPtr & value_ptr,const tensor::TensorPtr & status_ptr)365 int32_t PSContext::StoreWarmUpPtrByTensorList(const int32_t param_key, const tensor::TensorPtr &key_ptr,
366                                               const tensor::TensorPtr &value_ptr, const tensor::TensorPtr &status_ptr) {
367   MS_EXCEPTION_IF_NULL(key_ptr);
368   MS_EXCEPTION_IF_NULL(value_ptr);
369   MS_EXCEPTION_IF_NULL(status_ptr);
370 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
371   return embedding_cache_table_manager.StoreWarmUpPtr(param_key, key_ptr, value_ptr, status_ptr);
372 #else
373   return -1;
374 #endif
375 }
376 }  // namespace ps
377 }  // namespace mindspore
378