• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_
18 #define MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_
19 
20 #include <unistd.h>
21 #include <unordered_map>
22 #include <string>
23 #include <iostream>
24 #include <memory>
25 #include <vector>
26 #include <mutex>
27 #include <condition_variable>
28 #include <thread>
29 #include <cmath>
30 #include <random>
31 #include <utility>
32 #include <list>
33 #include <map>
34 #include <functional>
35 #include <algorithm>
36 
37 #include "ir/func_graph.h"
38 #include "backend/session/session_basic.h"
39 #include "backend/session/anf_runtime_algorithm.h"
40 #include "backend/session/session_factory.h"
41 #include "ps/optimizer_info.h"
42 #include "ps/optimizer_info_builder.h"
43 #include "ps/ps_context.h"
44 #include "runtime/device/cpu/kernel_select_cpu.h"
45 #include "utils/ms_context.h"
46 #include "backend/kernel_compiler/kernel.h"
47 #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
48 #include "backend/kernel_compiler/cpu/ps/pserver_kernel.h"
49 #include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h"
50 #include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h"
51 #include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h"
52 #include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h"
53 #include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h"
54 #include "ps/ps_cache/ps_data/ps_data_prefetch.h"
55 #include "ps/random_normal/random_normal.h"
56 
57 #include "ps/constants.h"
58 #include "ps/util.h"
59 #include "ps/embedding_table_shard_metadata.h"
60 #include "utils/log_adapter.h"
61 #include "proto/comm.pb.h"
62 #include "proto/ps.pb.h"
63 #include "ps/core/server_node.h"
64 #include "ps/core/node.h"
65 
66 namespace mindspore {
67 namespace ps {
68 class ParameterServer {
69  public:
GetInstance()70   static ParameterServer &GetInstance() {
71     static ParameterServer instance;
72     return instance;
73   }
74 
75   void Run(const FuncGraphPtr &func_graph);
76 
77  private:
ParameterServer()78   ParameterServer()
79       : pserver_num_(0),
80         worker_num_(0),
81         grad_accum_count_(0),
82         handler_(nullptr),
83         func_graph_(nullptr),
84         sess_(nullptr),
85         running_(true),
86         thread_(nullptr),
87         server_node_(nullptr) {}
88   ~ParameterServer() = default;
89   ParameterServer(const ParameterServer &) = delete;
90   ParameterServer &operator=(const ParameterServer &) = delete;
91 
92   class ServerHandler {
93    public:
ServerHandler(ParameterServer * ps)94     explicit ServerHandler(ParameterServer *ps) : ps_(ps) {}
95     ~ServerHandler() = default;
96     void Init();
97     void operator()(const std::shared_ptr<core::TcpConnection> &conn, const std::shared_ptr<core::MessageMeta> &meta,
98                     const DataPtr &data, size_t size);
99     void HandlePushReq(const DataPtr &data, size_t size, const VectorPtr &res);
100     void HandlePullReq(const DataPtr &data, size_t size, const VectorPtr &res);
101     void HandleInitWeights(const DataPtr &data, size_t size, const VectorPtr &res);
102     void HandleInitWeightToOptimId(const DataPtr &data, size_t size, const VectorPtr &res);
103     void HandleInitInputsShape(const DataPtr &data, size_t size, const VectorPtr &res);
104     void HandleInitEmbeddings(const DataPtr &data, size_t size, const VectorPtr &res);
105     void HandleCheckReadyForPush(const DataPtr &data, size_t size, const VectorPtr &res);
106     void HandleCheckReadyForPull(const DataPtr &data, size_t size, const VectorPtr &res);
107     void HandleEmbeddingLookup(const DataPtr &data, size_t size, const VectorPtr &res);
108     void HandleUpdateEmbeddings(const DataPtr &data, size_t size, const VectorPtr &res);
109     void HandleFinalize(const DataPtr &data, size_t size, const VectorPtr &res);
110 
111    private:
112     ParameterServer *ps_;
113     typedef void (ServerHandler::*RequestHandler)(const DataPtr &data, size_t size, const VectorPtr &res);
114     std::unordered_map<int, RequestHandler> handlers_;
115     std::unordered_map<int, std::string> commands_;
116     std::unordered_map<Key, bool> init_weights_;
117     std::unordered_map<Key, bool> init_weight_to_optim_;
118     std::unordered_map<Key, bool> init_optim_info_;
119   };
120 
121   bool Init(const FuncGraphPtr &func_graph);
122   void InitOptimInfoBuilders();
123   void InitWeightKeyToOptims(const Key &key, const int64_t &optim_id);
124   void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths);
125   void InitWeight(const Key &key, const WeightPtr &weight);
126   void InitGrad(const Key &key, const GradPtr &grad);
127   void InitEmbeddingTable(const Key &key,
128                           const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
129                           const ParamInitInfo &param_init_info);
130   bool HasWeight(const Key &key);
131   void Finalize();
132   void UpdateWeights();
133   void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths);
134   WeightPtr weight(const Key &key);
135   void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res);
136   void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals);
137   inline bool ReadyForUpdateWeights() const;
138   inline bool ReadyForPush(const Key &key);
139   inline bool ReadyForPull(const Key &key);
140   inline void ResetGradAccumCount();
141   const CNodePtr GetCNode(const std::string &name) const;
142   inline std::mutex &mutex();
143   void GetEmbeddingTableParamPtr();
144   void SyncEmbeddingTables();
145   // Cache embedding table parameter by map, key: parameter name, value: parameter node pointer
146   void CacheEmbeddingTableParamPtr();
147 
148   size_t pserver_num_;
149   size_t worker_num_;
150   size_t grad_accum_count_;
151   std::unique_ptr<ServerHandler> handler_;
152   FuncGraphPtr func_graph_;
153   std::shared_ptr<session::SessionBasic> sess_;
154   bool running_;
155   bool embedding_param_ptr_cached_{false};
156   // Used to cache embedding table parameter, key: parameter name, value: parameter node pointer
157   std::map<std::string, ParameterPtr> embedding_parameter_tables_;
158 
159   std::unordered_map<Key, std::shared_ptr<PServerKernel>> optimizers_;
160   std::unordered_map<Key, InputsShapePtr> optim_inputs_shape_;
161   std::unordered_map<Key, InputsShapePtr> original_optim_inputs_shape_;
162   std::unordered_map<Key, std::shared_ptr<OptimizerInfo>> optim_infos_;
163   std::unordered_map<std::string, std::shared_ptr<OptimizerInfoBuilder>> optim_info_builders_;
164   std::unordered_map<Key, std::string> weight_key_to_optims_;
165   std::unordered_map<Key, std::string> weight_key_to_optim_op_;
166   std::unordered_map<Key, WeightPtr> weights_;
167   std::unordered_map<Key, bool> is_embedding_;
168   std::unordered_map<Key, WeightPtr> grads_;
169   std::unordered_map<Key, size_t> grads_accum_counter_;
170   std::unordered_map<Key, std::shared_ptr<PServerKernel>> embedding_lookup_ops_;
171   std::unordered_map<Key, uint64_t> tokens_;
172 
173   std::mutex mutex_;
174   std::condition_variable apply_grads_cv_;
175 
176   std::unique_ptr<std::thread> thread_;
177   std::shared_ptr<core::ServerNode> server_node_;
178   std::map<Key, ParameterPtr> embedding_tables_;
179 
180   friend class ServerHandler;
181 };
182 }  // namespace ps
183 }  // namespace mindspore
184 #endif  // MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_
185