• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/worker.h"
18 #include "pipeline/jit/pipeline.h"
19 
20 namespace mindspore {
21 namespace ps {
Run()22 void Worker::Run() {
23   std::lock_guard<std::mutex> lock(running_mutex_);
24 
25   server_num_ = PSContext::instance()->initial_server_num();
26   if (running_) {
27     MS_LOG(INFO) << "'Worker is already running.";
28     return;
29   }
30   if (!PSContext::instance()->is_worker()) {
31     MS_LOG(EXCEPTION) << "The role is not worker.";
32   }
33 
34   Initialize();
35 
36   worker_node_.RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
37     MS_LOG(ERROR) << "Trigger timeout event: SCHEDULER_TIMEOUT begin to exit the system!";
38     this->Finalize();
39     exit(0);
40   });
41   worker_node_.RegisterEventCallback(core::ClusterEvent::NODE_TIMEOUT, [this]() {
42     MS_LOG(ERROR) << "Trigger timeout event: NODE_TIMEOUT begin to exit the system!";
43     this->Finalize();
44     exit(0);
45   });
46 
47   MS_LOG(INFO) << "Worker starts connecting to scheduler and server...";
48   worker_node_.Start();
49   MS_LOG(INFO) << "Worker connected successfully.";
50 
51   running_ = true;
52 }
53 
Push(const std::vector<size_t> & keys,std::vector<uintptr_t> addrs,const ShapeVector & sizes)54 void Worker::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes) {
55   if (keys.size() == 0) {
56     MS_LOG(EXCEPTION) << "key size should be greater than zero";
57   }
58   if (key_to_optimId_.count(keys[0]) == 0) {
59     MS_LOG(EXCEPTION) << "no optim id found for key" << keys[0];
60   }
61   Key key = keys[0];
62   int64_t optim_id = key_to_optimId_[key];
63   MS_LOG(INFO) << "The key is:" << key << " the optim_id:" << optim_id;
64   bool is_sparse = false;
65   if (optim_id == 1 || optim_id == kSparseLazyAdamIndex || optim_id == kSparseFtrlIndex) {
66     is_sparse = true;
67   }
68   int64_t grad_index = -1;
69   int64_t indice_index = -1;
70 
71   // Sparse adam gradient
72   if (optim_id == 1 || optim_id == kSparseLazyAdamIndex) {
73     grad_index = kSparseGradIndex;
74     indice_index = kSparseIndiceIndex;
75 
76     // Sparse ftrl gradient
77   } else if (optim_id == kSparseFtrlIndex) {
78     grad_index = 0;
79     indice_index = 1;
80   }
81 
82   size_t total_size = std::accumulate(sizes.begin(), sizes.end(), 0, std::plus<int64_t>());
83   std::vector<float> total_buffer(total_size, 0);
84   size_t offset = 0;
85   for (size_t i = 0; i < sizes.size(); i++) {
86     void *dst_data = total_buffer.data() + offset / sizeof(float);
87     void *src_data = reinterpret_cast<void *>(addrs[i]);
88     MS_EXCEPTION_IF_NULL(dst_data);
89     MS_EXCEPTION_IF_NULL(src_data);
90     size_t size = sizes[i] * sizeof(float);
91     size_t dest_size = size;
92     size_t src_size = size;
93     auto ret = memcpy_s(dst_data, dest_size, src_data, src_size);
94     if (ret != 0) {
95       MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
96       return;
97     }
98     offset += size;
99   }
100   MS_LOG(INFO) << "The total size is:" << total_size;
101 
102   while (running_ && (!IsReadyForPush(keys[0]))) {
103     continue;
104   }
105   std::vector<int> sizes_int;
106   (void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int),
107                        [](const int64_t &value) { return static_cast<int>(value); });
108   if (!is_sparse) {
109     PushData(std::vector<Key>(keys), total_buffer, std::vector<int>(sizes_int), kPushCmd);
110   } else {
111     std::vector<int64_t> &var_shape = key_to_optim_shapes_[key][0];
112     int64_t first_dim_size = var_shape[0];
113     int64_t outer_dim_size = std::accumulate(var_shape.begin() + 1, var_shape.end(), 1, std::multiplies<int64_t>());
114     MS_LOG(DEBUG) << "The keys:" << keys << " the total_buffer:" << total_buffer << " the sizes_int:" << sizes_int
115                   << " the grad_index:" << grad_index << " the indice_index:" << indice_index
116                   << " the first_dim_size:" << first_dim_size << " the outer_dim_size" << outer_dim_size;
117     PushSparseData(std::vector<Key>(keys), total_buffer, std::vector<int>(sizes_int), LongToSize(grad_index),
118                    LongToSize(indice_index), LongToSize(first_dim_size), LongToSize(outer_dim_size));
119   }
120 }
121 
Pull(const size_t key,void * dev_addr,const size_t size)122 void Worker::Pull(const size_t key, void *dev_addr, const size_t size) {
123   MS_EXCEPTION_IF_NULL(dev_addr);
124   std::vector<float> variables(size / sizeof(float), 0);
125   while (running_ && (!IsReadyForPull(key))) {
126     continue;
127   }
128   PullData({key}, &variables, nullptr, kPullCmd);
129   MS_LOG(DEBUG) << "The variables:" << variables << " the size is:" << size;
130   size_t dst_size = size;
131   size_t src_size = size;
132   auto ret = memcpy_s(dev_addr, dst_size, variables.data(), src_size);
133   if (ret != 0) {
134     MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
135     return;
136   }
137 }
138 
SetParamKey(const std::string & param_name)139 size_t Worker::SetParamKey(const std::string &param_name) {
140   size_t key = UINT64_MAX;
141   if (param_to_key_.count(param_name)) {
142     key = param_to_key_[param_name];
143     MS_LOG(INFO) << param_name << " key is already set: key value is " << key;
144   } else {
145     key = key_cnt_++;
146     param_to_key_[param_name] = key;
147     MS_LOG(INFO) << "Set key " << key << " for parameter " << param_name;
148   }
149   return key;
150 }
151 
GetParamKey(const std::string & param_name)152 size_t Worker::GetParamKey(const std::string &param_name) {
153   size_t key = kInvalidKey;
154   if (param_to_key_.find(param_name) != param_to_key_.end()) {
155     key = param_to_key_[param_name];
156     MS_LOG(DEBUG) << "Get key of parameter " << param_name << " key is " << key;
157   }
158   return key;
159 }
160 
SetParamInitInServer(const std::string & param_name,bool init_in_server)161 void Worker::SetParamInitInServer(const std::string &param_name, bool init_in_server) {
162   MS_LOG(DEBUG) << "Set parameter " << param_name << " init_in_server:" << init_in_server;
163   param_to_init_in_server_[param_name] = init_in_server;
164 }
165 
GetParamInitInServer(const std::string & param_name)166 bool Worker::GetParamInitInServer(const std::string &param_name) {
167   if (param_to_init_in_server_.count(param_name) == 0) {
168     return false;
169   }
170   return param_to_init_in_server_[param_name];
171 }
172 
SetKeyOptimId(size_t key,const std::string & optimizer_name)173 void Worker::SetKeyOptimId(size_t key, const std::string &optimizer_name) {
174   MS_LOG(INFO) << "SetKeyOptimId key is:" << key << " optimizer_name:" << optimizer_name;
175   key_to_optimId_[key] = Util::optimizer_id(optimizer_name);
176 }
177 
SetOptimInputShapes(size_t key,const ShapeVector & shape)178 void Worker::SetOptimInputShapes(size_t key, const ShapeVector &shape) {
179   if (key_to_optim_shapes_.find(key) == key_to_optim_shapes_.end()) {
180     key_to_optim_shapes_[key] = {shape};
181   } else {
182     key_to_optim_shapes_[key].push_back(shape);
183   }
184 }
185 
AddEmbeddingTable(const Key & key,const size_t & row_count)186 void Worker::AddEmbeddingTable(const Key &key, const size_t &row_count) {
187   bool has_init = IsKeyInit(key);
188   if (has_init) {
189     return;
190   }
191   uint64_t begin = 0;
192   uint64_t end = 0;
193   for (int64_t i = 0; i < server_num_; i++) {
194     size_t local_row_cnt = LongToSize(Util::LocalShard(row_count, i, server_num_));
195     MS_LOG(DEBUG) << "The row_count:" << row_count << " the local_row_cnt:" << local_row_cnt;
196     if (i == 0) {
197       end = local_row_cnt - 1;
198     } else {
199       begin = end + 1;
200       end += local_row_cnt;
201     }
202     EmbeddingTableShardMetadata range(begin, end);
203     if (embedding_table_ranges_.count(key) == 0) {
204       embedding_table_ranges_[key] = std::make_shared<std::vector<EmbeddingTableShardMetadata>>();
205       MS_EXCEPTION_IF_NULL(embedding_table_ranges_[key]);
206     }
207     embedding_table_ranges_[key]->push_back(range);
208   }
209   embedding_row_cnt_[key] = row_count;
210 }
211 
InitPSEmbeddingTable(const size_t & key,const std::vector<size_t> & input_shape,const std::vector<size_t> & indices_shape,const std::vector<size_t> & output_shape,const ParamInitInfoMessage & info)212 void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape,
213                                   const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape,
214                                   const ParamInitInfoMessage &info) {
215   bool has_init = IsKeyInit(key);
216   if (has_init) {
217     MS_LOG(DEBUG) << "The key embedding table of key " << key << " is initialized.";
218     return;
219   }
220 
221   EmbeddingTableMeta embedding_table_meta;
222   embedding_table_meta.set_key(key);
223   *embedding_table_meta.mutable_input_shape() = {input_shape.begin(), input_shape.end()};
224   *embedding_table_meta.mutable_indices_shape() = {indices_shape.begin(), indices_shape.end()};
225   *embedding_table_meta.mutable_output_shape() = {output_shape.begin(), output_shape.end()};
226   *embedding_table_meta.mutable_info() = info;
227 
228   std::string kv_data = embedding_table_meta.SerializeAsString();
229 
230   std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
231   size_t dest_size = kv_data.length();
232   int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
233   if (ret != 0) {
234     MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
235     return;
236   }
237 
238   worker_node_.Broadcast(core::NodeRole::SERVER, res, kv_data.length(), kInitEmbeddingsCmd);
239 }
240 
InitPSParamAndOptim(const AnfNodePtr & input_node,const tensor::TensorPtr & tensor)241 void Worker::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor) {
242   MS_EXCEPTION_IF_NULL(tensor);
243   MS_EXCEPTION_IF_NULL(input_node);
244   auto pk_node = input_node->cast<ParameterPtr>();
245   MS_EXCEPTION_IF_NULL(pk_node);
246   const std::string &param_name = pk_node->fullname_with_scope();
247   void *param_data = tensor->data_c();
248   size_t param_size = LongToSize(tensor->data().nbytes());
249 
250   size_t param_key = GetParamKey(param_name);
251   if (param_key == kInvalidKey) {
252     MS_LOG(DEBUG) << "Parameter " << param_name << " has no key assigned.";
253     return;
254   }
255   bool init_in_server = false;
256   auto param_info_ptr = pk_node->param_info();
257   if (param_info_ptr != nullptr && param_info_ptr->init_in_server()) {
258     init_in_server = true;
259   }
260   SetParamInitInServer(param_name, init_in_server);
261   bool init = IsKeyInit(param_key);
262   if (!init) {
263     MS_LOG(DEBUG) << "Init parameter key " << param_key << " and optimizer in parameter server side for " << param_name
264                   << ", whether init in server: " << init_in_server;
265     AddKeyToServerId(param_key);
266     if (!PsDataPrefetch::GetInstance().cache_enable()) {
267       if (!init_in_server) {
268         if (param_size > INT_MAX) {
269           MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is "
270                             << param_size;
271         }
272         InitPSParamData({param_key}, param_data, param_size);
273       }
274       InitPSOptimId(param_key);
275       InitPSOptimInputShapes(param_key);
276     }
277   }
278 }
279 
DoPSEmbeddingLookup(const Key & key,const std::vector<int> & lookup_ids,std::vector<float> * lookup_result,int64_t cmd)280 void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ids, std::vector<float> *lookup_result,
281                                  int64_t cmd) {
282   MS_EXCEPTION_IF_NULL(lookup_result);
283   EmbeddingTableLookup embedding_table_lookup;
284   embedding_table_lookup.set_key(key);
285   *embedding_table_lookup.mutable_keys() = {lookup_ids.begin(), lookup_ids.end()};
286 
287   PartitionEmbeddingMessages messages;
288   lookup_partitioner_(embedding_table_lookup, &messages, {});
289   std::vector<uint32_t> rank_ids;
290   std::vector<DataPtr> data;
291   std::vector<size_t> sizes;
292   for (size_t i = 0; i < messages.size(); i++) {
293     if (messages.at(i).first) {
294       rank_ids.push_back(i);
295       std::string kv_data = messages.at(i).second.SerializeAsString();
296 
297       std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
298       size_t dest_size = kv_data.length();
299       int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
300       if (ret != 0) {
301         MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
302         return;
303       }
304       data.push_back(res);
305       sizes.push_back(kv_data.length());
306     }
307   }
308 
309   std::vector<VectorPtr> resp;
310   if (!worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, LongToInt(cmd), &resp)) {
311     MS_LOG(ERROR) << "Worker send failed!";
312   }
313   int64_t single_id_len = SizeToLong(lookup_result->size() / lookup_ids.size());
314   std::unordered_map<Key, std::shared_ptr<std::pair<float *, int64_t>>> id_addr_map;
315   std::shared_ptr<std::vector<float>> values = std::make_shared<std::vector<float>>();
316   std::shared_ptr<std::vector<Key>> keys = std::make_shared<std::vector<Key>>();
317   int64_t value_offset = 0;
318   for (size_t i = 0; i < resp.size(); ++i) {
319     KVMessage message;
320     CHECK_RETURN_TYPE(message.ParseFromArray(resp.at(i)->data(), resp.at(i)->size()));
321     for (auto j = 0; j < message.values_size(); j++) {
322       values->push_back(message.values(j));
323     }
324     for (auto k = 0; k < message.keys_size(); k++) {
325       const Key &message_key = message.keys(k);
326       keys->push_back(message_key);
327     }
328   }
329 
330   for (size_t i = 0; i < keys->size(); i++) {
331     const Key &map_key = keys->at(i);
332     float *addr = values->data() + value_offset;
333     value_offset += single_id_len;
334     id_addr_map[map_key] = std::make_shared<std::pair<float *, int64_t>>(std::make_pair(addr, single_id_len));
335   }
336 
337   float *result_addr = lookup_result->data();
338   MS_EXCEPTION_IF_NULL(result_addr);
339   int64_t offset = 0;
340   size_t dst_size = 0;
341   size_t src_size = 0;
342   void *dst_data = nullptr;
343   void *src_data = nullptr;
344   for (size_t i = 0; i < lookup_ids.size(); i++) {
345     if (id_addr_map.count(lookup_ids[i]) == 0) {
346       offset += single_id_len;
347       continue;
348     }
349     const Key &id_key = static_cast<Key>(lookup_ids[i]);
350     auto &pair = id_addr_map[id_key];
351     size_t size = LongToSize(single_id_len * sizeof(float));
352     dst_size = size;
353     src_size = size;
354     dst_data = result_addr + offset;
355     src_data = pair->first;
356     MS_EXCEPTION_IF_NULL(dst_data);
357     MS_EXCEPTION_IF_NULL(src_data);
358     auto mem_ret = memcpy_s(dst_data, dst_size, src_data, src_size);
359     if (mem_ret != 0) {
360       MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << mem_ret << ")";
361       return;
362     }
363     offset += single_id_len;
364   }
365 }
366 
UpdateEmbeddingTable(const std::vector<Key> & keys,const std::vector<int> & lookup_ids,const std::vector<float> & vals)367 void Worker::UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vector<int> &lookup_ids,
368                                   const std::vector<float> &vals) {
369   KVMessage kvs;
370   *kvs.mutable_keys() = {keys.begin(), keys.end()};
371   *kvs.mutable_len() = {lookup_ids.begin(), lookup_ids.end()};
372   *kvs.mutable_values() = {vals.begin(), vals.end()};
373   PartitionKVMessages messages;
374   update_embedding_partitioner_(kvs, &messages, {});
375   std::vector<uint32_t> rank_ids;
376   std::vector<DataPtr> data;
377   std::vector<size_t> sizes;
378   for (size_t i = 0; i < messages.size(); i++) {
379     if (messages.at(i).first) {
380       rank_ids.push_back(i);
381       std::string kv_data = messages.at(i).second.SerializeAsString();
382 
383       std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
384       size_t dest_size = kv_data.length();
385       int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
386       if (ret != 0) {
387         MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
388         return;
389       }
390       data.push_back(res);
391       sizes.push_back(kv_data.length());
392     }
393   }
394   (void)worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, LongToInt(kUpdateEmbeddingsCmd));
395 }
396 
Finalize()397 void Worker::Finalize() {
398   if (running_) {
399     MS_LOG(INFO) << "Worker starts finalizing...";
400     KVMessage kvs;
401     kvs.add_keys(0);
402     kvs.add_values(0.0f);
403     std::string kv_data = kvs.SerializeAsString();
404     std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
405     size_t dest_size = kv_data.length();
406     int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
407     if (ret != 0) {
408       MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
409       return;
410     }
411     worker_node_.Broadcast(core::NodeRole::SERVER, res, kv_data.length(), kFinalizeCmd);
412     worker_node_.Finish();
413     worker_node_.Stop();
414     running_ = false;
415     MS_LOG(INFO) << "Worker finalized successfully.";
416   }
417 }
418 
Initialize()419 void Worker::Initialize() {
420   lookup_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
421     LookupIdPartitioner(send, partition, attrs);
422   };
423   worker_init_embedding_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
424     WorkerInitEmbeddingPartitioner(send, partition, attrs);
425   };
426   round_robin_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
427     RoundRobinPartitioner(send, partition, attrs);
428   };
429   sparse_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
430     SparsePartitioner(send, partition, attrs);
431   };
432   update_embedding_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
433     UpdateEmbeddingPartitioner(send, partition, attrs);
434   };
435   broadcast_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
436     BroadcastPartitioner(send, partition, attrs);
437   };
438 }
439 
IsKeyInit(const size_t key)440 bool Worker::IsKeyInit(const size_t key) {
441   if (init_keys_.find(key) == init_keys_.end() || !init_keys_[key]) {
442     return false;
443   }
444   return true;
445 }
446 
AddKeyToServerId(const Key & key)447 void Worker::AddKeyToServerId(const Key &key) { AddKeyByHashMod(key); }
448 
AddKeyByHashMod(const Key & key)449 void Worker::AddKeyByHashMod(const Key &key) {
450   if (server_num_ == 0) {
451     MS_LOG(EXCEPTION) << "Server number is invalid:0";
452   }
453   key_to_server_id_[key] = static_cast<int64_t>(key % server_num_);
454   MS_LOG(DEBUG) << "The server id of key " << key << " is " << key_to_server_id_[key];
455 }
456 
InitPSOptimId(const size_t param_key)457 void Worker::InitPSOptimId(const size_t param_key) {
458   MS_LOG(INFO) << "InitPSOptimId key is:" << param_key;
459   if (key_to_optimId_.count(param_key) == 0) {
460     MS_LOG(EXCEPTION) << "Can't find optimizer id of parameter key " << param_key;
461   }
462   int64_t optim_id = key_to_optimId_[param_key];
463 
464   std::vector<Key> keys = {param_key};
465   std::vector<float> optim_id_vals = {static_cast<float>(optim_id)};
466   std::vector<int> optim_id_lens = {SizeToInt(optim_id_vals.size())};
467   MS_LOG(INFO) << "The keys is" << keys << " the optim_id_vals is: " << optim_id_vals
468                << " optim_id_lens is:" << optim_id_lens;
469   PushData(keys, optim_id_vals, optim_id_lens, kInitWeightToOptimIdCmd);
470 }
471 
InitPSOptimInputShapes(const size_t key)472 void Worker::InitPSOptimInputShapes(const size_t key) {
473   std::vector<Key> keys;
474   std::vector<int> shape_len;
475   std::vector<float> all_shape;
476   std::vector<ShapeVector> shapes = key_to_optim_shapes_[key];
477   for (auto shape : shapes) {
478     keys.push_back(key);
479     if (shape.size() == 0) {
480       shape_len.push_back(1);
481       all_shape.push_back(1);
482     } else {
483       shape_len.push_back(SizeToLong(shape.size()));
484       std::transform(shape.begin(), shape.end(), std::back_inserter(all_shape),
485                      [](size_t dim) -> float { return static_cast<float>(dim); });
486     }
487   }
488   MS_LOG(INFO) << "keys:" << keys;
489   MS_LOG(INFO) << "shape_len:" << shape_len;
490   MS_LOG(INFO) << "all_shape:" << all_shape;
491   if (!init_keys_[key]) {
492     init_keys_[key] = true;
493   }
494   PushData(keys, all_shape, shape_len, kInitOptimInputsShapeCmd);
495 }
496 
InitPSParamData(const std::vector<size_t> & keys,void * const origin_addr,size_t size)497 void Worker::InitPSParamData(const std::vector<size_t> &keys, void *const origin_addr, size_t size) {
498   MS_EXCEPTION_IF_NULL(origin_addr);
499   std::vector<float> addr{reinterpret_cast<float *>(origin_addr),
500                           reinterpret_cast<float *>(origin_addr) + size / sizeof(float)};
501   std::vector<Key> key(keys);
502   std::vector<int> lens;
503   lens.push_back(addr.size());
504   MS_LOG(INFO) << "the keys are:" << keys;
505   MS_LOG(INFO) << "the values are:" << addr;
506   PushData(key, addr, lens, kInitWeightsCmd);
507   init_keys_[key[0]] = true;
508 }
509 
IsReadyForPush(const Key & key)510 bool Worker::IsReadyForPush(const Key &key) {
511   std::vector<float> result(1, 0);
512   PullData({key}, &result, nullptr, kCheckReadyForPushCmd);
513   MS_LOG(INFO) << "key:" << key;
514   if (result[0] > 0) {
515     MS_LOG(INFO) << "IsReadyForPush:";
516     return true;
517   } else {
518     MS_LOG(INFO) << "IsReadyForPush:";
519     return false;
520   }
521 }
522 
IsReadyForPull(const Key & key)523 bool Worker::IsReadyForPull(const Key &key) {
524   std::vector<float> result(1, 0);
525   PullData({key}, &result, nullptr, kCheckReadyForPullCmd);
526   if (result[0] > 0) {
527     MS_LOG(INFO) << "IsReadyForPull";
528     return true;
529   } else {
530     MS_LOG(INFO) << "IsReadyForPull";
531     return false;
532   }
533 }
534 
PrepareSparseGradient(const size_t,const size_t,const std::unordered_set<int> & distinct_ids,const std::vector<std::pair<int,float * >> & indice_to_grads,const int * all_indice,const size_t segment_size,float * gradient,int * indices)535 void Worker::PrepareSparseGradient(const size_t, const size_t, const std::unordered_set<int> &distinct_ids,
536                                    const std::vector<std::pair<int, float *>> &indice_to_grads, const int *all_indice,
537                                    const size_t segment_size, float *gradient, int *indices) {
538   MS_EXCEPTION_IF_NULL(all_indice);
539   MS_EXCEPTION_IF_NULL(gradient);
540   MS_EXCEPTION_IF_NULL(indices);
541   size_t offset = 0;
542   int64_t index = 0;
543   size_t segment_data_size = segment_size * sizeof(float);
544   size_t dst_size;
545   size_t src_size;
546   void *dst_data = nullptr;
547   void *src_data = nullptr;
548   for (auto &pair : indice_to_grads) {
549     if (distinct_ids.count(pair.first) == 0) {
550       continue;
551     }
552     indices[index++] = pair.first;
553 
554     dst_size = segment_data_size;
555     src_size = segment_data_size;
556     dst_data = gradient + offset;
557     src_data = pair.second;
558     MS_EXCEPTION_IF_NULL(dst_data);
559     MS_EXCEPTION_IF_NULL(src_data);
560     auto ret = memcpy_s(gradient + offset, dst_size, pair.second, src_size);
561     if (ret != 0) {
562       MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
563       return;
564     }
565     offset += segment_size;
566   }
567 }
568 
BuildSparseValue(const std::vector<int> & lengths,const size_t grad_index,const size_t indice_index,const float * original_data,const float * grads,int * indices,std::vector<float> * reduced_data)569 void Worker::BuildSparseValue(const std::vector<int> &lengths, const size_t grad_index, const size_t indice_index,
570                               const float *original_data, const float *grads, int *indices,
571                               std::vector<float> *reduced_data) {
572   MS_EXCEPTION_IF_NULL(original_data);
573   MS_EXCEPTION_IF_NULL(grads);
574   MS_EXCEPTION_IF_NULL(indices);
575   MS_EXCEPTION_IF_NULL(reduced_data);
576   int64_t offset = 0;
577   size_t dst_size = 0;
578   size_t src_size = 0;
579   void *dst_data = nullptr;
580   void *src_data = nullptr;
581   for (size_t i = 0; i < lengths.size(); i++) {
582     if (i != grad_index && i != indice_index) {
583       size_t data_size = lengths[i] * sizeof(float);
584       dst_size = data_size;
585       src_size = data_size;
586       dst_data = reduced_data->data() + offset;
587       src_data = const_cast<float *>(original_data) + offset;
588       MS_EXCEPTION_IF_NULL(dst_data);
589       MS_EXCEPTION_IF_NULL(src_data);
590       auto mem_ret = memcpy_s(dst_data, dst_size, src_data, src_size);
591       if (mem_ret != 0) {
592         MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << mem_ret << ")";
593         return;
594       }
595     }
596     offset += lengths[i];
597   }
598 
599   // Fill the reduced gradient
600   int64_t grad_offset = 0;
601   for (size_t i = 0; i < grad_index; i++) {
602     grad_offset += lengths[i];
603   }
604   size_t data_size = lengths[grad_index] * sizeof(float);
605   dst_size = data_size;
606   src_size = data_size;
607   dst_data = reduced_data->data() + grad_offset;
608   src_data = const_cast<float *>(grads);
609   MS_EXCEPTION_IF_NULL(dst_data);
610   auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
611   if (ret != 0) {
612     MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
613     return;
614   }
615 
616   // Fill the reduced indice
617   int64_t indice_offset = grad_offset + lengths[grad_index];
618   data_size = lengths[indice_index] * sizeof(float);
619   float *indice_data = reduced_data->data() + indice_offset;
620   dst_size = data_size;
621   src_size = data_size;
622   dst_data = indice_data;
623   src_data = indices;
624   MS_EXCEPTION_IF_NULL(dst_data);
625   MS_EXCEPTION_IF_NULL(src_data);
626   ret = memcpy_s(dst_data, dst_size, src_data, src_size);
627   if (ret != 0) {
628     MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
629     return;
630   }
631 }
632 
PushData(const std::vector<Key> & keys,const std::vector<float> & vals,const std::vector<int> & lens,int cmd,int64_t)633 void Worker::PushData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
634                       int cmd, int64_t) {
635   KVMessage kvs;
636   *kvs.mutable_keys() = {keys.begin(), keys.end()};
637   *kvs.mutable_values() = {vals.begin(), vals.end()};
638   *kvs.mutable_len() = {lens.begin(), lens.end()};
639   MS_LOG(INFO) << "the result is:" << embedding_table_ranges_.count(keys[0]);
640   if (embedding_table_ranges_.count(keys[0])) {
641     if (cmd == kInitWeightsCmd) {
642       SendForPush(cmd, kvs, worker_init_embedding_partitioner_, {});
643     } else {
644       std::string kv_data = kvs.SerializeAsString();
645       std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
646       size_t dest_size = kv_data.length();
647       int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
648       if (ret != 0) {
649         MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
650         return;
651       }
652       worker_node_.Broadcast(core::NodeRole::SERVER, res, kv_data.length(), cmd);
653     }
654   } else {
655     SendForPush(cmd, kvs, round_robin_partitioner_, {});
656   }
657 }
658 
PushSparseData(const std::vector<Key> & keys,const std::vector<float> & vals,const std::vector<int> & lens,size_t grad_index,size_t indice_index,size_t first_dim_size,size_t outer_dim_size)659 void Worker::PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
660                             size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size) {
661   KVMessage kvs;
662   *kvs.mutable_keys() = {keys.begin(), keys.end()};
663   *kvs.mutable_values() = {vals.begin(), vals.end()};
664   *kvs.mutable_len() = {lens.begin(), lens.end()};
665   if (embedding_table_ranges_.count(keys[0])) {
666     std::map<int64_t, int64_t> attrs{{0, grad_index}, {1, indice_index}, {2, first_dim_size}, {3, outer_dim_size}};
667     SendForPush(kPushCmd, kvs, sparse_partitioner_, attrs);
668   } else {
669     SendForPush(kPushCmd, kvs, round_robin_partitioner_, {});
670   }
671 }
672 
PullData(const std::vector<Key> & keys,std::vector<float> * const vals,std::vector<int> * lens,int cmd,int64_t priority)673 void Worker::PullData(const std::vector<Key> &keys, std::vector<float> *const vals, std::vector<int> *lens, int cmd,
674                       int64_t priority) {
675   MS_EXCEPTION_IF_NULL(vals);
676   KVMessage kvs;
677   *kvs.mutable_keys() = {keys.begin(), keys.end()};
678   if (embedding_table_ranges_.count(keys[0])) {
679     SendForPull(cmd, kvs, broadcast_partitioner_, {}, vals, lens);
680   } else {
681     SendForPull(cmd, kvs, round_robin_partitioner_, {}, vals, lens);
682   }
683 }
684 
LookupIdPartitioner(const EmbeddingTableLookup & send,PartitionEmbeddingMessages * partition,const std::map<int64_t,int64_t> &)685 void Worker::LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition,
686                                  const std::map<int64_t, int64_t> &) {
687   MS_EXCEPTION_IF_NULL(partition);
688 
689   const Key &key = send.key();
690   const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[key]);
691   partition->resize(ranges.size());
692 
693   for (size_t i = 0; i < ranges.size(); i++) {
694     const EmbeddingTableShardMetadata &range = ranges[i];
695     const auto &begin = range.begin();
696     const auto &end = range.end();
697     std::unordered_set<int32_t> unique_ids;
698     auto &kvs = partition->at(i).second;
699 
700     kvs.set_key(key);
701 
702     std::for_each(send.keys().begin(), send.keys().end(), [&](int32_t lookup_id) {
703       if (lookup_id >= SizeToInt(begin) && lookup_id <= SizeToInt(end)) {
704         unique_ids.insert(lookup_id);
705       }
706     });
707     MS_LOG(DEBUG) << "The unique ids size is:" << unique_ids.size();
708 
709     for (const auto &lookup_id : unique_ids) {
710       kvs.add_keys(lookup_id);
711       kvs.add_values(0.0f);
712     }
713 
714     if (kvs.keys().empty()) {
715       partition->at(i).first = false;
716     } else {
717       partition->at(i).first = true;
718     }
719   }
720 }
721 
SparsePartitioner(const KVMessage & send,PartitionKVMessages * partition,const std::map<int64_t,int64_t> & attrs)722 void Worker::SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition,
723                                const std::map<int64_t, int64_t> &attrs) {
724   MS_EXCEPTION_IF_NULL(partition);
725   // Init variables
726   float *data = const_cast<float *>(send.values().data());
727 
728   if (attrs.count(kGradIndex) == 0 || attrs.count(kIndiceIndex) == 0 || attrs.count(kFirstDimSize) == 0 ||
729       attrs.count(kOutDimSize) == 0) {
730     MS_LOG(EXCEPTION) << "Invalid attrs keys";
731   }
732   auto iter = attrs.find(kGradIndex);
733   size_t grad_index = static_cast<size_t>(iter->second);
734   iter = attrs.find(kIndiceIndex);
735   size_t indice_index = static_cast<size_t>(iter->second);
736   iter = attrs.find(kFirstDimSize);
737   size_t first_dim_size = static_cast<size_t>(iter->second);
738   iter = attrs.find(kOutDimSize);
739   size_t outer_dim_size = static_cast<size_t>(iter->second);
740 
741   size_t grad_size = send.len()[SizeToInt(grad_index)];
742   size_t indice_size = send.len()[SizeToInt(indice_index)];
743   size_t segment_size = grad_size / indice_size;
744 
745   size_t grad_offset = 0;
746   size_t indice_offset = 0;
747   for (size_t i = 0; i < grad_index; i++) {
748     grad_offset += send.len()[i];
749   }
750   for (size_t j = 0; j < indice_index; j++) {
751     indice_offset += send.len()[j];
752   }
753 
754   float *grad_data = data + grad_offset;
755   void *indice_data_temp = data + indice_offset;
756   int *indice_data = reinterpret_cast<int *>(indice_data_temp);
757 
758   // Build the mappings of indice to gradient
759   std::vector<std::pair<int, float *>> indice_to_grads;
760   for (size_t i = 0; i < indice_size; i++) {
761     int indice = indice_data[i];
762     float *grad = grad_data + i * segment_size;
763     indice_to_grads.push_back(std::make_pair(indice, grad));
764   }
765 
766   const Key &key = send.keys()[0];
767   const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[key]);
768   partition->resize(ranges.size());
769 
770   // Construct reduced sparse data for each server
771   for (size_t i = 0; i < ranges.size(); i++) {
772     const EmbeddingTableShardMetadata &range = ranges[i];
773     const auto &begin = range.begin();
774     const auto &end = range.end();
775     auto &kvs = partition->at(i).second;
776     *kvs.mutable_keys() = {send.keys().begin(), send.keys().end()};
777     *kvs.mutable_len() = {send.len().begin(), send.len().end()};
778 
779     // Prepare the sparse gradient and indice
780     std::vector<int> indice_ids;
781     std::unordered_set<int> distinct_ids;
782     for (size_t j = 0; j < indice_size; j++) {
783       size_t indice = static_cast<size_t>(indice_data[j]);
784       if (indice >= begin && indice <= end) {
785         indice_ids.push_back(indice);
786         distinct_ids.insert(indice);
787       }
788     }
789     size_t indices_size = indice_ids.size();
790     if (indices_size > 0) {
791       size_t partition_segment_size = indices_size * segment_size;
792       std::vector<float> src_grad_data(partition_segment_size);
793       std::vector<int> src_indice_data(indices_size);
794       PrepareSparseGradient(begin, end, distinct_ids, indice_to_grads, indice_data, segment_size, src_grad_data.data(),
795                             src_indice_data.data());
796 
797       // Reduce the sparse gradient and indice
798       std::vector<float> new_grad(partition_segment_size);
799       std::vector<int> new_indices(indices_size);
800       mindspore::kernel::SparseGradient<int> unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size});
801       Util::ReduceSparseGradient(src_grad_data.data(), src_indice_data.data(), indices_size, segment_size,
802                                  first_dim_size, outer_dim_size, &unique_sparse_grad);
803 
804       // Update the length of reduce sparse gradient and indice
805       std::vector<int> reduced_lens = {kvs.len().begin(), kvs.len().end()};
806       reduced_lens[grad_index] = unique_sparse_grad.indices_size_ * segment_size;
807       reduced_lens[indice_index] = unique_sparse_grad.indices_size_;
808 
809       // Build the sparse value to be sent
810       size_t total_size = std::accumulate(reduced_lens.begin(), reduced_lens.end(), 0, std::plus<int>());
811       std::vector<float> reduced_data(total_size, 0);
812       BuildSparseValue(reduced_lens, grad_index, indice_index, data, unique_sparse_grad.value_,
813                        unique_sparse_grad.indices_, &reduced_data);
814 
815       *kvs.mutable_len() = {reduced_lens.begin(), reduced_lens.end()};
816       *kvs.mutable_values() = {reduced_data.begin(), reduced_data.end()};
817     }
818 
819     if (indices_size == 0) {
820       std::vector<float> no_keys;
821       std::vector<float> no_vals;
822       std::vector<float> no_lens;
823       no_keys.push_back(key);
824       no_vals.push_back(kGradValue);
825       *kvs.mutable_values() = {no_vals.begin(), no_vals.end()};
826       *kvs.mutable_len() = {no_lens.begin(), no_lens.end()};
827     }
828     partition->at(i).first = true;
829   }
830 }
831 
RoundRobinPartitioner(const KVMessage & send,PartitionKVMessages * partition,const std::map<int64_t,int64_t> &)832 void Worker::RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition,
833                                    const std::map<int64_t, int64_t> &) {
834   MS_EXCEPTION_IF_NULL(partition);
835   partition->resize(LongToSize(server_num_));
836   auto keys = send.keys();
837   auto values = send.values();
838   auto lens = send.len();
839   MS_LOG(INFO) << "the key size is:" << send.keys_size() << " the values size is:" << send.values_size()
840                << " the lens:" << send.len_size();
841 
842   size_t len;
843   Key param_key;
844   for (int i = 0; i < send.keys_size(); i++) {
845     param_key = keys[i];
846     int64_t server_id = key_to_server_id_[param_key];
847     if (!partition->at(LongToUlong(server_id)).first) {
848       partition->at(LongToUlong(server_id)).first = true;
849     }
850 
851     KVMessage &server_kv_pairs = partition->at(LongToUlong(server_id)).second;
852     server_kv_pairs.add_keys(param_key);
853     if (values.empty()) {
854       continue;
855     }
856     len = lens[i];
857     int64_t offset = std::accumulate(lens.begin(), lens.begin() + i, 0);
858     auto val_begin = values.begin() + offset;
859     auto val_end = val_begin + len;
860     for (auto it = val_begin; it != val_end; ++it) {
861       server_kv_pairs.add_values(*it);
862     }
863     server_kv_pairs.add_len(len);
864   }
865 }
866 
WorkerInitEmbeddingPartitioner(const KVMessage & send,std::vector<std::pair<bool,KVMessage>> * partition,const std::map<int64_t,int64_t> &)867 void Worker::WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector<std::pair<bool, KVMessage>> *partition,
868                                             const std::map<int64_t, int64_t> &) {
869   MS_EXCEPTION_IF_NULL(partition);
870   partition->resize(LongToSize(server_num_));
871   auto keys = send.keys();
872   auto values = send.values();
873   auto lens = send.len();
874 
875   int32_t col_cnt = lens[0] / embedding_row_cnt_[keys[0]];
876   const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[keys[0]]);
877   for (size_t i = 0; i < ranges.size(); i++) {
878     size_t offset_begin = ranges[i].begin() * col_cnt;
879     size_t offset_end = (ranges[i].end() + 1) * col_cnt;
880     KVMessage kvs;
881     *kvs.mutable_keys() = keys;
882     *kvs.mutable_values() = {values.begin() + offset_begin, values.begin() + offset_end};
883     kvs.add_len(offset_end - offset_begin);
884     partition->at(i).first = true;
885     partition->at(i).second = kvs;
886   }
887 }
UpdateEmbeddingPartitioner(const KVMessage & send,PartitionKVMessages * partition,const std::map<int64_t,int64_t> &)888 void Worker::UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition,
889                                         const std::map<int64_t, int64_t> &) {
890   MS_EXCEPTION_IF_NULL(partition);
891   const float *embedding_vals = send.values().data();
892   const uint64_t *lookup_ids = send.len().data();
893   size_t val_size = IntToSize(send.values_size());
894   size_t id_size = IntToSize(send.len_size());
895   if (id_size == 0) {
896     MS_LOG(EXCEPTION) << "The id size is 0.";
897     return;
898   }
899   size_t embedding_dim = val_size / id_size;
900 
901   const Key &key = send.keys()[0];
902   const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[key]);
903   partition->resize(ranges.size());
904 
905   for (size_t i = 0; i < ranges.size(); i++) {
906     const EmbeddingTableShardMetadata &range = ranges[i];
907     const auto &begin = range.begin();
908     const auto &end = range.end();
909     auto &kvs = partition->at(i).second;
910     kvs.add_keys(key);
911     for (size_t j = 0; j < id_size; j++) {
912       auto lookup_id = lookup_ids[j];
913       if (lookup_id >= begin && lookup_id <= end) {
914         kvs.add_keys(lookup_id);
915         for (size_t k = 0; k < embedding_dim; k++) {
916           kvs.add_values(embedding_vals[j * embedding_dim + k]);
917         }
918       }
919     }
920 
921     if (kvs.keys_size() <= 1) {
922       partition->at(i).first = false;
923     } else {
924       partition->at(i).first = true;
925     }
926   }
927 }
928 
BroadcastPartitioner(const KVMessage & send,PartitionKVMessages * partition,const std::map<int64_t,int64_t> &)929 void Worker::BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition,
930                                   const std::map<int64_t, int64_t> &) {
931   MS_EXCEPTION_IF_NULL(partition);
932   partition->resize(LongToSize(server_num_));
933   for (size_t i = 0; i < LongToSize(server_num_); i++) {
934     partition->at(i).first = true;
935     partition->at(i).second = send;
936   }
937 }
938 
SendForPush(int cmd,const KVMessage & send,const KVPartitioner & partitioner,const std::map<int64_t,int64_t> & attrs)939 void Worker::SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
940                          const std::map<int64_t, int64_t> &attrs) {
941   PartitionKVMessages messages;
942   partitioner(send, &messages, attrs);
943   std::vector<uint32_t> rank_ids;
944   std::vector<DataPtr> data;
945   std::vector<size_t> sizes;
946   for (size_t i = 0; i < messages.size(); i++) {
947     if (messages.at(i).first) {
948       rank_ids.push_back(i);
949       std::string kv_data = messages.at(i).second.SerializeAsString();
950 
951       std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
952       size_t dest_size = kv_data.length();
953       int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
954       if (ret != 0) {
955         MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
956         return;
957       }
958       data.push_back(res);
959       sizes.push_back(kv_data.length());
960     }
961   }
962   worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, cmd);
963 }
964 
SendForPull(int cmd,const KVMessage & send,const KVPartitioner & partitioner,const std::map<int64_t,int64_t> &,std::vector<float> * vals,std::vector<int> * lens)965 void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
966                          const std::map<int64_t, int64_t> &, std::vector<float> *vals, std::vector<int> *lens) {
967   MS_EXCEPTION_IF_NULL(vals);
968   PartitionKVMessages messages;
969   partitioner(send, &messages, {});
970   std::vector<uint32_t> rank_ids;
971   std::vector<DataPtr> data;
972   std::vector<size_t> sizes;
973   for (size_t i = 0; i < messages.size(); i++) {
974     if (messages.at(i).first) {
975       rank_ids.push_back(i);
976       std::string kv_data = messages.at(i).second.SerializeAsString();
977 
978       std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
979       size_t dest_size = kv_data.length();
980       int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
981       if (ret != 0) {
982         MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
983         return;
984       }
985       data.push_back(res);
986       sizes.push_back(kv_data.length());
987     }
988   }
989   std::vector<VectorPtr> resp;
990   worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, cmd, &resp);
991   vals->clear();
992   for (size_t i = 0; i < resp.size(); ++i) {
993     KVMessage message;
994     CHECK_RETURN_TYPE(message.ParseFromArray(resp.at(i)->data(), SizeToInt(resp.at(i)->size())));
995     std::copy(message.values().begin(), message.values().end(), std::back_inserter(*vals));
996 
997     if (lens) {
998       lens->clear();
999       std::copy(message.len().begin(), message.len().end(), std::back_inserter(*lens));
1000     }
1001   }
1002 }
1003 }  // namespace ps
1004 }  // namespace mindspore
1005