1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_ 18 #define MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_ 19 20 #include <unistd.h> 21 #include <unordered_map> 22 #include <string> 23 #include <iostream> 24 #include <memory> 25 #include <vector> 26 #include <mutex> 27 #include <condition_variable> 28 #include <thread> 29 #include <cmath> 30 #include <random> 31 #include <utility> 32 #include <list> 33 #include <map> 34 #include <functional> 35 #include <algorithm> 36 37 #include "ir/func_graph.h" 38 #include "backend/session/session_basic.h" 39 #include "backend/session/anf_runtime_algorithm.h" 40 #include "backend/session/session_factory.h" 41 #include "ps/optimizer_info.h" 42 #include "ps/optimizer_info_builder.h" 43 #include "ps/ps_context.h" 44 #include "runtime/device/cpu/kernel_select_cpu.h" 45 #include "utils/ms_context.h" 46 #include "backend/kernel_compiler/kernel.h" 47 #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" 48 #include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" 49 #include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h" 50 #include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h" 51 #include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h" 52 #include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h" 53 #include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h" 54 #include "ps/ps_cache/ps_data/ps_data_prefetch.h" 55 #include "ps/random_normal/random_normal.h" 56 57 #include "ps/constants.h" 58 #include "ps/util.h" 59 #include "ps/embedding_table_shard_metadata.h" 60 #include "utils/log_adapter.h" 61 #include "proto/comm.pb.h" 62 #include "proto/ps.pb.h" 63 #include "ps/core/server_node.h" 64 #include "ps/core/node.h" 65 66 namespace mindspore { 67 namespace ps { 68 class ParameterServer { 69 public: GetInstance()70 static ParameterServer &GetInstance() { 71 static ParameterServer instance; 72 return instance; 73 } 74 75 void Run(const FuncGraphPtr &func_graph); 76 77 private: ParameterServer()78 ParameterServer() 79 : pserver_num_(0), 80 worker_num_(0), 81 grad_accum_count_(0), 82 handler_(nullptr), 83 func_graph_(nullptr), 84 sess_(nullptr), 85 running_(true), 86 thread_(nullptr), 87 server_node_(nullptr) {} 88 ~ParameterServer() = default; 89 ParameterServer(const ParameterServer &) = delete; 90 ParameterServer &operator=(const ParameterServer &) = delete; 91 92 class ServerHandler { 93 public: ServerHandler(ParameterServer * ps)94 explicit ServerHandler(ParameterServer *ps) : ps_(ps) {} 95 ~ServerHandler() = default; 96 void Init(); 97 void operator()(const std::shared_ptr<core::TcpConnection> &conn, const std::shared_ptr<core::MessageMeta> &meta, 98 const DataPtr &data, size_t size); 99 void HandlePushReq(const DataPtr &data, size_t size, const VectorPtr &res); 100 void HandlePullReq(const DataPtr &data, size_t size, const VectorPtr &res); 101 void HandleInitWeights(const DataPtr &data, size_t size, const VectorPtr &res); 102 void HandleInitWeightToOptimId(const DataPtr &data, size_t size, const VectorPtr &res); 103 void HandleInitInputsShape(const DataPtr &data, size_t size, const VectorPtr &res); 104 void HandleInitEmbeddings(const DataPtr &data, size_t size, const VectorPtr &res); 105 void HandleCheckReadyForPush(const DataPtr &data, size_t size, const VectorPtr &res); 106 void HandleCheckReadyForPull(const DataPtr &data, size_t size, const VectorPtr &res); 107 void HandleEmbeddingLookup(const DataPtr &data, size_t size, const VectorPtr &res); 108 void HandleUpdateEmbeddings(const DataPtr &data, size_t size, const VectorPtr &res); 109 void HandleFinalize(const DataPtr &data, size_t size, const VectorPtr &res); 110 111 private: 112 ParameterServer *ps_; 113 typedef void (ServerHandler::*RequestHandler)(const DataPtr &data, size_t size, const VectorPtr &res); 114 std::unordered_map<int, RequestHandler> handlers_; 115 std::unordered_map<int, std::string> commands_; 116 std::unordered_map<Key, bool> init_weights_; 117 std::unordered_map<Key, bool> init_weight_to_optim_; 118 std::unordered_map<Key, bool> init_optim_info_; 119 }; 120 121 bool Init(const FuncGraphPtr &func_graph); 122 void InitOptimInfoBuilders(); 123 void InitWeightKeyToOptims(const Key &key, const int64_t &optim_id); 124 void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths); 125 void InitWeight(const Key &key, const WeightPtr &weight); 126 void InitGrad(const Key &key, const GradPtr &grad); 127 void InitEmbeddingTable(const Key &key, 128 const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes, 129 const ParamInitInfo ¶m_init_info); 130 bool HasWeight(const Key &key); 131 void Finalize(); 132 void UpdateWeights(); 133 void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); 134 WeightPtr weight(const Key &key); 135 void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res); 136 void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals); 137 inline bool ReadyForUpdateWeights() const; 138 inline bool ReadyForPush(const Key &key); 139 inline bool ReadyForPull(const Key &key); 140 inline void ResetGradAccumCount(); 141 const CNodePtr GetCNode(const std::string &name) const; 142 inline std::mutex &mutex(); 143 void GetEmbeddingTableParamPtr(); 144 void SyncEmbeddingTables(); 145 // Cache embedding table parameter by map, key: parameter name, value: parameter node pointer 146 void CacheEmbeddingTableParamPtr(); 147 148 size_t pserver_num_; 149 size_t worker_num_; 150 size_t grad_accum_count_; 151 std::unique_ptr<ServerHandler> handler_; 152 FuncGraphPtr func_graph_; 153 std::shared_ptr<session::SessionBasic> sess_; 154 bool running_; 155 bool embedding_param_ptr_cached_{false}; 156 // Used to cache embedding table parameter, key: parameter name, value: parameter node pointer 157 std::map<std::string, ParameterPtr> embedding_parameter_tables_; 158 159 std::unordered_map<Key, std::shared_ptr<PServerKernel>> optimizers_; 160 std::unordered_map<Key, InputsShapePtr> optim_inputs_shape_; 161 std::unordered_map<Key, InputsShapePtr> original_optim_inputs_shape_; 162 std::unordered_map<Key, std::shared_ptr<OptimizerInfo>> optim_infos_; 163 std::unordered_map<std::string, std::shared_ptr<OptimizerInfoBuilder>> optim_info_builders_; 164 std::unordered_map<Key, std::string> weight_key_to_optims_; 165 std::unordered_map<Key, std::string> weight_key_to_optim_op_; 166 std::unordered_map<Key, WeightPtr> weights_; 167 std::unordered_map<Key, bool> is_embedding_; 168 std::unordered_map<Key, WeightPtr> grads_; 169 std::unordered_map<Key, size_t> grads_accum_counter_; 170 std::unordered_map<Key, std::shared_ptr<PServerKernel>> embedding_lookup_ops_; 171 std::unordered_map<Key, uint64_t> tokens_; 172 173 std::mutex mutex_; 174 std::condition_variable apply_grads_cv_; 175 176 std::unique_ptr<std::thread> thread_; 177 std::shared_ptr<core::ServerNode> server_node_; 178 std::map<Key, ParameterPtr> embedding_tables_; 179 180 friend class ServerHandler; 181 }; 182 } // namespace ps 183 } // namespace mindspore 184 #endif // MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_ 185