• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PS_CONTEXT_H_
18 #define MINDSPORE_CCSRC_PS_CONTEXT_H_
19 
20 #include <map>
21 #include <string>
22 #include <memory>
23 #include "include/backend/distributed/ps/constants.h"
24 #include "include/backend/visible.h"
25 #include "ir/tensor.h"
26 
27 namespace mindspore {
28 namespace ps {
29 constexpr char kServerModePS[] = "PARAMETER_SERVER";
30 constexpr char kEnvRole[] = "MS_ROLE";
31 constexpr char kEnvRoleOfPServer[] = "MS_PSERVER";
32 constexpr char kEnvRoleOfServer[] = "MS_SERVER";
33 constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
34 constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
35 constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
36 constexpr size_t kMaxPasswordLen = 1024;
37 
38 namespace core {
39 struct ClusterConfig;
40 }  // namespace core
41 
42 class BACKEND_EXPORT PSContext {
43  public:
44   ~PSContext();
45   PSContext(PSContext const &) = delete;
46   PSContext &operator=(const PSContext &) = delete;
47   static std::shared_ptr<PSContext> instance();
48 
49   void SetPSEnable(bool enabled);
50   bool is_ps_mode() const;
51   void Reset();
52   std::string ms_role() const;
53   bool is_worker() const;
54   bool is_server() const;
55   bool is_scheduler() const;
56   uint32_t initial_worker_num() const;
57   uint32_t initial_server_num() const;
58   std::string scheduler_host() const;
59   void SetPSRankId(uint32_t rank_id);
60   uint32_t ps_rank_id() const;
61   void InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
62                            size_t vocab_size, int32_t param_key) const;
63   void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name) const;
64   void InsertAccumuInitInfo(const std::string &param_name, float init_val) const;
65   void CloneHashTable(const std::string &dest_param_name, int32_t dest_param_key, const std::string &src_param_name,
66                       int32_t src_param_key) const;
67   void set_cache_enable(bool cache_enable) const;
68   bool cache_enable() const;
69 
70   // Set embedding cache size for  ps cache mode.
71   void set_cache_size(size_t cache_size) const;
72 
73   // Set if the storage format of embedding table is sparse or not.
74   void set_sparse_format(bool is_sparse);
75 
76   void set_rank_id(uint32_t rank_id) const;
77 
78   // In new server framework, process role, worker number, server number, scheduler ip and scheduler port should be set
79   // by ps_context.
80   void set_server_mode(const std::string &server_mode);
81   const std::string &server_mode() const;
82 
83   void set_ms_role(const std::string &role);
84 
85   void set_worker_num(uint32_t worker_num);
86   uint32_t worker_num() const;
87 
88   void set_server_num(uint32_t server_num);
89   uint32_t server_num() const;
90 
91   void set_scheduler_ip(const std::string &sched_ip);
92   std::string scheduler_ip() const;
93 
94   void set_scheduler_port(uint16_t sched_port);
95   uint16_t scheduler_port() const;
96 
97   core::ClusterConfig &cluster_config();
98 
99   void set_scheduler_manage_port(uint16_t sched_port);
100   uint16_t scheduler_manage_port() const;
101 
102   void set_config_file_path(const std::string &path);
103   std::string config_file_path() const;
104 
105   void set_node_id(const std::string &node_id);
106   const std::string &node_id() const;
107 
108   bool enable_ssl() const;
109   void set_enable_ssl(bool enabled);
110 
111   char *client_password();
112   void set_client_password(const char *password);
113   void ClearClientPassword();
114 
115   char *server_password();
116   void set_server_password(const char *password);
117   void ClearServerPassword();
118 
119   std::string http_url_prefix() const;
120 
121   void set_instance_name(const std::string &instance_name);
122   const std::string &instance_name() const;
123 
124   // Whether distributed MindRT is enabled.
125   bool enable_distributed_mindrt() const;
126 
127   void set_checkpoint_load_status(bool status);
128 
129   int32_t StoreWarmUpPtrByTensor(int32_t param_key, const tensor::TensorPtr &tensor_ptr);
130 
131   int32_t StoreWarmUpPtrByTensorList(int32_t param_key, const tensor::TensorPtr &key_ptr,
132                                      const tensor::TensorPtr &value_ptr, const tensor::TensorPtr &status_ptr);
133 
134  private:
135   PSContext();
136 
137   bool ps_enabled_;
138   bool is_worker_;
139   bool is_pserver_;
140   bool is_sched_;
141   uint32_t rank_id_;
142   uint32_t worker_num_;
143   uint32_t server_num_;
144   std::string scheduler_host_;
145   uint16_t scheduler_port_;
146 
147   // The server process's role.
148   std::string role_;
149 
150   // Server mode which could be Parameter Server.
151   std::string server_mode_;
152 
153   // The cluster config read through environment variables, the value does not change.
154   std::unique_ptr<core::ClusterConfig> cluster_config_;
155 
156   // The port used by scheduler to receive http requests for scale out or scale in.
157   uint16_t scheduler_manage_port_;
158 
159   // The path of the configuration file, used to configure the certification path and persistent storage type, etc.
160   std::string config_file_path_;
161 
162   // Unique id of the node
163   std::string node_id_;
164 
165   // Whether to enable ssl for network communication.
166   bool enable_ssl_;
167   // Password used to decode p12 file.
168   char client_password_[kMaxPasswordLen];
169   // Password used to decode p12 file.
170   char server_password_[kMaxPasswordLen];
171   // http url prefix for http communication
172   std::string http_url_prefix_;
173   // The name of instance
174   std::string instance_name_;
175 };
176 }  // namespace ps
177 }  // namespace mindspore
178 #endif  // MINDSPORE_CCSRC_PS_CONTEXT_H_
179