• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/parameter_server.h"
18 #include <algorithm>
19 #include <thread>
20 
21 namespace mindspore {
22 namespace ps {
23 static const uint32_t kMaxThreadNum = 16;
24 static const uint32_t kCPUCoreNum = std::thread::hardware_concurrency();
25 
Run(const FuncGraphPtr & func_graph)26 void ParameterServer::Run(const FuncGraphPtr &func_graph) {
27   MS_EXCEPTION_IF_NULL(func_graph);
28   MS_LOG(INFO) << "PServer starts connecting to scheduler and workers...";
29   server_node_ = std::make_shared<core::ServerNode>();
30 
31   MS_LOG(INFO) << "PServer connected successfully.";
32   if (!PSContext::instance()->is_server()) {
33     MS_LOG(INFO) << "This is not the Server node.";
34     return;
35   }
36   Init(func_graph);
37   server_node_->Start();
38   PSContext::instance()->SetPSRankId(server_node_->rank_id());
39   thread_->join();
40   SyncEmbeddingTables();
41   MS_LOG(INFO) << "PServer finished updating models, starts finalizing...";
42   server_node_->Finish();
43   if (!server_node_->Stop()) {
44     MS_LOG(WARNING) << "Parameter server stop failed.";
45   }
46   MS_LOG(INFO) << "PServer finalized successfully.";
47 }
48 
Init(const FuncGraphPtr & func_graph)49 bool ParameterServer::Init(const FuncGraphPtr &func_graph) {
50   pserver_num_ = std::strtol(mindspore::common::GetEnv(kEnvPServerNum).c_str(), nullptr, kBase);
51   worker_num_ = std::strtol(mindspore::common::GetEnv(kEnvWorkerNum).c_str(), nullptr, kBase);
52   func_graph_ = func_graph;
53   handler_.reset(new ServerHandler(this));
54   handler_->Init();
55 
56   InitOptimInfoBuilders();
57   server_node_->set_handler(*handler_);
58   server_node_->RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
59     MS_LOG(ERROR) << "Trigger timeout event: SCHEDULER_TIMEOUT begin to exit the system!";
60     this->Finalize();
61   });
62   server_node_->RegisterEventCallback(core::ClusterEvent::NODE_TIMEOUT, [this]() {
63     MS_LOG(ERROR) << "Trigger timeout event: NODE_TIMEOUT begin to exit the system!";
64     this->Finalize();
65   });
66   thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this));
67   GetEmbeddingTableParamPtr();
68   return true;
69 }
70 
InitOptimInfoBuilders()71 void ParameterServer::InitOptimInfoBuilders() {
72   std::shared_ptr<OptimizerInfoBuilder> momentum_info_builder = std::make_shared<MomentumOptimInfoBuilder>(worker_num_);
73   std::shared_ptr<OptimizerInfoBuilder> sparse_adam_info_builder =
74     std::make_shared<SparseAdamOptimInfoBuilder>(worker_num_);
75   std::shared_ptr<OptimizerInfoBuilder> sparse_ftrl_info_builder =
76     std::make_shared<SparseFtrlOptimInfoBuilder>(worker_num_);
77   optim_info_builders_[kApplyMomentum] = momentum_info_builder;
78   optim_info_builders_[kSparseAdam] = sparse_adam_info_builder;
79   optim_info_builders_[kSparseFtrl] = sparse_ftrl_info_builder;
80 }
81 
InitWeightKeyToOptims(const Key & key,const int64_t & optim_id)82 void ParameterServer::InitWeightKeyToOptims(const Key &key, const int64_t &optim_id) {
83   if (weight_key_to_optims_.count(key) > 0 || Util::optimizer_name(optim_id) == "") {
84     return;
85   }
86   weight_key_to_optims_[key] = Util::optimizer_name(optim_id);
87   weight_key_to_optim_op_[key] = Util::optimizer_node_name(optim_id);
88   MS_LOG(INFO) << "Initializing optimizer id for key:" << key << ", optimizer name:" << weight_key_to_optims_[key]
89                << ", optimizer op name:" << weight_key_to_optim_op_[key];
90 }
91 
InitOptimInputsShape(const Keys & keys,const Values & values,const Lengths & lengths)92 void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths) {
93   InputsShapePtr inputs_shape = std::make_shared<InputsShape>();
94   MS_EXCEPTION_IF_NULL(inputs_shape);
95   InputsShapePtr original_inputs_shape = std::make_shared<InputsShape>();
96   MS_EXCEPTION_IF_NULL(original_inputs_shape);
97   size_t val_idx = 0;
98   const Key &key = keys[0];
99   MS_LOG(INFO) << "Initializing optimizer inputs shape for key:" << key;
100   if (optim_inputs_shape_.count(key) == 0) {
101     original_optim_inputs_shape_[key] = original_inputs_shape;
102     optim_inputs_shape_[key] = inputs_shape;
103   }
104   for (size_t i = 0; i < keys.size(); i++) {
105     auto shape = std::make_shared<std::vector<size_t>>();
106     MS_EXCEPTION_IF_NULL(shape);
107     auto original_shape = std::make_shared<std::vector<size_t>>();
108     MS_EXCEPTION_IF_NULL(original_shape);
109     inputs_shape->push_back(shape);
110     original_inputs_shape->push_back(original_shape);
111 
112     for (int64_t j = 0; j < lengths[i]; j++) {
113       shape->push_back(values[val_idx]);
114       original_shape->push_back(values[val_idx++]);
115     }
116   }
117   if (weight_key_to_optims_.count(key) > 0) {
118     const std::string &optim_name = weight_key_to_optims_[key];
119     const std::string &optim_op_name = weight_key_to_optim_op_[key];
120     if (optimizers_.count(key) == 0 && optim_inputs_shape_.count(key) > 0) {
121       const CNodePtr cnode = GetCNode(optim_op_name);
122       MS_EXCEPTION_IF_NULL(cnode);
123       if (optim_name == kSparseAdam) {
124         std::shared_ptr<PServerKernel> optimizer =
125           std::make_shared<kernel::ps::SparseApplyAdamPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
126         optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
127         optimizers_[key] = optimizer;
128       } else if (optim_name == kSparseLazyAdam) {
129         std::shared_ptr<PServerKernel> optimizer =
130           std::make_shared<kernel::ps::SparseApplyLazyAdamPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
131         optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
132         optimizers_[key] = optimizer;
133       } else if (optim_name == kApplyMomentum) {
134         std::shared_ptr<PServerKernel> optimizer =
135           std::make_shared<kernel::ps::ApplyMomentumPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
136         optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
137         optimizers_[key] = optimizer;
138       } else if (optim_name == kSparseFtrl) {
139         std::shared_ptr<PServerKernel> optimizer =
140           std::make_shared<kernel::ps::SparseApplyFtrlPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
141         optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
142         optimizers_[key] = optimizer;
143       }
144     }
145   }
146 }
147 
InitWeight(const Key & key,const WeightPtr & weight)148 void ParameterServer::InitWeight(const Key &key, const WeightPtr &weight) {
149   MS_EXCEPTION_IF_NULL(weight);
150   if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) {
151     MS_LOG(INFO) << "Initializing weight for key " << key << ", server rank " << server_node_->rank_id();
152     weights_[key] = weight;
153     tokens_[key] = 0;
154     is_embedding_[key] = false;
155   }
156 }
157 
InitGrad(const Key & key,const GradPtr & grad)158 void ParameterServer::InitGrad(const Key &key, const GradPtr &grad) {
159   MS_EXCEPTION_IF_NULL(grad);
160   if (grads_.count(key) == 0) {
161     grads_[key] = grad;
162     grads_accum_counter_[key] = 0;
163   }
164 }
165 
166 namespace {
167 // Initialize accumulation by multithreading parallelism.
InitAccumParallel(float init_value,size_t total_len,float * embedding_data)168 void InitAccumParallel(float init_value, size_t total_len, float *embedding_data) {
169   MS_EXCEPTION_IF_NULL(embedding_data);
170   auto init_task = [](float value, size_t task_len, float *data) {
171     for (size_t i = 0; i < task_len; i++) {
172       data[i] = value;
173     }
174   };
175 
176   size_t thread_num = std::max(kMaxThreadNum, kCPUCoreNum);
177   if (total_len <= thread_num) {
178     thread_num = 1;
179   }
180 
181   std::vector<std::thread> threads(thread_num);
182   size_t task_offset = 0;
183 
184   for (size_t i = 0; i < thread_num; ++i) {
185     // The value of thread_num is >= 1.
186     size_t task_len = total_len / thread_num + (i < (total_len % thread_num) ? 1 : 0);
187     threads[i] = std::thread(init_task, init_value, task_len, embedding_data + task_offset);
188     task_offset += task_len;
189   }
190 
191   for (size_t i = 0; i < thread_num; i++) {
192     threads[i].join();
193   }
194 }
195 
CopyTensorData(void * dest_ptr,size_t tensor_size,const void * src_ptr)196 void CopyTensorData(void *dest_ptr, size_t tensor_size, const void *src_ptr) {
197   MS_EXCEPTION_IF_NULL(dest_ptr);
198   MS_EXCEPTION_IF_NULL(src_ptr);
199   char *dest = reinterpret_cast<char *>(dest_ptr);
200   const char *src = reinterpret_cast<const char *>(src_ptr);
201 
202   // The security memcpy function 'memcpy_s' limits the value of the second parameter 'destMax' not to be greater than
203   // SECUREC_MEM_MAX_LEN. If tensor size(buffer length) is greater than SECUREC_MEM_MAX_LEN, the tensor should be cut
204   // into segments to copy.
205   for (size_t offset = 0; offset < tensor_size; offset += SECUREC_MEM_MAX_LEN) {
206     size_t copy_len = std::min(tensor_size - offset, SECUREC_MEM_MAX_LEN);
207     size_t dest_len = copy_len;
208     int ret = memcpy_s(dest + offset, dest_len, src + offset, copy_len);
209     if (ret != 0) {
210       MS_LOG(EXCEPTION) << "Failed to memcpy tensor, errorno(" << ret << ")";
211     }
212   }
213 }
214 }  // namespace
215 
InitEmbeddingTable(const Key & key,const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> & shapes,const ParamInitInfo & param_init_info)216 void ParameterServer::InitEmbeddingTable(
217   const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
218   const ParamInitInfo &param_init_info) {
219   MS_EXCEPTION_IF_NULL(shapes);
220   if (weights_.count(key) == 0) {
221     std::shared_ptr<PServerKernel> lookup =
222       std::make_shared<kernel::ps::EmbeddingLookUpPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
223     lookup->InitKernel(shapes);
224     embedding_lookup_ops_[key] = lookup;
225 
226     // Init embedding weight
227     const std::vector<size_t> &input_shapes = lookup->input_sizes();
228     size_t total_dims =
229       std::accumulate(input_shapes.begin(), input_shapes.end(), IntToSize(1), std::multiplies<size_t>());
230     WeightPtr embedding = std::make_shared<Weight>(total_dims, 0);
231     MS_EXCEPTION_IF_NULL(embedding);
232     float *embedding_data = embedding->data();
233     std::default_random_engine engine;
234     std::normal_distribution<float> random(0, kStdDev);
235     if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
236       CacheEmbeddingTableParamPtr();
237       if (param_init_info.param_type_ == kWeight) {
238         const std::string &param_name = param_init_info.param_name_;
239         auto iter = embedding_parameter_tables_.find(param_name);
240         if (iter == embedding_parameter_tables_.end()) {
241           MS_LOG(EXCEPTION) << "Can not find parameter info for: " << param_name;
242         }
243         // Cache embedding table parameter by weight key to parameter node pointer.
244         (void)embedding_tables_.emplace(key, iter->second);
245 
246         InitRandomNormal(0, kStdDev, input_shapes, param_init_info.global_seed_, param_init_info.op_seed_,
247                          embedding_data);
248       } else if (param_init_info.param_type_ == kAccumulation) {
249         InitAccumParallel(param_init_info.init_val_, total_dims, embedding_data);
250       }
251     } else {
252       for (size_t i = 0; i < total_dims; i++) {
253         embedding_data[i] = random(engine);
254       }
255     }
256     weights_[key] = embedding;
257     MS_LOG(DEBUG) << "The key:" << key << " the embedding:" << *embedding;
258     tokens_[key] = 0;
259     is_embedding_[key] = true;
260 
261     grads_accum_counter_[key] = 0;
262   }
263 }
264 
HasWeight(const Key & key)265 bool ParameterServer::HasWeight(const Key &key) { return (weights_.count(key) > 0 && !is_embedding_.count(key)); }
266 
Finalize()267 void ParameterServer::Finalize() {
268   running_ = false;
269   apply_grads_cv_.notify_one();
270 }
271 
UpdateWeights()272 void ParameterServer::UpdateWeights() {
273   while (true) {
274     MS_LOG(INFO) << "The running is:" << running_ << " the ready is:" << this->ReadyForUpdateWeights();
275     std::unique_lock<std::mutex> lock(mutex_);
276     apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights() || !running_; });
277     if (!running_) {
278       break;
279     }
280 
281     for (auto iter = weights_.begin(); iter != weights_.end(); iter++) {
282       Key key = iter->first;
283       WeightPtr weight_ptr = iter->second;
284 
285       std::shared_ptr<PServerKernel> optimizer = nullptr;
286       if (weight_key_to_optims_.count(key) > 0) {
287         optimizer = optimizers_[key];
288       }
289       MS_EXCEPTION_IF_NULL(optimizer);
290 
291       std::shared_ptr<OptimizerInfo> optim_info = optim_infos_[key];
292       if (optim_info != nullptr) {
293         const std::vector<kernel::AddressPtr> &inputs = optim_info->inputs();
294         const std::vector<kernel::AddressPtr> &workspaces = optim_info->workspaces();
295         const std::vector<kernel::AddressPtr> &outputs = optim_info->outputs();
296 
297         std::vector<std::vector<size_t>> shapes = {};
298         std::vector<size_t> indices_shape = {};
299         indices_shape.emplace_back(optim_info->indice_size());
300         shapes.push_back(indices_shape);
301 
302         if (original_optim_inputs_shape_.count(key) != 0) {
303           std::transform((*(original_optim_inputs_shape_[key])).begin(), (*(original_optim_inputs_shape_[key])).end(),
304                          std::back_inserter(shapes),
305                          [](const std::shared_ptr<std::vector<size_t>> &input_shapes) -> std::vector<size_t> {
306                            return *input_shapes;
307                          });
308         }
309         optimizer->ReInit(shapes);
310         optim_info->ComputeMean(shapes, worker_num_, pserver_num_, server_node_->rank_id());
311         optimizer->Execute(inputs, workspaces, outputs);
312         optim_info->Reset();
313       }
314       if (!is_embedding_[key]) {
315         tokens_[key] = worker_num_;
316       }
317     }
318     ResetGradAccumCount();
319   }
320 }
321 
AccumGrad(const Keys & keys,const Values & values,const Lengths & lengths)322 void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const Lengths &lengths) {
323   std::unique_lock<std::mutex> lock(mutex_);
324   const Key &key = keys[0];
325   bool no_sparse_grad = values.size() == 1 && values[0] == kGradValue;
326   if (!no_sparse_grad) {
327     std::shared_ptr<OptimizerInfo> optim_info = optim_infos_[key];
328 
329     // Create or update the optimizer info
330     if (optim_info == nullptr) {
331       const std::shared_ptr<OptimizerInfoBuilder> &builder = optim_info_builders_[weight_key_to_optims_[key]];
332       std::shared_ptr<kernel::ps::PServerKernel> pserver_kernel = optimizers_[key];
333       if (pserver_kernel == nullptr) {
334         MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key];
335       }
336       MS_EXCEPTION_IF_NULL(pserver_kernel);
337       OptimizerInfo *optim = builder->Build(pserver_kernel, weights_[key], keys, values, lengths,
338                                             optim_inputs_shape_[key], worker_num_, is_embedding_[key]);
339       optim_info.reset(optim);
340       optim_infos_[key] = optim_info;
341     } else {
342       optim_info->Update(values, lengths);
343       optim_info->Accumulate(values, lengths);
344     }
345   }
346 
347   grads_accum_counter_[key] += 1;
348   if (grads_accum_counter_[key] == worker_num_) {
349     grad_accum_count_++;
350   }
351   if (ReadyForUpdateWeights()) {
352     apply_grads_cv_.notify_one();
353   }
354 }
355 
weight(const Key & key)356 WeightPtr ParameterServer::weight(const Key &key) {
357   std::unique_lock<std::mutex> lock(mutex_);
358   if (weights_.count(key) == 0) {
359     MS_LOG(EXCEPTION) << "Invalid weight key " << key;
360   }
361   WeightPtr weight_ptr = weights_[key];
362   MS_EXCEPTION_IF_NULL(weight_ptr);
363   WeightPtr copy_weight_ptr = std::make_shared<std::vector<float>>(weight_ptr->size(), 0);
364   MS_EXCEPTION_IF_NULL(copy_weight_ptr);
365   copy_weight_ptr = weight_ptr;
366   tokens_[key] -= 1;
367   return copy_weight_ptr;
368 }
369 
DoEmbeddingLookup(Key key,const LookupIds & lookup_ids,KVMessage * res)370 void ParameterServer::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res) {
371   std::unique_lock<std::mutex> lock(mutex_);
372   MS_EXCEPTION_IF_NULL(res);
373   if (weights_.count(key) == 0) {
374     MS_LOG(ERROR) << "Invalid embedding table key " << key;
375     return;
376   }
377   if (embedding_lookup_ops_.count(key) == 0) {
378     MS_LOG(ERROR) << "Invalid embedding lookup op key " << key;
379     return;
380   }
381   WeightPtr table_ptr = weights_[key];
382   MS_EXCEPTION_IF_NULL(table_ptr);
383   std::shared_ptr<PServerKernel> table_lookup_op = embedding_lookup_ops_[key];
384   MS_EXCEPTION_IF_NULL(table_lookup_op);
385 
386   // Update shapes of lookup operator
387   std::vector<std::vector<size_t>> shapes = {};
388   std::vector<size_t> indices_shape = {};
389   indices_shape.emplace_back(lookup_ids.size());
390   shapes.push_back(indices_shape);
391   table_lookup_op->ReInit(shapes);
392 
393   const std::vector<size_t> output_shapes = table_lookup_op->output_sizes();
394   std::vector<kernel::AddressPtr> inputs;
395   AddressPtr embedding_table = std::make_shared<kernel::Address>();
396   MS_EXCEPTION_IF_NULL(embedding_table);
397   AddressPtr indices = std::make_shared<kernel::Address>();
398   MS_EXCEPTION_IF_NULL(indices);
399   inputs.push_back(embedding_table);
400   inputs.push_back(indices);
401   embedding_table->addr = table_ptr->data();
402   embedding_table->size = table_ptr->size() * sizeof(float);
403 
404   std::unique_ptr<int[]> tmp_ids = std::make_unique<int[]>(lookup_ids.size());
405   MS_EXCEPTION_IF_NULL(tmp_ids);
406   for (size_t i = 0; i < lookup_ids.size(); i++) {
407     tmp_ids[i] = static_cast<int>(lookup_ids[i]);
408   }
409   indices->addr = tmp_ids.get();
410   indices->size = lookup_ids.size() * sizeof(int);
411 
412   std::vector<kernel::AddressPtr> workspaces;
413   std::vector<kernel::AddressPtr> outputs;
414   AddressPtr output = std::make_shared<kernel::Address>();
415   MS_EXCEPTION_IF_NULL(output);
416   std::shared_ptr<Values> addr = std::make_shared<Values>(output_shapes[0] / sizeof(float), 0);
417   MS_EXCEPTION_IF_NULL(addr);
418 
419   output->addr = addr->data();
420   output->size = output_shapes[0];
421   outputs.push_back(output);
422 
423   table_lookup_op->Execute(inputs, workspaces, outputs);
424   *res->mutable_values() = {addr->begin(), addr->end()};
425   res->add_len(res->values_size());
426 }
427 
UpdateEmbeddings(const Key & key,const LookupIds & lookup_ids,const Values & vals)428 void ParameterServer::UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals) {
429   if (weights_.count(key) == 0) {
430     MS_LOG(ERROR) << "Invalid embedding table key " << key;
431     return;
432   }
433   if (embedding_lookup_ops_.count(key) == 0) {
434     MS_LOG(ERROR) << "Invalid embedding lookup op key " << key;
435     return;
436   }
437   WeightPtr table_ptr = weights_[key];
438   MS_EXCEPTION_IF_NULL(table_ptr);
439   std::shared_ptr<PServerKernel> table_lookup_op = embedding_lookup_ops_[key];
440   MS_EXCEPTION_IF_NULL(table_lookup_op);
441   table_lookup_op->UpdateEmbeddings(table_ptr->data(), lookup_ids.data(), vals.data(), lookup_ids.size());
442 }
443 
ReadyForUpdateWeights() const444 inline bool ParameterServer::ReadyForUpdateWeights() const {
445   return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size();
446 }
447 
ReadyForPush(const Key & key)448 inline bool ParameterServer::ReadyForPush(const Key &key) {
449   std::unique_lock<std::mutex> lock(mutex_);
450   if (weights_.empty()) {
451     MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send "
452                          "kInitWeightsCmd command. 2.The Server failed to initialize weights.";
453   }
454   return grad_accum_count_ < weights_.size() && tokens_[key] == 0;
455 }
456 
ReadyForPull(const Key & key)457 inline bool ParameterServer::ReadyForPull(const Key &key) {
458   std::unique_lock<std::mutex> lock(mutex_);
459   if (tokens_.count(key) == 0 || weights_[key] == 0) {
460     MS_LOG(EXCEPTION) << "Invalid weight key " << key;
461   }
462   MS_LOG(INFO) << "ReadyForPull: " << (tokens_[key] > 0);
463   return tokens_[key] > 0;
464 }
465 
ResetGradAccumCount()466 inline void ParameterServer::ResetGradAccumCount() {
467   grad_accum_count_ = 0;
468   for (auto iter = grads_accum_counter_.begin(); iter != grads_accum_counter_.end(); iter++) {
469     grads_accum_counter_[iter->first] = 0;
470   }
471 }
472 
GetCNode(const std::string & name) const473 const CNodePtr ParameterServer::GetCNode(const std::string &name) const {
474   std::list<CNodePtr> cnodes = func_graph_->GetOrderedCnodes();
475   for (CNodePtr cnode : cnodes) {
476     MS_EXCEPTION_IF_NULL(cnode);
477     std::string fullname = cnode->fullname_with_scope();
478     if (fullname.find(name) != std::string::npos && fullname.find("Push") != std::string::npos) {
479       return cnode;
480     }
481   }
482   return nullptr;
483 }
484 
mutex()485 inline std::mutex &ParameterServer::mutex() { return mutex_; }
486 
GetEmbeddingTableParamPtr()487 void ParameterServer::GetEmbeddingTableParamPtr() {
488   if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
489     return;
490   }
491 
492   MS_EXCEPTION_IF_NULL(func_graph_);
493   auto cnodes = func_graph_->GetOrderedCnodes();
494   Key count = 0;
495   for (auto cnode : cnodes) {
496     MS_EXCEPTION_IF_NULL(cnode);
497     std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
498     if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName || cnode_name == kSparseGatherV2OpName) {
499       auto embedding_table = AnfAlgo::GetInputNode(cnode, 0);
500       if (IsPrimitiveCNode(embedding_table, prim::kPrimLoad)) {
501         auto embedding_cnode = embedding_table->cast<CNodePtr>();
502         embedding_table = AnfAlgo::GetInputNode(embedding_cnode, 0);
503       }
504       MS_EXCEPTION_IF_NULL(embedding_table);
505       if (embedding_table->isa<Parameter>()) {
506         MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count;
507         embedding_tables_.insert(std::make_pair(count, embedding_table->cast<ParameterPtr>()));
508         count++;
509       }
510     }
511   }
512 }
513 
CacheEmbeddingTableParamPtr()514 void ParameterServer::CacheEmbeddingTableParamPtr() {
515   if (embedding_param_ptr_cached_) {
516     return;
517   }
518 
519   MS_EXCEPTION_IF_NULL(func_graph_);
520   auto cnodes = func_graph_->GetOrderedCnodes();
521   for (auto cnode : cnodes) {
522     MS_EXCEPTION_IF_NULL(cnode);
523     std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
524     if (cnode_name != kGatherV2OpName && cnode_name != kSparseGatherV2OpName) {
525       continue;
526     }
527 
528     auto embedding_table = AnfAlgo::GetInputNode(cnode, 0);
529     if (IsPrimitiveCNode(embedding_table, prim::kPrimLoad)) {
530       auto embedding_cnode = embedding_table->cast<CNodePtr>();
531       embedding_table = AnfAlgo::GetInputNode(embedding_cnode, 0);
532     }
533 
534     MS_EXCEPTION_IF_NULL(embedding_table);
535     if (embedding_table->isa<Parameter>()) {
536       (void)embedding_parameter_tables_.emplace(embedding_table->fullname_with_scope(),
537                                                 embedding_table->cast<ParameterPtr>());
538     }
539   }
540 
541   embedding_param_ptr_cached_ = true;
542 }
543 
SyncEmbeddingTables()544 void ParameterServer::SyncEmbeddingTables() {
545   for (auto embedding_table : embedding_tables_) {
546     Key key = embedding_table.first;
547     if (embedding_lookup_ops_.count(key) == 0) {
548       MS_LOG(WARNING) << "Can't find look up PS kernel for key " << key;
549       continue;
550     }
551     auto lookup = embedding_lookup_ops_[key];
552     const std::vector<size_t> &input_shapes = lookup->input_sizes();
553     std::vector<int64_t> new_tensor_shape(input_shapes.begin(), input_shapes.end());
554 
555     tensor::TensorPtr new_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, new_tensor_shape);
556     MS_EXCEPTION_IF_NULL(new_tensor);
557     float *new_tensor_data_ptr = reinterpret_cast<float *>(new_tensor->data_c());
558     size_t new_tensor_size = static_cast<size_t>(new_tensor->data().nbytes());
559     size_t embedding_table_size = weights_[key]->size() * sizeof(float);
560     if (new_tensor_size != embedding_table_size) {
561       MS_LOG(EXCEPTION) << "Shape of embedding table can't match. New tensor size:" << new_tensor_size
562                         << ", embedding_table size:" << embedding_table_size;
563     }
564     MS_EXCEPTION_IF_NULL(new_tensor_data_ptr);
565     MS_EXCEPTION_IF_NULL(weights_[key]->data());
566 
567     CopyTensorData(new_tensor_data_ptr, new_tensor_size, weights_[key]->data());
568 
569     auto paramter_tensor_ptr = embedding_table.second->default_param();
570     MS_EXCEPTION_IF_NULL(paramter_tensor_ptr);
571     paramter_tensor_ptr->cast<tensor::TensorPtr>()->AssignValue(*new_tensor);
572   }
573 }
574 
Init()575 void ParameterServer::ServerHandler::Init() {
576   handlers_[kInitWeightsCmd] = &ServerHandler::HandleInitWeights;
577   handlers_[kInitWeightToOptimIdCmd] = &ServerHandler::HandleInitWeightToOptimId;
578   handlers_[kInitOptimInputsShapeCmd] = &ServerHandler::HandleInitInputsShape;
579   handlers_[kInitEmbeddingsCmd] = &ServerHandler::HandleInitEmbeddings;
580   handlers_[kCheckReadyForPushCmd] = &ServerHandler::HandleCheckReadyForPush;
581   handlers_[kCheckReadyForPullCmd] = &ServerHandler::HandleCheckReadyForPull;
582   handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup;
583   handlers_[kUpdateEmbeddingsCmd] = &ServerHandler::HandleUpdateEmbeddings;
584   handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize;
585   handlers_[kPushCmd] = &ServerHandler::HandlePushReq;
586   handlers_[kPullCmd] = &ServerHandler::HandlePullReq;
587   commands_[kInitWeightsCmd] = "kInitWeightsCmd";
588   commands_[kInitWeightToOptimIdCmd] = "kInitWeightToOptimIdCmd";
589   commands_[kInitOptimInputsShapeCmd] = "kInitOptimInputsShapeCmd";
590   commands_[kInitEmbeddingsCmd] = "kInitEmbeddingsCmd";
591   commands_[kCheckReadyForPushCmd] = "kCheckReadyForPushCmd";
592   commands_[kCheckReadyForPullCmd] = "kCheckReadyForPullCmd";
593   commands_[kEmbeddingLookupCmd] = "kEmbeddingLookupCmd";
594   commands_[kUpdateEmbeddingsCmd] = "kUpdateEmbeddingsCmd";
595   commands_[kFinalizeCmd] = "kFinalizeCmd";
596   commands_[kPushCmd] = "kPushCmd";
597   commands_[kPullCmd] = "kPullCmd";
598 }
599 
operator ()(const std::shared_ptr<core::TcpConnection> & conn,const std::shared_ptr<core::MessageMeta> & meta,const DataPtr & data,size_t size)600 void ParameterServer::ServerHandler::operator()(const std::shared_ptr<core::TcpConnection> &conn,
601                                                 const std::shared_ptr<core::MessageMeta> &meta, const DataPtr &data,
602                                                 size_t size) {
603   auto output = std::make_shared<std::vector<unsigned char>>();
604   if (commands_.count(meta->user_cmd()) == 0) {
605     MS_LOG(EXCEPTION) << "The command:" << meta->user_cmd() << " is not supported!";
606   }
607   MS_LOG(INFO) << "The command is:" << commands_[meta->user_cmd()];
608 
609   auto &handler_ptr = handlers_[meta->user_cmd()];
610   (this->*handler_ptr)(data, size, output);
611   MS_LOG(DEBUG) << "The output size is:" << output->size();
612 
613   if (output->size() > 0) {
614     ps_->server_node_->Response(conn, meta, output->data(), output->size());
615   } else {
616     // If the size of the output is 0, then constructed an empty string, Because the Response function is a synchronous,
617     // the res variable  will be automatically recycled after calling the Response function
618     std::string res;
619     ps_->server_node_->Response(conn, meta, res.data(), res.length());
620   }
621   MS_LOG(DEBUG) << "The request id is:" << meta->request_id() << " the current time is:"
622                 << std::chrono::time_point_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now())
623                      .time_since_epoch()
624                      .count();
625 }
626 
HandlePushReq(const DataPtr & data,size_t size,const VectorPtr & res)627 void ParameterServer::ServerHandler::HandlePushReq(const DataPtr &data, size_t size, const VectorPtr &res) {
628   MS_EXCEPTION_IF_NULL(res);
629   KVMessage input;
630   CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
631   Keys keys = {input.keys().begin(), input.keys().end()};
632   Values values = {input.values().begin(), input.values().end()};
633   Lengths lens = {input.len().begin(), input.len().end()};
634   MS_LOG(DEBUG) << "The keys:" << keys << " the values:" << values << " the len:" << lens;
635   ps_->AccumGrad(keys, values, lens);
636 }
637 
HandlePullReq(const DataPtr & data,size_t size,const VectorPtr & res)638 void ParameterServer::ServerHandler::HandlePullReq(const DataPtr &data, size_t size, const VectorPtr &res) {
639   MS_EXCEPTION_IF_NULL(res);
640   KVMessage input;
641   CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
642   KVMessage res_data;
643   *res_data.mutable_keys() = input.keys();
644   Key key = input.keys()[0];
645   auto weight = ps_->weight(key);
646   *res_data.mutable_values() = {weight->begin(), weight->end()};
647   res->resize(res_data.ByteSizeLong());
648   size_t dest_size = res_data.ByteSizeLong();
649   size_t src_size = res_data.ByteSizeLong();
650   int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
651   if (ret != 0) {
652     MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
653   }
654 }
655 
HandleInitWeights(const DataPtr & data,size_t size,const VectorPtr & res)656 void ParameterServer::ServerHandler::HandleInitWeights(const DataPtr &data, size_t size, const VectorPtr &res) {
657   std::unique_lock<std::mutex> lock(ps_->mutex());
658   MS_EXCEPTION_IF_NULL(res);
659   KVMessage input;
660   CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
661   int key_num = input.keys_size();
662   const float *data_ptr = input.values().data();
663   size_t pos = 0;
664   for (int i = 0; i < key_num; i++) {
665     Key key = input.keys()[i];
666     size_t data_len = input.len_size() != key_num ? input.values_size() / key_num : input.len()[i];
667 
668     if (!ps_->HasWeight(key)) {
669       WeightPtr weight_ptr = std::make_shared<std::vector<float>>(data_ptr + pos, data_ptr + (pos + data_len));
670       MS_EXCEPTION_IF_NULL(weight_ptr);
671       ps_->InitWeight(key, weight_ptr);
672 
673       GradPtr grad_ptr = std::make_shared<std::vector<float>>(data_len, 0);
674       MS_EXCEPTION_IF_NULL(grad_ptr);
675       ps_->InitGrad(key, grad_ptr);
676     }
677     pos += data_len;
678   }
679 }
680 
HandleInitWeightToOptimId(const DataPtr & data,size_t size,const VectorPtr & res)681 void ParameterServer::ServerHandler::HandleInitWeightToOptimId(const DataPtr &data, size_t size, const VectorPtr &res) {
682   std::unique_lock<std::mutex> lock(ps_->mutex());
683   MS_EXCEPTION_IF_NULL(res);
684   KVMessage input;
685   CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
686   int key_num = input.keys_size();
687   for (int i = 0; i < key_num; i++) {
688     Key key = input.keys()[i];
689     float val = input.values()[i];
690     if (init_weight_to_optim_[key]) {
691       continue;
692     } else {
693       init_weight_to_optim_[key] = true;
694     }
695     ps_->InitWeightKeyToOptims(key, static_cast<int64_t>(val));
696   }
697 }
698 
HandleInitInputsShape(const DataPtr & data,size_t size,const VectorPtr & res)699 void ParameterServer::ServerHandler::HandleInitInputsShape(const DataPtr &data, size_t size, const VectorPtr &res) {
700   std::unique_lock<std::mutex> lock(ps_->mutex());
701   MS_EXCEPTION_IF_NULL(res);
702   KVMessage input;
703   CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
704   const Key &key = input.keys()[0];
705   if (init_optim_info_[key]) {
706     return;
707   } else {
708     init_optim_info_[key] = true;
709   }
710   Keys keys = {input.keys().begin(), input.keys().end()};
711   Values values = {input.values().begin(), input.values().end()};
712   Lengths lens = {input.len().begin(), input.len().end()};
713   ps_->InitOptimInputsShape(keys, values, lens);
714 }
715 
HandleInitEmbeddings(const DataPtr & data,size_t size,const VectorPtr &)716 void ParameterServer::ServerHandler::HandleInitEmbeddings(const DataPtr &data, size_t size, const VectorPtr &) {
717   std::unique_lock<std::mutex> lock(ps_->mutex());
718   EmbeddingTableMeta embedding_table_meta;
719   CHECK_RETURN_TYPE(embedding_table_meta.ParseFromArray(data.get(), SizeToInt(size)));
720   const Key &key = embedding_table_meta.key();
721   MS_LOG(INFO) << "Initializing embedding table for key:" << key;
722   std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes =
723     std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>();
724   MS_EXCEPTION_IF_NULL(shapes);
725   std::shared_ptr<std::vector<size_t>> input_shape = std::make_shared<std::vector<size_t>>(
726     embedding_table_meta.input_shape().begin(), embedding_table_meta.input_shape().end());
727   MS_EXCEPTION_IF_NULL(input_shape);
728   std::shared_ptr<std::vector<size_t>> indices_shape = std::make_shared<std::vector<size_t>>(
729     embedding_table_meta.indices_shape().begin(), embedding_table_meta.indices_shape().end());
730   MS_EXCEPTION_IF_NULL(indices_shape);
731   std::shared_ptr<std::vector<size_t>> output_shape = std::make_shared<std::vector<size_t>>(
732     embedding_table_meta.output_shape().begin(), embedding_table_meta.output_shape().end());
733   MS_EXCEPTION_IF_NULL(output_shape);
734   shapes->push_back(input_shape);
735   shapes->push_back(indices_shape);
736   shapes->push_back(output_shape);
737 
738   const ParamInitInfoMessage &info = embedding_table_meta.info();
739   ParamInitInfo param_init_info;
740   if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
741     param_init_info.param_name_ = info.param_name();
742     param_init_info.param_type_ = static_cast<ParamType>(info.param_type());
743     if (param_init_info.param_type_ == kWeight) {
744       param_init_info.global_seed_ = info.global_seed();
745       param_init_info.op_seed_ = info.op_seed();
746     } else if (param_init_info.param_type_ == kAccumulation) {
747       param_init_info.init_val_ = info.init_val();
748     }
749   }
750   ps_->InitEmbeddingTable(key, shapes, param_init_info);
751 }
752 
HandleCheckReadyForPush(const DataPtr & data,size_t size,const VectorPtr & res)753 void ParameterServer::ServerHandler::HandleCheckReadyForPush(const DataPtr &data, size_t size, const VectorPtr &res) {
754   MS_EXCEPTION_IF_NULL(res);
755   KVMessage input;
756   CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
757   const Key &key = input.keys()[0];
758   bool ready = ps_->ReadyForPush(key);
759   MS_LOG(INFO) << "The ready is:" << ready;
760   KVMessage res_data;
761   res_data.add_keys(key);
762   res_data.add_values(ready);
763   res->resize(res_data.ByteSizeLong());
764   size_t dest_size = res_data.ByteSizeLong();
765   size_t src_size = res_data.ByteSizeLong();
766   int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
767   if (ret != 0) {
768     MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
769   }
770 }
771 
HandleCheckReadyForPull(const DataPtr & data,size_t size,const VectorPtr & res)772 void ParameterServer::ServerHandler::HandleCheckReadyForPull(const DataPtr &data, size_t size, const VectorPtr &res) {
773   MS_EXCEPTION_IF_NULL(res);
774   KVMessage input;
775   CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
776   const Key &key = input.keys()[0];
777   bool ready = ps_->ReadyForPull(key);
778   KVMessage res_data;
779   res_data.add_keys(key);
780   res_data.add_values(ready);
781   res->resize(res_data.ByteSizeLong());
782   size_t dest_size = res_data.ByteSizeLong();
783   size_t src_size = res_data.ByteSizeLong();
784   int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
785   if (ret != 0) {
786     MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
787   }
788 }
789 
HandleEmbeddingLookup(const DataPtr & data,size_t size,const VectorPtr & res)790 void ParameterServer::ServerHandler::HandleEmbeddingLookup(const DataPtr &data, size_t size, const VectorPtr &res) {
791   MS_EXCEPTION_IF_NULL(res);
792   EmbeddingTableLookup input;
793   CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
794   const Key &key = input.key();
795 
796   KVMessage res_data;
797   std::vector<Key> keys = {input.keys().begin(), input.keys().end()};
798   *res_data.mutable_keys() = {input.keys().begin(), input.keys().end()};
799 
800   ps_->DoEmbeddingLookup(key, keys, &res_data);
801 
802   res->resize(res_data.ByteSizeLong());
803   size_t dest_size = res_data.ByteSizeLong();
804   size_t src_size = res_data.ByteSizeLong();
805   int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
806   if (ret != 0) {
807     MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
808   }
809 }
810 
HandleUpdateEmbeddings(const DataPtr & data,size_t size,const VectorPtr & res)811 void ParameterServer::ServerHandler::HandleUpdateEmbeddings(const DataPtr &data, size_t size, const VectorPtr &res) {
812   std::unique_lock<std::mutex> lock(ps_->mutex());
813   MS_EXCEPTION_IF_NULL(res);
814   KVMessage input;
815   CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
816   const Key &key = input.keys()[0];
817   const LookupIds &lookup_ids = {input.keys().begin() + 1, input.keys().end()};
818   const Values &update_vals = {input.values().begin(), input.values().end()};
819   ps_->UpdateEmbeddings(key, lookup_ids, update_vals);
820 }
821 
HandleFinalize(const DataPtr &,size_t,const VectorPtr & res)822 void ParameterServer::ServerHandler::HandleFinalize(const DataPtr &, size_t, const VectorPtr &res) {
823   MS_EXCEPTION_IF_NULL(res);
824   ps_->Finalize();
825 }
826 }  // namespace ps
827 }  // namespace mindspore
828