• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include "ps/ps_cache/ps_cache_manager.h"
19 #include "utils/log_adapter.h"
20 #include "utils/ms_utils.h"
21 
22 using mindspore::kernel::Address;
23 namespace mindspore {
24 namespace ps {
InsertHashTableSize(const std::string & param_name,size_t cache_vocab_size,size_t embedding_size,size_t vocab_size)25 void PsCacheManager::InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
26                                          size_t vocab_size) {
27   if (cache_vocab_size == 0 || embedding_size == 0 || vocab_size == 0) {
28     MS_LOG(EXCEPTION) << "The size of hash table can not equal to zero.";
29   }
30   hash_tables_[param_name].cache_vocab_size = cache_vocab_size;
31   hash_tables_[param_name].host_cache_vocab_size = cache_vocab_size * kHostCacheScaleFactor;
32   hash_tables_[param_name].embedding_size = embedding_size;
33   hash_tables_[param_name].vocab_size = vocab_size;
34 
35   if (vocab_size_ == 0) {
36     vocab_size_ = vocab_size;
37   }
38   if (vocab_cache_size_ == 0) {
39     vocab_cache_size_ = cache_vocab_size;
40   }
41   if (host_vocab_cache_size_ == 0) {
42     host_vocab_cache_size_ = cache_vocab_size * kHostCacheScaleFactor;
43   }
44 }
45 
ReInsertHashTableSize(const std::string & new_param_name,const std::string & cur_param_name,size_t cache_vocab_size,size_t embedding_size)46 void PsCacheManager::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
47                                            size_t cache_vocab_size, size_t embedding_size) {
48   if (cache_vocab_size == 0 || embedding_size == 0) {
49     MS_LOG(EXCEPTION) << "The size of hash table can not equal to zero.";
50   }
51   if (new_param_name.empty() || cur_param_name.empty()) {
52     MS_LOG(EXCEPTION) << "Parameter name can not be empty.";
53   }
54   if (new_param_name == cur_param_name) {
55     return;
56   }
57   auto iter = hash_tables_.find(cur_param_name);
58   if (iter != hash_tables_.end()) {
59     hash_tables_.emplace(new_param_name, iter->second);
60     hash_tables_.erase(iter);
61   } else {
62     hash_tables_[new_param_name].cache_vocab_size = cache_vocab_size;
63     hash_tables_[new_param_name].embedding_size = embedding_size;
64   }
65 }
66 
InsertWeightInitInfo(const std::string & param_name,size_t global_seed,size_t op_seed)67 void PsCacheManager::InsertWeightInitInfo(const std::string &param_name, size_t global_seed, size_t op_seed) {
68   auto iter = hash_tables_.find(param_name);
69   if (iter == hash_tables_.end()) {
70     MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table.";
71   }
72   auto &hash_table_info = iter->second;
73   if (hash_table_info.param_init_info_.param_type_ != kUnKnown) {
74     return;
75   }
76   MS_LOG(INFO) << "Insert embedding table init info:" << param_name << ", global seed:" << global_seed
77                << ", op seed:" << op_seed;
78   hash_table_info.param_init_info_.param_name_ = param_name;
79   hash_table_info.param_init_info_.param_type_ = kWeight;
80   hash_table_info.param_init_info_.global_seed_ = global_seed;
81   hash_table_info.param_init_info_.op_seed_ = op_seed;
82   if (CheckFinishInsertInitInfo()) {
83     finish_insert_init_info_ = true;
84     insert_init_info_.notify_one();
85   }
86 }
87 
InsertAccumuInitInfo(const std::string & param_name,float init_val)88 void PsCacheManager::InsertAccumuInitInfo(const std::string &param_name, float init_val) {
89   auto iter = hash_tables_.find(param_name);
90   if (iter == hash_tables_.end()) {
91     MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table.";
92   }
93   auto &hash_table_info = iter->second;
94   if (hash_table_info.param_init_info_.param_type_ != kUnKnown) {
95     return;
96   }
97   MS_LOG(INFO) << "Insert accumulation init info:" << param_name << ", init value:" << init_val;
98   hash_table_info.param_init_info_.param_name_ = param_name;
99   hash_table_info.param_init_info_.param_type_ = kAccumulation;
100   hash_table_info.param_init_info_.init_val_ = init_val;
101   if (CheckFinishInsertInitInfo()) {
102     finish_insert_init_info_ = true;
103     insert_init_info_.notify_one();
104   }
105 }
106 
CheckFinishInsertInitInfo() const107 bool PsCacheManager::CheckFinishInsertInitInfo() const {
108   for (const auto &item : hash_tables_) {
109     const auto &hash_table_info = item.second;
110     const auto &param_init_info = hash_table_info.param_init_info_;
111     if (param_init_info.param_type_ == kUnKnown) {
112       return false;
113     }
114   }
115   MS_LOG(INFO) << "Finish inserting embedding table init info.";
116   return true;
117 }
118 
CloneHashTable(const std::string & dest_param_name,const std::string & src_param_name)119 void PsCacheManager::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) {
120   if (dest_param_name == src_param_name) {
121     MS_LOG(INFO) << "The dest_param_name is same as src_param_name";
122     return;
123   }
124   auto iter = hash_tables_.find(src_param_name);
125   if (iter == hash_tables_.end()) {
126     MS_LOG(EXCEPTION) << "The source hash table[" << src_param_name << "] does not exist, clone failed.";
127   }
128   hash_tables_.emplace(dest_param_name, iter->second);
129 }
130 
QueryHashTableAddr(const std::string & param_name) const131 const Address &PsCacheManager::QueryHashTableAddr(const std::string &param_name) const {
132   auto iter = hash_tables_.find(param_name);
133   if (iter == hash_tables_.end()) {
134     MS_LOG(EXCEPTION) << "Can not find device address of " << param_name;
135   }
136   return iter->second.device_address;
137 }
138 
QueryHashTableSize(const std::string & param_name) const139 const size_t &PsCacheManager::QueryHashTableSize(const std::string &param_name) const {
140   auto iter = hash_tables_.find(param_name);
141   if (iter == hash_tables_.end()) {
142     MS_LOG(EXCEPTION) << "Can not find vocab cache size of " << param_name;
143   }
144   return iter->second.cache_vocab_size;
145 }
146 
Initialize()147 void PsCacheManager::Initialize() {
148   MS_LOG(INFO) << "PS cache initialize.";
149   if (!Worker::GetInstance().running()) {
150     Worker::GetInstance().Run();
151   }
152   embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, vocab_cache_size_);
153   MS_ERROR_IF_NULL_WO_RET_VAL(embedding_device_cache_);
154   embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_vocab_cache_size_);
155   MS_ERROR_IF_NULL_WO_RET_VAL(embedding_host_cache_);
156   AddEmbeddingTable();
157   AllocMemForHashTable();
158   SetLocalIdRank();
159   DumpHashTables();
160   initialized_ps_cache_ = true;
161 }
162 
AddEmbeddingTable() const163 void PsCacheManager::AddEmbeddingTable() const {
164   for (const auto &item : hash_tables_) {
165     const auto &param_name = item.first;
166     size_t key = Worker::GetInstance().SetParamKey(param_name);
167     size_t row_count = item.second.vocab_size;
168     // if worker role
169     Worker::GetInstance().AddEmbeddingTable(key, row_count);
170   }
171 }
172 
InitParameterServer()173 void PsCacheManager::InitParameterServer() {
174   MS_LOG(INFO) << "PS embedding cache table init begin:" << finish_insert_init_info_;
175   std::unique_lock<std::mutex> locker(data_mutex_);
176   insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true || running_ == false; });
177   if (!running_) {
178     return;
179   }
180   for (const auto &item : hash_tables_) {
181     const auto &param_name = item.first;
182     size_t key = Worker::GetInstance().SetParamKey(param_name);
183     const auto &hash_table_info = item.second;
184     const auto &param_init_info = hash_table_info.param_init_info_;
185 
186     std::vector<size_t> input_shape = {item.second.vocab_size, item.second.embedding_size};
187     std::vector<size_t> indices_shape = {1, 1};
188     std::vector<size_t> output_shape = {1, 1, 1};
189     ParamInitInfoMessage info;
190     info.set_param_name(param_name);
191     info.set_param_type(param_init_info.param_type_);
192     info.set_init_val(param_init_info.init_val_);
193     info.set_global_seed(param_init_info.global_seed_);
194     info.set_op_seed(param_init_info.op_seed_);
195     // if worker role
196     Worker::GetInstance().InitPSEmbeddingTable(key, input_shape, indices_shape, output_shape, info);
197   }
198 
199   finish_init_parameter_server_ = true;
200   data_prase_.notify_one();
201   MS_LOG(INFO) << "PS embedding cache table init end.";
202 }
203 
InitDataChannel()204 void PsCacheManager::InitDataChannel() {
205   MS_LOG(INFO) << "PS embedding cache data channel init begin.";
206   auto channel = channel_name();
207   if (channel.empty()) {
208     std::unique_lock<std::mutex> locker(data_mutex_);
209     data_prase_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; });
210     if (!running_) {
211       return;
212     }
213   }
214   MS_LOG(INFO) << "PS embedding cache data channel  init end.";
215 }
216 
AllocMemForHashTable()217 void PsCacheManager::AllocMemForHashTable() {
218   MS_EXCEPTION_IF_NULL(embedding_device_cache_);
219   MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
220   size_t max_embedding_size = 0;
221   for (auto &item : hash_tables_) {
222     size_t embedding_size = item.second.embedding_size;
223     auto &device_address = item.second.device_address;
224     device_address.size = vocab_cache_size_ * embedding_size * sizeof(float);
225     auto addr = embedding_device_cache_->cache_->MallocMemory(device_address.size);
226     MS_EXCEPTION_IF_NULL(addr);
227     device_address.addr = addr;
228 
229     auto &host_address = item.second.host_address;
230     std::unique_ptr<float[]> host_hash_table_addr = std::make_unique<float[]>(host_vocab_cache_size_ * embedding_size);
231     MS_EXCEPTION_IF_NULL(host_hash_table_addr);
232     host_address = std::move(host_hash_table_addr);
233     MS_EXCEPTION_IF_NULL(host_address);
234 
235     max_embedding_size = (embedding_size > max_embedding_size) ? embedding_size : max_embedding_size;
236   }
237   embedding_device_cache_->hash_swap_index_addr_ =
238     reinterpret_cast<int *>(embedding_device_cache_->cache_->MallocMemory(batch_elements_ * sizeof(int)));
239   MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_index_addr_);
240   embedding_device_cache_->hash_swap_value_addr_ = reinterpret_cast<float *>(
241     embedding_device_cache_->cache_->MallocMemory(max_embedding_size * batch_elements_ * sizeof(float)));
242   MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_);
243 }
244 
SetLocalIdRank()245 void PsCacheManager::SetLocalIdRank() {
246   auto worker_num = PSContext::instance()->initial_worker_num();
247   if (worker_num > 0) {
248     auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num));
249     vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_);
250     emb_table_slice_bounds_.first = local_shard_size * rank_id_;
251     emb_table_slice_bounds_.second = std::min(emb_table_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_));
252     cache_indices_bounds_.first = SizeToInt(vocab_cache_size_) * rank_id_;
253     cache_indices_bounds_.second = cache_indices_bounds_.first + SizeToInt(vocab_cache_size_);
254     MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id_
255                  << ", id begin:" << emb_table_slice_bounds_.first << ", id end:" << emb_table_slice_bounds_.second
256                  << ", cache indices begin: " << cache_indices_bounds_.first
257                  << ", cache indices end: " << cache_indices_bounds_.second
258                  << ", vocab_cache_size_diff: " << vocab_cache_size_diff_;
259   }
260 }
261 
cache_indices_lower_bound() const262 int PsCacheManager::cache_indices_lower_bound() const { return cache_indices_bounds_.first; }
263 
channel_name()264 std::string PsCacheManager::channel_name() {
265   std::lock_guard<std::mutex> locker(channel_mutex_);
266   return channel_name_;
267 }
268 
set_channel_name(const std::string channel_name)269 void PsCacheManager::set_channel_name(const std::string channel_name) {
270   if (channel_name_ == channel_name) {
271     return;
272   }
273   std::lock_guard<std::mutex> locker(channel_mutex_);
274   channel_name_ = channel_name;
275 }
276 
IncreaseStep()277 bool PsCacheManager::IncreaseStep() {
278   if (data_step_ >= UINT64_MAX) {
279     MS_LOG(ERROR) << "The data step (" << data_step_ << ") will exceed the maximum value of uint64_t.";
280     return false;
281   }
282   data_step_++;
283   set_current_graph_step();
284   if (graph_running_step_ > data_step_) {
285     MS_LOG(ERROR) << "The graph running step (" << graph_running_step_ << ") exceed the data step (" << data_step_
286                   << ").";
287     return false;
288   }
289   return true;
290 }
291 
IncreaseGraphStep(const std::string & channel_name)292 void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
293   if (!running_) {
294     MS_LOG(EXCEPTION) << "PS embedding cache data processing thread isn't running.";
295   }
296   if (graph_step_ >= UINT64_MAX) {
297     MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t.";
298   }
299   if (graph_step_ == 0) {
300     MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameter_server_;
301     std::unique_lock<std::mutex> locker(data_mutex_);
302     data_prase_.wait(locker, [this] { return ((finish_init_parameter_server_ == true) || (running_ == false)); });
303     if (!running_) {
304       MS_LOG(EXCEPTION) << "PS embedding cache data processing thread isn't running.";
305     }
306     MS_LOG(INFO) << "Graph running waiting embedding table init end.";
307   }
308   graph_step_++;
309   set_channel_name(channel_name);
310   if (!PsDataPrefetch::GetInstance().TryWakeChannel(channel_name)) {
311     MS_LOG(EXCEPTION) << "TryWakeChannel failed, channel name: " << channel_name;
312   }
313   data_prase_.notify_one();
314 }
315 
DoProcessData(uint32_t device_id,const void * context)316 void PsCacheManager::DoProcessData(uint32_t device_id, const void *context) {
317   // PS embeddingLookup cache check.
318   if (!initialized_ps_cache_) {
319     MS_LOG(EXCEPTION) << "Only the sink_mode of dataset supports embeddingLookup cache in parameter server training "
320                          "mode, current dataset mode is not sink_mode.";
321   }
322   process_data_thread_ = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context);
323 }
324 
ProcessDataTask(uint32_t device_id,const void * context)325 void PsCacheManager::ProcessDataTask(uint32_t device_id, const void *context) {
326   MS_LOG(INFO) << "PS embedding cache process data task begin.";
327   running_ = true;
328   MS_ERROR_IF_NULL_WO_RET_VAL(embedding_device_cache_);
329   MS_ERROR_IF_NULL_WO_RET_VAL(embedding_device_cache_->cache_);
330   embedding_device_cache_->cache_->InitDevice(device_id, context);
331 
332   // MallocConstantMemory need stream on device Ascend, should be called after InitDevice.
333   if (!(embedding_device_cache_->cache_->MallocConstantMemory(vocab_cache_size_))) {
334     MS_LOG(ERROR) << "MallocConstantMemory failed.";
335     running_ = false;
336     return;
337   }
338 
339   InitParameterServer();
340   InitDataChannel();
341   while (running_) {
342     if (!ProcessData()) {
343       running_ = false;
344     }
345   }
346   MS_LOG(INFO) << "PS embedding cache process data task end.";
347 }
348 
Finalize()349 void PsCacheManager::Finalize() {
350   if (running_) {
351     SyncEmbeddingTable();
352   }
353   running_ = false;
354   PsDataPrefetch::GetInstance().NotifyFinalize();
355   insert_init_info_.notify_all();
356   data_prase_.notify_all();
357   if (process_data_thread_.joinable()) {
358     process_data_thread_.join();
359   }
360 }
361 
ProcessData()362 bool PsCacheManager::ProcessData() {
363   struct timeval start_time, end_time;
364   const uint64_t kUSecondInSecond = 1000000;
365   (void)gettimeofday(&start_time, nullptr);
366   void *data = nullptr;
367   if (!PsDataPrefetch::GetInstance().QueryData(channel_name_, &data)) {
368     return false;
369   }
370   if (data == nullptr) {
371     MS_LOG(INFO) << "No data process, channel name:" << channel_name_;
372     std::unique_lock<std::mutex> locker(data_mutex_);
373     const int64_t longest_time_to_wait = 100;
374     (void)data_prase_.wait_for(locker, std::chrono::milliseconds(longest_time_to_wait));
375     return true;
376   }
377   RETURN_IF_FALSE(IncreaseStep());
378   auto data_size = PsDataPrefetch::GetInstance().data_size(channel_name_);
379   if (data_size == 0) {
380     MS_LOG(ERROR) << "The data_size can not be zero.";
381     return false;
382   }
383   auto batch_ids = reinterpret_cast<int *>(data);
384   auto batch_ids_len = data_size / sizeof(int);
385   std::unique_ptr<int[]> hash_index = std::make_unique<int[]>(batch_ids_len);
386   MS_ERROR_IF_NULL_W_RET_VAL(hash_index, false);
387   if (memset_s(&statistics_info_, sizeof(statistics_info_), 0, sizeof(statistics_info_))) {
388     MS_LOG(ERROR) << "Process data memset failed.";
389     return false;
390   }
391   // Get hash swap in/out index and ids.
392   RETURN_IF_FALSE(ParseData(batch_ids, batch_ids_len, hash_index.get()));
393   DumpStatisticsInfo();
394   if ((device_need_wait_graph_ || host_need_wait_graph_) && (!WaitGraphRun())) {
395     MS_LOG(ERROR) << "Ps cache wait graph finish failed.";
396     return false;
397   }
398   for (const auto &item : hash_tables_) {
399     auto key = Worker::GetInstance().GetParamKey(item.first);
400     auto hash_info = item.second;
401     RETURN_IF_FALSE(HashSwapHostToServer(key, hash_info));
402     RETURN_IF_FALSE(HashSwapDeviceToHost(hash_info));
403     RETURN_IF_FALSE(HashSwapServerToHost(key, hash_info));
404     RETURN_IF_FALSE(HashSwapHostToDevice(hash_info));
405   }
406   size_t dest_len = data_size;
407   // Replace the batch_ids by hash index for getNext-op getting hash index as input.
408   if (memcpy_s(data, dest_len, hash_index.get(), data_size) != EOK) {
409     MS_LOG(ERROR) << "Process data memcpy failed.";
410     return false;
411   }
412   RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream());
413   // Finish the data process and notify data prefetch.
414   RETURN_IF_FALSE(PsDataPrefetch::GetInstance().FinalizeData(channel_name_));
415   (void)gettimeofday(&end_time, nullptr);
416   uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
417   cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
418   const uint64_t milli_second_ratio = 1000;
419   MS_LOG(DEBUG) << "Ps cache completes processing data(data step:" << data_step_
420                 << ",graph step:" << graph_running_step_ << " channel name:" << channel_name_
421                 << ", time cost:" << (cost / milli_second_ratio) << "ms).";
422   return true;
423 }
424 
CheckCacheHitOrOutRangeTask(const int * batch_ids,const size_t batch_ids_len,int * hash_index,bool * in_device,bool * out_range,size_t * hash_hit_count)425 bool PsCacheManager::CheckCacheHitOrOutRangeTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
426                                                  bool *in_device, bool *out_range, size_t *hash_hit_count) {
427   MS_ERROR_IF_NULL(batch_ids);
428   MS_ERROR_IF_NULL(hash_index);
429   MS_ERROR_IF_NULL(in_device);
430   MS_ERROR_IF_NULL(hash_hit_count);
431   MS_ERROR_IF_NULL(embedding_device_cache_);
432   auto &device_hash_map = embedding_device_cache_->device_hash_map_;
433   MS_ERROR_IF_NULL(device_hash_map);
434   const auto &hash_id_to_index = device_hash_map->hash_id_to_index();
435 
436   for (size_t i = 0; i < batch_ids_len; ++i) {
437     if (batch_ids[i] < emb_table_slice_bounds_.first) {
438       hash_index[i] = batch_ids[i] - vocab_cache_size_diff_;
439       out_range[i] = true;
440       continue;
441     }
442     if (batch_ids[i] >= emb_table_slice_bounds_.second) {
443       hash_index[i] = batch_ids[i] + cache_indices_bounds_.second;
444       out_range[i] = true;
445       continue;
446     }
447     auto iter = hash_id_to_index.find(batch_ids[i]);
448     if (iter != hash_id_to_index.end()) {
449       hash_index[i] = iter->second + cache_indices_bounds_.first;
450       if (device_hash_map->hash_step(iter->second) != data_step_) {
451         ++(*hash_hit_count);
452         device_hash_map->set_hash_step(iter->second, data_step_);
453       }
454       in_device[i] = true;
455     }
456   }
457   return true;
458 }
459 
CheckCacheHitOrOutRange(const int * batch_ids,const size_t batch_ids_len,int * hash_index,bool * in_device,bool * out_range)460 bool PsCacheManager::CheckCacheHitOrOutRange(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
461                                              bool *in_device, bool *out_range) {
462   MS_ERROR_IF_NULL(batch_ids);
463   MS_ERROR_IF_NULL(hash_index);
464   MS_ERROR_IF_NULL(in_device);
465   MS_ERROR_IF_NULL(out_range);
466 
467   size_t thread_num = batch_ids_len / kMaxIdsPerThread + 1;
468   thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
469   std::thread threads[kMaxThreadNum];
470   size_t hash_hit_count[kMaxThreadNum] = {0};
471   size_t i = 0;
472   size_t task_offset = 0;
473 
474   for (; i < thread_num; ++i) {
475     if (task_offset >= batch_ids_len) {
476       break;
477     }
478     size_t task_proc_lens = batch_ids_len / thread_num + (i < (batch_ids_len % thread_num) ? 1 : 0);
479     threads[i] =
480       std::thread(&PsCacheManager::CheckCacheHitOrOutRangeTask, this, batch_ids + task_offset, task_proc_lens,
481                   hash_index + task_offset, in_device + task_offset, out_range + task_offset, hash_hit_count + i);
482     task_offset += task_proc_lens;
483   }
484   if (task_offset != batch_ids_len) {
485     MS_LOG(WARNING) << "Ps cache check id in device inadequate, total:" << batch_ids_len << " checked:" << task_offset;
486   }
487 
488   for (size_t j = 0; j < i; j++) {
489     threads[j].join();
490   }
491   for (size_t j = 0; j < i; j++) {
492     statistics_info_.hash_hit_count_ += hash_hit_count[j];
493   }
494   return true;
495 }
496 
ResetEmbeddingHashMap()497 bool PsCacheManager::ResetEmbeddingHashMap() {
498   MS_ERROR_IF_NULL(embedding_device_cache_);
499   const auto &device_hash_map = embedding_device_cache_->device_hash_map_;
500   MS_ERROR_IF_NULL(device_hash_map);
501   MS_ERROR_IF_NULL(embedding_host_cache_);
502   const auto &host_hash_map = embedding_host_cache_->host_hash_map_;
503   MS_ERROR_IF_NULL(host_hash_map);
504   device_hash_map->Reset();
505   host_hash_map->Reset();
506   device_need_wait_graph_ = false;
507   host_need_wait_graph_ = false;
508   return true;
509 }
510 
ParseData(const int * batch_ids,const size_t batch_ids_len,int * hash_index)511 bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) {
512   MS_ERROR_IF_NULL(batch_ids);
513   MS_ERROR_IF_NULL(hash_index);
514   statistics_info_.batch_id_count_ = batch_ids_len;
515   std::unique_ptr<bool[]> in_device = std::make_unique<bool[]>(batch_ids_len);
516   std::unique_ptr<bool[]> out_range = std::make_unique<bool[]>(batch_ids_len);
517   if (memset_s(in_device.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) {
518     MS_LOG(EXCEPTION) << "Initialize in_device array failed.";
519   }
520   if (memset_s(out_range.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) {
521     MS_LOG(EXCEPTION) << "Initialize out_range array failed.";
522   }
523   RETURN_IF_FALSE(CheckCacheHitOrOutRange(batch_ids, batch_ids_len, hash_index, in_device.get(), out_range.get()));
524   RETURN_IF_FALSE(ResetEmbeddingHashMap());
525   for (size_t i = 0; i < batch_ids_len; i++) {
526     if (in_device[i] || out_range[i]) {
527       continue;
528     }
529     bool need_swap_host_to_device = true;
530     bool need_swap_device_to_host = true;
531     int index = INVALID_INDEX_VALUE;
532     RETURN_IF_FALSE(ParseDeviceData(batch_ids[i], &need_swap_device_to_host, &need_swap_host_to_device, &index));
533     hash_index[i] = index + cache_indices_bounds_.first;
534     if (need_swap_host_to_device) {
535       RETURN_IF_FALSE(ParseHostDataHostToDevice(batch_ids[i]));
536     }
537     if (need_swap_device_to_host) {
538       RETURN_IF_FALSE(ParseHostDataDeviceToHost());
539     }
540   }
541   return true;
542 }
543 
WaitGraphRun()544 bool PsCacheManager::WaitGraphRun() {
545   MS_LOG(INFO) << "Hash table has no space to insert new data and retries within 2 minutes.";
546   std::unique_lock<std::mutex> locker(data_mutex_);
547   const int64_t longest_time_to_wait = 120;
548   if (!data_prase_.wait_for(locker, std::chrono::seconds(longest_time_to_wait),
549                             [this] { return graph_step_ > graph_running_step_; })) {
550     MS_LOG(ERROR) << "Ps cache data parse timeout, suggest to enlarge the cache size(graph step:" << graph_step_
551                   << ", graph running step:" << graph_running_step_ << ").";
552     return false;
553   }
554   set_current_graph_step();
555   return true;
556 }
557 
ParseDeviceData(size_t id,bool * need_swap_device_to_host,bool * need_swap_host_to_device,int * hash_index)558 bool PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device,
559                                      int *hash_index) {
560   MS_ERROR_IF_NULL(need_swap_device_to_host);
561   MS_ERROR_IF_NULL(need_swap_host_to_device);
562   MS_ERROR_IF_NULL(hash_index);
563   MS_ERROR_IF_NULL(embedding_device_cache_);
564   auto &device_hash_map = embedding_device_cache_->device_hash_map_;
565   MS_ERROR_IF_NULL(device_hash_map);
566 
567   int index = INVALID_INDEX_VALUE;
568   const auto &hash_id_to_index = device_hash_map->hash_id_to_index();
569   const auto &iter = hash_id_to_index.find(id);
570   if (iter != hash_id_to_index.end()) {
571     *need_swap_device_to_host = false;
572     *need_swap_host_to_device = false;
573     index = iter->second;
574     if (device_hash_map->hash_step(index) != data_step_) {
575       statistics_info_.hash_hit_count_++;
576       device_hash_map->set_hash_step(index, data_step_);
577     }
578   } else {
579     int *device_to_host_index = embedding_device_cache_->device_to_host_index.get();
580     int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get();
581     int *host_to_device_index = embedding_device_cache_->host_to_device_index.get();
582     int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get();
583     MS_ERROR_IF_NULL(host_to_device_index);
584     MS_ERROR_IF_NULL(host_to_device_ids);
585     auto tmp_device_to_host_size = statistics_info_.device_to_host_size_;
586     while (true) {
587       index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_,
588                                          &(statistics_info_.device_to_host_size_), &device_need_wait_graph_);
589       if (index == INVALID_INDEX_VALUE) {
590         if (!WaitGraphRun()) {
591           return false;
592         }
593         continue;
594       }
595       host_to_device_index[statistics_info_.host_to_device_size_] = index;
596       host_to_device_ids[statistics_info_.host_to_device_size_] = id;
597       statistics_info_.host_to_device_size_++;
598       *need_swap_device_to_host = statistics_info_.device_to_host_size_ > tmp_device_to_host_size;
599       break;
600     }
601   }
602   *hash_index = index;
603   return true;
604 }
605 
ParseHostDataHostToDevice(size_t id)606 bool PsCacheManager::ParseHostDataHostToDevice(size_t id) {
607   MS_ERROR_IF_NULL(embedding_host_cache_);
608   int *host_to_device_index = embedding_host_cache_->host_to_device_index.get();
609   MS_ERROR_IF_NULL(host_to_device_index);
610   auto &host_hash_map = embedding_host_cache_->host_hash_map_;
611   MS_ERROR_IF_NULL(host_hash_map);
612 
613   const auto &hash_id_to_index = host_hash_map->hash_id_to_index();
614   const auto &iter = hash_id_to_index.find(id);
615   if (iter != hash_id_to_index.end()) {
616     auto index = iter->second;
617     if (host_hash_map->hash_step(index) != data_step_) {
618       host_hash_map->set_hash_step(index, data_step_);
619     }
620     host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index;
621   } else {
622     int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
623     int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
624     int *server_to_host_index = embedding_host_cache_->server_to_host_index.get();
625     int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get();
626     MS_ERROR_IF_NULL(server_to_host_index);
627     MS_ERROR_IF_NULL(server_to_host_ids);
628     while (true) {
629       auto index =
630         host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, graph_running_step_,
631                                  &statistics_info_.host_to_server_size_, &host_need_wait_graph_);
632       if (index == INVALID_INDEX_VALUE) {
633         RETURN_IF_FALSE(WaitGraphRun());
634         continue;
635       }
636       host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index;
637       server_to_host_index[statistics_info_.server_to_host_size_] = index;
638       server_to_host_ids[statistics_info_.server_to_host_size_++] = id;
639       break;
640     }
641   }
642   return true;
643 }
644 
ParseHostDataDeviceToHost()645 bool PsCacheManager::ParseHostDataDeviceToHost() {
646   MS_ERROR_IF_NULL(embedding_device_cache_);
647   MS_ERROR_IF_NULL(embedding_host_cache_);
648   int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get();
649   int *device_to_host_index = embedding_host_cache_->device_to_host_index.get();
650   MS_ERROR_IF_NULL(device_to_host_ids);
651   MS_ERROR_IF_NULL(device_to_host_index);
652 
653   auto &host_hash_map = embedding_host_cache_->host_hash_map_;
654   MS_ERROR_IF_NULL(host_hash_map);
655   int swap_device_to_host_id = device_to_host_ids[statistics_info_.device_to_host_size_ - 1];
656   const auto &hash_id_to_index = host_hash_map->hash_id_to_index();
657   const auto &iter = hash_id_to_index.find(swap_device_to_host_id);
658   if (iter != hash_id_to_index.end()) {
659     auto index = iter->second;
660     if (host_hash_map->hash_step(index) != data_step_) {
661       host_hash_map->set_hash_step(index, data_step_);
662     }
663     device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index;
664   } else {
665     int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
666     int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
667     while (true) {
668       auto index =
669         host_hash_map->ParseData(swap_device_to_host_id, host_to_server_index, host_to_server_ids, data_step_,
670                                  graph_running_step_, &statistics_info_.host_to_server_size_, &host_need_wait_graph_);
671       if (index == INVALID_INDEX_VALUE) {
672         RETURN_IF_FALSE(WaitGraphRun());
673         continue;
674       }
675       device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index;
676       break;
677     }
678   }
679   return true;
680 }
681 
LookUpTableTask(size_t indices_lens,size_t outer_dim_size,size_t first_dim_size,const float * input_addr,const int * indices_addr,float * output_addr)682 void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size,
683                                      const float *input_addr, const int *indices_addr, float *output_addr) {
684   MS_ERROR_IF_NULL_WO_RET_VAL(input_addr);
685   MS_ERROR_IF_NULL_WO_RET_VAL(indices_addr);
686   MS_ERROR_IF_NULL_WO_RET_VAL(output_addr);
687   auto type_size = sizeof(float);
688   size_t lens = outer_dim_size * type_size;
689   for (size_t i = 0; i < indices_lens; ++i) {
690     int index = indices_addr[i];
691     if (index >= 0 && index < SizeToInt(first_dim_size)) {
692       size_t pos = index * outer_dim_size;
693       auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens);
694       if (ret != EOK) {
695         MS_LOG(ERROR) << "LookUpTable task memcpy failed.";
696         running_ = false;
697         return;
698       }
699     } else {
700       auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens);
701       if (ret != EOK) {
702         MS_LOG(ERROR) << "LookUpTable task memset failed.";
703         running_ = false;
704         return;
705       }
706     }
707     output_addr += outer_dim_size;
708   }
709 }
710 
LookUpHostHashTable(size_t embedding_size,size_t indices_lens,const float * hash_table_addr,const int * indices_addr,float * output_addr)711 bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
712                                          const int *indices_addr, float *output_addr) {
713   MS_ERROR_IF_NULL_W_RET_VAL(hash_table_addr, false);
714   MS_ERROR_IF_NULL_W_RET_VAL(indices_addr, false);
715   MS_ERROR_IF_NULL_W_RET_VAL(output_addr, false);
716   size_t first_dim_size = host_vocab_cache_size_;
717   size_t outer_dim_size = embedding_size;
718 
719   size_t thread_num = indices_lens / kMaxIdsPerThread + 1;
720   thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
721   std::thread threads[kMaxThreadNum];
722   size_t task_proc_lens = (indices_lens + thread_num - 1) / thread_num;
723   size_t i = 0;
724   size_t task_offset = 0;
725   MS_LOG(DEBUG) << "Indices lens: " << indices_lens << ", one task proc lens:" << task_proc_lens;
726   for (; i < thread_num; i++) {
727     if (task_offset >= indices_lens) {
728       break;
729     }
730     MS_LOG(DEBUG) << "Task offset: " << task_offset << ", task process lens:" << task_proc_lens;
731     threads[i] = std::thread(&PsCacheManager::LookUpTableTask, this, task_proc_lens, outer_dim_size, first_dim_size,
732                              hash_table_addr, indices_addr + task_offset, output_addr + task_offset * outer_dim_size);
733     task_offset += task_proc_lens;
734     if (task_offset + task_proc_lens > indices_lens) {
735       task_proc_lens = indices_lens - task_offset;
736     }
737   }
738   for (size_t j = 0; j < i; j++) {
739     threads[j].join();
740   }
741   return running_;
742 }
743 
InsertHostHashTable(size_t embedding_size,size_t insert_indices_size,const int * insert_indices,const float * insert_data,float * hash_table_addr)744 bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, const int *insert_indices,
745                                          const float *insert_data, float *hash_table_addr) {
746   MS_ERROR_IF_NULL_W_RET_VAL(insert_indices, false);
747   MS_ERROR_IF_NULL_W_RET_VAL(insert_data, false);
748   MS_ERROR_IF_NULL_W_RET_VAL(hash_table_addr, false);
749   size_t first_dim_size = host_vocab_cache_size_;
750   size_t thread_num = insert_indices_size / kMaxIdsPerThread + 1;
751   thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
752   std::thread threads[kMaxThreadNum];
753   size_t task_proc_lens = (insert_indices_size + thread_num - 1) / thread_num;
754   size_t i = 0;
755   size_t task_offset = 0;
756 
757   auto insert_hash_table_task = [this](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size,
758                                        const int *insert_indices, const float *insert_data, float *hash_table_addr) {
759     auto type_size = sizeof(float);
760     size_t copy_len = outer_dim_size * type_size;
761     size_t dest_len = copy_len;
762     for (size_t i = 0; i < insert_indices_size; ++i) {
763       int index = insert_indices[i];
764       if (index >= 0 && index < SizeToInt(first_dim_size)) {
765         auto ret =
766           memcpy_s(hash_table_addr + index * outer_dim_size, dest_len, insert_data + i * outer_dim_size, copy_len);
767         if (ret != EOK) {
768           MS_LOG(ERROR) << "Insert hash table task memcpy failed.";
769           running_ = false;
770           return;
771         }
772       }
773     }
774   };
775 
776   for (; i < thread_num; i++) {
777     if (task_offset >= insert_indices_size) {
778       break;
779     }
780     MS_LOG(DEBUG) << "Task offset: " << task_offset << ", task process lens:" << task_proc_lens;
781     threads[i] = std::thread(insert_hash_table_task, task_proc_lens, embedding_size, first_dim_size,
782                              insert_indices + task_offset, insert_data + task_offset * embedding_size, hash_table_addr);
783     task_offset += task_proc_lens;
784     if (task_offset + task_proc_lens > insert_indices_size) {
785       task_proc_lens = insert_indices_size - task_offset;
786     }
787   }
788 
789   for (size_t j = 0; j < i; j++) {
790     threads[j].join();
791   }
792   return running_;
793 }
794 
HashSwapHostToDevice(const HashTableInfo & hash_info)795 bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) {
796   MS_ERROR_IF_NULL(embedding_device_cache_);
797   MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
798   MS_ERROR_IF_NULL(embedding_host_cache_);
799   auto host_cache_host_to_device_index = embedding_host_cache_->host_to_device_index.get();
800   auto device_cache_host_to_device_index = embedding_device_cache_->host_to_device_index.get();
801   auto swap_indices_size = statistics_info_.host_to_device_size_;
802   if (swap_indices_size == 0) {
803     return true;
804   }
805   auto embedding_size = hash_info.embedding_size;
806   MS_ERROR_IF_NULL_W_RET_VAL(hash_info.device_address.addr, false);
807   auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
808   auto cache_vocab_size = hash_info.cache_vocab_size;
809   MS_ERROR_IF_NULL_W_RET_VAL(hash_info.host_address, false);
810   auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
811   auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
812   RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr,
813                                       host_cache_host_to_device_index, swap_out_data.get()));
814   RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(
815     embedding_device_cache_->hash_swap_value_addr_, swap_out_data.get(),
816     swap_indices_size * embedding_size * sizeof(float)));
817   RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_,
818                                                                        device_cache_host_to_device_index,
819                                                                        swap_indices_size * sizeof(int)));
820   RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn(
821     hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
822     cache_vocab_size, embedding_size, swap_indices_size));
823   RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream());
824   return true;
825 }
826 
HashSwapDeviceToHost(const HashTableInfo & hash_info)827 bool PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) {
828   MS_ERROR_IF_NULL(embedding_device_cache_);
829   MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
830   MS_ERROR_IF_NULL(embedding_host_cache_);
831   auto swap_indices_size = statistics_info_.device_to_host_size_;
832   auto device_cache_device_to_host_index = embedding_device_cache_->device_to_host_index.get();
833   auto host_cache_device_to_host_index = embedding_host_cache_->device_to_host_index.get();
834   if (swap_indices_size == 0) {
835     return true;
836   }
837   auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
838   auto cache_vocab_size = hash_info.cache_vocab_size;
839   auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
840   auto embedding_size = hash_info.embedding_size;
841   auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
842   RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_,
843                                                                        device_cache_device_to_host_index,
844                                                                        swap_indices_size * sizeof(int)));
845   RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut(
846     hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
847     cache_vocab_size, embedding_size, swap_indices_size));
848   RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost(
849     swap_out_data.get(), embedding_device_cache_->hash_swap_value_addr_,
850     swap_indices_size * embedding_size * sizeof(float)));
851   RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream());
852   RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index,
853                                       swap_out_data.get(), host_hash_table_addr));
854   return true;
855 }
856 
HashSwapHostToServer(size_t key,const HashTableInfo & hash_info)857 bool PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_info) {
858   MS_ERROR_IF_NULL(embedding_host_cache_);
859   auto host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
860   MS_ERROR_IF_NULL_W_RET_VAL(host_to_server_ids, false);
861   auto host_to_server_index = embedding_host_cache_->host_to_server_index.get();
862   MS_ERROR_IF_NULL_W_RET_VAL(host_to_server_index, false);
863   auto swap_indices_size = statistics_info_.host_to_server_size_;
864   if (swap_indices_size == 0) {
865     return true;
866   }
867   std::vector<int> lookup_ids(swap_indices_size, 0);
868   std::vector<float> swap_out_data;
869   auto embedding_size = hash_info.embedding_size;
870   swap_out_data.resize(swap_indices_size * embedding_size);
871   auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
872   RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index,
873                                       swap_out_data.data()));
874 
875   size_t copy_len = swap_indices_size * sizeof(int);
876   size_t dest_len = copy_len;
877   auto ret = memcpy_s(lookup_ids.data(), dest_len, host_to_server_ids, copy_len);
878   if (ret != EOK) {
879     MS_LOG(ERROR) << "Lookup id memcpy failed.";
880     return false;
881   }
882   Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
883   return true;
884 }
885 
HashSwapServerToHost(size_t key,const HashTableInfo & hash_info)886 bool PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_info) {
887   MS_ERROR_IF_NULL(embedding_host_cache_);
888   auto swap_indices_size = statistics_info_.server_to_host_size_;
889   auto server_to_host_ids = embedding_host_cache_->server_to_host_ids.get();
890   MS_ERROR_IF_NULL_W_RET_VAL(server_to_host_ids, false);
891   auto server_to_host_index = embedding_host_cache_->server_to_host_index.get();
892   MS_ERROR_IF_NULL_W_RET_VAL(server_to_host_index, false);
893   if (swap_indices_size == 0) {
894     return true;
895   }
896   auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
897   MS_ERROR_IF_NULL_W_RET_VAL(host_hash_table_addr, false);
898   auto embedding_size = hash_info.embedding_size;
899   std::vector<float> lookup_result(swap_indices_size * embedding_size, 0);
900   std::vector<int> lookup_ids(swap_indices_size, 0);
901   size_t copy_len = swap_indices_size * sizeof(int);
902   size_t dest_len = copy_len;
903   auto ret = memcpy_s(lookup_ids.data(), dest_len, server_to_host_ids, copy_len);
904   if (ret != EOK) {
905     MS_LOG(ERROR) << "Lookup id memcpy failed.";
906     return false;
907   }
908   Worker::GetInstance().DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
909   RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index,
910                                       lookup_result.data(), host_hash_table_addr));
911   return true;
912 }
913 
HashSwapDeviceOut(int * swap_out_index,std::vector<float> * swap_out_data,const HashTableInfo & hash_info)914 bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, std::vector<float> *swap_out_data,
915                                        const HashTableInfo &hash_info) {
916   MS_ERROR_IF_NULL(swap_out_index);
917   MS_ERROR_IF_NULL(swap_out_data);
918   MS_ERROR_IF_NULL(embedding_device_cache_);
919   MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
920   auto swap_out_index_size = statistics_info_.device_to_host_size_;
921   if (swap_out_index_size == 0) {
922     return true;
923   }
924 
925   MS_ERROR_IF_NULL_W_RET_VAL(hash_info.device_address.addr, false);
926   auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
927   auto cache_vocab_size = hash_info.cache_vocab_size;
928   auto embedding_size = hash_info.embedding_size;
929   swap_out_data->resize(swap_out_index_size * embedding_size);
930   RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(
931     embedding_device_cache_->hash_swap_index_addr_, swap_out_index, swap_out_index_size * sizeof(int)));
932   RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut(
933     hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
934     cache_vocab_size, embedding_size, swap_out_index_size));
935   RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost(
936     swap_out_data->data(), embedding_device_cache_->hash_swap_value_addr_,
937     swap_out_index_size * embedding_size * sizeof(float)));
938   RETURN_IF_FALSE(embedding_device_cache_->cache_->RecordEvent());
939   return true;
940 }
941 
HashSwapDeviceIn(const int * swap_in_ids,const int * swap_in_index,const HashTableInfo & hash_info,size_t key)942 bool PsCacheManager::HashSwapDeviceIn(const int *swap_in_ids, const int *swap_in_index, const HashTableInfo &hash_info,
943                                       size_t key) {
944   MS_ERROR_IF_NULL(swap_in_ids);
945   MS_ERROR_IF_NULL(swap_in_index);
946   MS_ERROR_IF_NULL(embedding_device_cache_);
947   MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
948   auto swap_in_ids_size = statistics_info_.host_to_device_size_;
949   if (swap_in_ids_size == 0) {
950     return true;
951   }
952 
953   MS_ERROR_IF_NULL_W_RET_VAL(hash_info.device_address.addr, false);
954   auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
955   auto cache_vocab_size = hash_info.cache_vocab_size;
956   auto embedding_size = hash_info.embedding_size;
957   // Get id embs by swap_in_ids in host(Pipeline with hash swap-out in device).
958   std::vector<float> lookup_result(swap_in_ids_size * embedding_size, 0);
959   std::vector<int> lookup_ids(swap_in_ids_size, 0);
960   size_t copy_len = swap_in_ids_size * sizeof(int);
961   size_t dest_len = copy_len;
962   auto ret = memcpy_s(lookup_ids.data(), dest_len, swap_in_ids, copy_len);
963   if (ret != EOK) {
964     MS_LOG(ERROR) << "Lookup id memcpy failed.";
965     return false;
966   }
967   Worker::GetInstance().DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
968   // Hash swap-in in device.
969   RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(
970     embedding_device_cache_->hash_swap_value_addr_, lookup_result.data(),
971     swap_in_ids_size * embedding_size * sizeof(float)));
972   RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_,
973                                                                        swap_in_index, swap_in_ids_size * sizeof(int)));
974   RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn(
975     hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
976     cache_vocab_size, embedding_size, swap_in_ids_size));
977   return true;
978 }
979 
UpdataEmbeddingTable(const std::vector<float> & swap_out_data,int * const swap_out_ids,size_t key)980 bool PsCacheManager::UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *const swap_out_ids,
981                                           size_t key) {
982   MS_ERROR_IF_NULL(embedding_device_cache_);
983   MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
984   MS_ERROR_IF_NULL(swap_out_ids);
985   auto swap_out_ids_size = statistics_info_.device_to_host_size_;
986   if (swap_out_ids_size == 0) {
987     return true;
988   }
989   std::vector<int> lookup_ids(swap_out_ids_size, 0);
990   size_t copy_len = swap_out_ids_size * sizeof(int);
991   size_t dest_len = copy_len;
992   auto ret = memcpy_s(lookup_ids.data(), dest_len, swap_out_ids, copy_len);
993   if (ret != EOK) {
994     MS_LOG(ERROR) << "Lookup id memcpy failed.";
995     return false;
996   }
997   // Need synchronize event to ensure that the swap-out in device is completed.
998   RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeEvent());
999   Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
1000   return true;
1001 }
1002 
SyncEmbeddingTable()1003 void PsCacheManager::SyncEmbeddingTable() {
1004   if (finish_embedding_table_sync_) {
1005     return;
1006   }
1007   if (!initialized_ps_cache_) {
1008     return;
1009   }
1010   if (!SyncHostEmbeddingTable()) {
1011     MS_LOG(ERROR) << "SyncHostEmbeddingTable failed.";
1012   }
1013   if (!SyncDeviceEmbeddingTable()) {
1014     MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed.";
1015   }
1016   finish_embedding_table_sync_ = true;
1017 }
1018 
SyncHostEmbeddingTable()1019 bool PsCacheManager::SyncHostEmbeddingTable() {
1020   MS_ERROR_IF_NULL(embedding_host_cache_);
1021   MS_ERROR_IF_NULL(embedding_host_cache_->host_hash_map_);
1022   const auto &hash_id_to_index = embedding_host_cache_->host_hash_map_->hash_id_to_index();
1023   size_t swap_indices_lens = hash_id_to_index.size();
1024   if (swap_indices_lens == 0) {
1025     return true;
1026   }
1027   std::unique_ptr<int[]> host_to_server_ids_ptr = std::make_unique<int[]>(swap_indices_lens);
1028   MS_ERROR_IF_NULL(host_to_server_ids_ptr);
1029   std::unique_ptr<int[]> host_to_server_indices_ptr = std::make_unique<int[]>(swap_indices_lens);
1030   MS_ERROR_IF_NULL(host_to_server_indices_ptr);
1031   size_t idx = 0;
1032   for (const auto &item : hash_id_to_index) {
1033     host_to_server_ids_ptr[idx] = item.first;
1034     host_to_server_indices_ptr[idx++] = item.second;
1035   }
1036   for (const auto &item : hash_tables_) {
1037     const auto &hash_info = item.second;
1038     if (hash_info.param_init_info_.param_type_ != kWeight) {
1039       continue;
1040     }
1041     auto key = Worker::GetInstance().GetParamKey(item.first);
1042     std::vector<int> lookup_ids(swap_indices_lens, 0);
1043     std::vector<float> swap_out_data;
1044     auto embedding_size = hash_info.embedding_size;
1045     swap_out_data.resize(swap_indices_lens * embedding_size);
1046     auto host_hash_table_addr = hash_info.host_address.get();
1047     MS_ERROR_IF_NULL(host_hash_table_addr);
1048     RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_lens, host_hash_table_addr,
1049                                         host_to_server_indices_ptr.get(), swap_out_data.data()));
1050 
1051     size_t copy_len = swap_indices_lens * sizeof(int);
1052     size_t dest_len = copy_len;
1053     auto ret = memcpy_s(lookup_ids.data(), dest_len, host_to_server_ids_ptr.get(), copy_len);
1054     if (ret != EOK) {
1055       MS_LOG(ERROR) << "Lookup id memcpy failed.";
1056       return false;
1057     }
1058     Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
1059   }
1060   return true;
1061 }
1062 
SyncDeviceEmbeddingTable()1063 bool PsCacheManager::SyncDeviceEmbeddingTable() {
1064   MS_ERROR_IF_NULL(embedding_device_cache_);
1065   MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
1066   const auto &device_hash_map = embedding_device_cache_->device_hash_map_;
1067   MS_ERROR_IF_NULL(device_hash_map);
1068   const auto &hash_id_to_index = device_hash_map->hash_id_to_index();
1069   size_t swap_indices_lens = hash_id_to_index.size();
1070   if (swap_indices_lens == 0) {
1071     return true;
1072   }
1073   std::unique_ptr<int[]> device_to_server_ids_ptr = std::make_unique<int[]>(swap_indices_lens);
1074   MS_ERROR_IF_NULL(device_to_server_ids_ptr);
1075   std::unique_ptr<int[]> device_to_server_indices_ptr = std::make_unique<int[]>(swap_indices_lens);
1076   MS_ERROR_IF_NULL(device_to_server_indices_ptr);
1077   size_t idx = 0;
1078   for (const auto &item : hash_id_to_index) {
1079     device_to_server_ids_ptr[idx] = item.first;
1080     device_to_server_indices_ptr[idx++] = item.second;
1081   }
1082   for (const auto &item : hash_tables_) {
1083     const auto &hash_info = item.second;
1084     if (hash_info.param_init_info_.param_type_ != kWeight) {
1085       continue;
1086     }
1087     auto key = Worker::GetInstance().GetParamKey(item.first);
1088     std::vector<int> lookup_ids(swap_indices_lens, 0);
1089     std::vector<float> swap_out_data;
1090     auto embedding_size = hash_info.embedding_size;
1091     swap_out_data.resize(swap_indices_lens * embedding_size);
1092     std::unique_ptr<float[]> device_hash_table_addr_tmp =
1093       std::make_unique<float[]>(device_hash_map->hash_capacity() * embedding_size);
1094     MS_ERROR_IF_NULL(device_hash_table_addr_tmp);
1095 
1096     auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
1097     MS_ERROR_IF_NULL(hash_table_addr);
1098     auto hash_table_size = hash_info.device_address.size;
1099     RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost(device_hash_table_addr_tmp.get(),
1100                                                                          hash_table_addr, hash_table_size));
1101     RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream());
1102     RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_lens, device_hash_table_addr_tmp.get(),
1103                                         device_to_server_indices_ptr.get(), swap_out_data.data()));
1104 
1105     size_t copy_len = swap_indices_lens * sizeof(int);
1106     size_t dest_len = copy_len;
1107     auto ret = memcpy_s(lookup_ids.data(), dest_len, device_to_server_ids_ptr.get(), copy_len);
1108     if (ret != EOK) {
1109       MS_LOG(ERROR) << "Lookup id memcpy failed.";
1110       return false;
1111     }
1112     Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
1113   }
1114   return true;
1115 }
1116 
DumpHashTables(bool dump_device_tables) const1117 void PsCacheManager::DumpHashTables(bool dump_device_tables) const {
1118   MS_EXCEPTION_IF_NULL(embedding_device_cache_);
1119   MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
1120   for (const auto &item : hash_tables_) {
1121     const auto &param_name = item.first;
1122     size_t cache_vocab_size = item.second.cache_vocab_size;
1123     size_t host_cache_vocab_size = item.second.host_cache_vocab_size;
1124     size_t embedding_size = item.second.embedding_size;
1125     size_t vocab_size = item.second.vocab_size;
1126     MS_LOG(INFO) << "Hash table info:"
1127                  << " embedding table name:" << param_name << ", vocab size:" << vocab_size
1128                  << ", embedding size:" << embedding_size << ", device cache size:" << cache_vocab_size
1129                  << ", host cache size:" << host_cache_vocab_size
1130                  << ", device cache address:" << reinterpret_cast<void *>(item.second.device_address.addr)
1131                  << ", host cache address:" << reinterpret_cast<void *>(item.second.host_address.get());
1132     if (dump_device_tables) {
1133       std::unique_ptr<float[]> output = std::make_unique<float[]>(item.second.device_address.size / sizeof(float));
1134       embedding_device_cache_->cache_->CopyDeviceMemToHost(output.get(), item.second.device_address.addr,
1135                                                            item.second.device_address.size);
1136       embedding_device_cache_->cache_->SynchronizeStream();
1137       for (size_t i = 0; i < cache_vocab_size; i++) {
1138         for (size_t j = 0; j < embedding_size; j++) {
1139           std::cout << output[i * embedding_size + j] << " ";
1140         }
1141         std::cout << std::endl;
1142       }
1143       std::cout << std::endl;
1144     }
1145   }
1146 }
1147 
DumpStatisticsInfo(size_t each_print_step)1148 void PsCacheManager::DumpStatisticsInfo(size_t each_print_step) {
1149   // Default each 1000 step prints ps cache hit rate.
1150   const size_t kFloatToPercentSign = 100;
1151   if (data_step_ % each_print_step == 0) {
1152     statistics_info_.batch_id_unique_count_ = statistics_info_.hash_hit_count_ + statistics_info_.host_to_device_size_;
1153     auto repeat_rate = SizeToFloat(statistics_info_.batch_id_count_ - statistics_info_.batch_id_unique_count_) /
1154                        statistics_info_.batch_id_count_;
1155     auto device_hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_;
1156     auto host_hit_rate = SizeToFloat(statistics_info_.batch_id_unique_count_ - statistics_info_.server_to_host_size_) /
1157                          statistics_info_.batch_id_unique_count_;
1158     MS_LOG(INFO) << "PS embedding cache data statistics info(total id num:" << statistics_info_.batch_id_count_
1159                  << ", unique id num:" << statistics_info_.batch_id_unique_count_
1160                  << ", host swap to device num:" << statistics_info_.host_to_device_size_
1161                  << ", device swap to host num:" << statistics_info_.device_to_host_size_
1162                  << ", host swap to server num:" << statistics_info_.host_to_server_size_
1163                  << ", server swap to host num:" << statistics_info_.server_to_host_size_
1164                  << ", data repeat rate:" << (repeat_rate * kFloatToPercentSign)
1165                  << "%, device cache hit rate:" << (device_hit_rate * kFloatToPercentSign)
1166                  << "%, host cache hit rate:" << (host_hit_rate * kFloatToPercentSign) << ").";
1167   }
1168 }
1169 }  // namespace ps
1170 }  // namespace mindspore
1171