1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PS_CONTEXT_H_ 18 #define MINDSPORE_CCSRC_PS_CONTEXT_H_ 19 20 #include <map> 21 #include <string> 22 #include <memory> 23 #include "include/backend/distributed/ps/constants.h" 24 #include "include/backend/visible.h" 25 #include "ir/tensor.h" 26 27 namespace mindspore { 28 namespace ps { 29 constexpr char kServerModePS[] = "PARAMETER_SERVER"; 30 constexpr char kEnvRole[] = "MS_ROLE"; 31 constexpr char kEnvRoleOfPServer[] = "MS_PSERVER"; 32 constexpr char kEnvRoleOfServer[] = "MS_SERVER"; 33 constexpr char kEnvRoleOfWorker[] = "MS_WORKER"; 34 constexpr char kEnvRoleOfScheduler[] = "MS_SCHED"; 35 constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS"; 36 constexpr size_t kMaxPasswordLen = 1024; 37 38 namespace core { 39 struct ClusterConfig; 40 } // namespace core 41 42 class BACKEND_EXPORT PSContext { 43 public: 44 ~PSContext(); 45 PSContext(PSContext const &) = delete; 46 PSContext &operator=(const PSContext &) = delete; 47 static std::shared_ptr<PSContext> instance(); 48 49 void SetPSEnable(bool enabled); 50 bool is_ps_mode() const; 51 void Reset(); 52 std::string ms_role() const; 53 bool is_worker() const; 54 bool is_server() const; 55 bool is_scheduler() const; 56 uint32_t initial_worker_num() const; 57 uint32_t initial_server_num() const; 58 std::string scheduler_host() const; 59 void SetPSRankId(uint32_t rank_id); 60 uint32_t ps_rank_id() const; 61 void InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, 62 size_t vocab_size, int32_t param_key) const; 63 void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name) const; 64 void InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const; 65 void CloneHashTable(const std::string &dest_param_name, int32_t dest_param_key, const std::string &src_param_name, 66 int32_t src_param_key) const; 67 void set_cache_enable(bool cache_enable) const; 68 bool cache_enable() const; 69 70 // Set embedding cache size for ps cache mode. 71 void set_cache_size(size_t cache_size) const; 72 73 // Set if the storage format of embedding table is sparse or not. 74 void set_sparse_format(bool is_sparse); 75 76 void set_rank_id(uint32_t rank_id) const; 77 78 // In new server framework, process role, worker number, server number, scheduler ip and scheduler port should be set 79 // by ps_context. 80 void set_server_mode(const std::string &server_mode); 81 const std::string &server_mode() const; 82 83 void set_ms_role(const std::string &role); 84 85 void set_worker_num(uint32_t worker_num); 86 uint32_t worker_num() const; 87 88 void set_server_num(uint32_t server_num); 89 uint32_t server_num() const; 90 91 void set_scheduler_ip(const std::string &sched_ip); 92 std::string scheduler_ip() const; 93 94 void set_scheduler_port(uint16_t sched_port); 95 uint16_t scheduler_port() const; 96 97 core::ClusterConfig &cluster_config(); 98 99 void set_scheduler_manage_port(uint16_t sched_port); 100 uint16_t scheduler_manage_port() const; 101 102 void set_config_file_path(const std::string &path); 103 std::string config_file_path() const; 104 105 void set_node_id(const std::string &node_id); 106 const std::string &node_id() const; 107 108 bool enable_ssl() const; 109 void set_enable_ssl(bool enabled); 110 111 char *client_password(); 112 void set_client_password(const char *password); 113 void ClearClientPassword(); 114 115 char *server_password(); 116 void set_server_password(const char *password); 117 void ClearServerPassword(); 118 119 std::string http_url_prefix() const; 120 121 void set_instance_name(const std::string &instance_name); 122 const std::string &instance_name() const; 123 124 // Whether distributed MindRT is enabled. 125 bool enable_distributed_mindrt() const; 126 127 void set_checkpoint_load_status(bool status); 128 129 int32_t StoreWarmUpPtrByTensor(int32_t param_key, const tensor::TensorPtr &tensor_ptr); 130 131 int32_t StoreWarmUpPtrByTensorList(int32_t param_key, const tensor::TensorPtr &key_ptr, 132 const tensor::TensorPtr &value_ptr, const tensor::TensorPtr &status_ptr); 133 134 private: 135 PSContext(); 136 137 bool ps_enabled_; 138 bool is_worker_; 139 bool is_pserver_; 140 bool is_sched_; 141 uint32_t rank_id_; 142 uint32_t worker_num_; 143 uint32_t server_num_; 144 std::string scheduler_host_; 145 uint16_t scheduler_port_; 146 147 // The server process's role. 148 std::string role_; 149 150 // Server mode which could be Parameter Server. 151 std::string server_mode_; 152 153 // The cluster config read through environment variables, the value does not change. 154 std::unique_ptr<core::ClusterConfig> cluster_config_; 155 156 // The port used by scheduler to receive http requests for scale out or scale in. 157 uint16_t scheduler_manage_port_; 158 159 // The path of the configuration file, used to configure the certification path and persistent storage type, etc. 160 std::string config_file_path_; 161 162 // Unique id of the node 163 std::string node_id_; 164 165 // Whether to enable ssl for network communication. 166 bool enable_ssl_; 167 // Password used to decode p12 file. 168 char client_password_[kMaxPasswordLen]; 169 // Password used to decode p12 file. 170 char server_password_[kMaxPasswordLen]; 171 // http url prefix for http communication 172 std::string http_url_prefix_; 173 // The name of instance 174 std::string instance_name_; 175 }; 176 } // namespace ps 177 } // namespace mindspore 178 #endif // MINDSPORE_CCSRC_PS_CONTEXT_H_ 179