1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <algorithm>
18 #include "ps/ps_cache/ps_cache_manager.h"
19 #include "utils/log_adapter.h"
20 #include "utils/ms_utils.h"
21
22 using mindspore::kernel::Address;
23 namespace mindspore {
24 namespace ps {
InsertHashTableSize(const std::string & param_name,size_t cache_vocab_size,size_t embedding_size,size_t vocab_size)25 void PsCacheManager::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size,
26 size_t vocab_size) {
27 if (cache_vocab_size == 0 || embedding_size == 0 || vocab_size == 0) {
28 MS_LOG(EXCEPTION) << "The size of hash table can not equal to zero.";
29 }
30 hash_tables_[param_name].cache_vocab_size = cache_vocab_size;
31 hash_tables_[param_name].host_cache_vocab_size = cache_vocab_size * kHostCacheScaleFactor;
32 hash_tables_[param_name].embedding_size = embedding_size;
33 hash_tables_[param_name].vocab_size = vocab_size;
34
35 if (vocab_size_ == 0) {
36 vocab_size_ = vocab_size;
37 }
38 if (vocab_cache_size_ == 0) {
39 vocab_cache_size_ = cache_vocab_size;
40 }
41 if (host_vocab_cache_size_ == 0) {
42 host_vocab_cache_size_ = cache_vocab_size * kHostCacheScaleFactor;
43 }
44 }
45
ReInsertHashTableSize(const std::string & new_param_name,const std::string & cur_param_name,size_t cache_vocab_size,size_t embedding_size)46 void PsCacheManager::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
47 size_t cache_vocab_size, size_t embedding_size) {
48 if (cache_vocab_size == 0 || embedding_size == 0) {
49 MS_LOG(EXCEPTION) << "The size of hash table can not equal to zero.";
50 }
51 if (new_param_name.empty() || cur_param_name.empty()) {
52 MS_LOG(EXCEPTION) << "Parameter name can not be empty.";
53 }
54 if (new_param_name == cur_param_name) {
55 return;
56 }
57 auto iter = hash_tables_.find(cur_param_name);
58 if (iter != hash_tables_.end()) {
59 hash_tables_.emplace(new_param_name, iter->second);
60 hash_tables_.erase(iter);
61 } else {
62 hash_tables_[new_param_name].cache_vocab_size = cache_vocab_size;
63 hash_tables_[new_param_name].embedding_size = embedding_size;
64 }
65 }
66
InsertWeightInitInfo(const std::string & param_name,size_t global_seed,size_t op_seed)67 void PsCacheManager::InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) {
68 auto iter = hash_tables_.find(param_name);
69 if (iter == hash_tables_.end()) {
70 MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table.";
71 }
72 auto &hash_table_info = iter->second;
73 if (hash_table_info.param_init_info_.param_type_ != kUnKnown) {
74 return;
75 }
76 MS_LOG(INFO) << "Insert embedding table init info:" << param_name << ", global seed:" << global_seed
77 << ", op seed:" << op_seed;
78 hash_table_info.param_init_info_.param_name_ = param_name;
79 hash_table_info.param_init_info_.param_type_ = kWeight;
80 hash_table_info.param_init_info_.global_seed_ = global_seed;
81 hash_table_info.param_init_info_.op_seed_ = op_seed;
82 if (CheckFinishInsertInitInfo()) {
83 finish_insert_init_info_ = true;
84 insert_init_info_.notify_one();
85 }
86 }
87
InsertAccumuInitInfo(const std::string & param_name,float init_val)88 void PsCacheManager::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) {
89 auto iter = hash_tables_.find(param_name);
90 if (iter == hash_tables_.end()) {
91 MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table.";
92 }
93 auto &hash_table_info = iter->second;
94 if (hash_table_info.param_init_info_.param_type_ != kUnKnown) {
95 return;
96 }
97 MS_LOG(INFO) << "Insert accumulation init info:" << param_name << ", init value:" << init_val;
98 hash_table_info.param_init_info_.param_name_ = param_name;
99 hash_table_info.param_init_info_.param_type_ = kAccumulation;
100 hash_table_info.param_init_info_.init_val_ = init_val;
101 if (CheckFinishInsertInitInfo()) {
102 finish_insert_init_info_ = true;
103 insert_init_info_.notify_one();
104 }
105 }
106
CheckFinishInsertInitInfo() const107 bool PsCacheManager::CheckFinishInsertInitInfo() const {
108 for (const auto &item : hash_tables_) {
109 const auto &hash_table_info = item.second;
110 const auto ¶m_init_info = hash_table_info.param_init_info_;
111 if (param_init_info.param_type_ == kUnKnown) {
112 return false;
113 }
114 }
115 MS_LOG(INFO) << "Finish inserting embedding table init info.";
116 return true;
117 }
118
CloneHashTable(const std::string & dest_param_name,const std::string & src_param_name)119 void PsCacheManager::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) {
120 if (dest_param_name == src_param_name) {
121 MS_LOG(INFO) << "The dest_param_name is same as src_param_name";
122 return;
123 }
124 auto iter = hash_tables_.find(src_param_name);
125 if (iter == hash_tables_.end()) {
126 MS_LOG(EXCEPTION) << "The source hash table[" << src_param_name << "] does not exist, clone failed.";
127 }
128 hash_tables_.emplace(dest_param_name, iter->second);
129 }
130
QueryHashTableAddr(const std::string & param_name) const131 const Address &PsCacheManager::QueryHashTableAddr(const std::string ¶m_name) const {
132 auto iter = hash_tables_.find(param_name);
133 if (iter == hash_tables_.end()) {
134 MS_LOG(EXCEPTION) << "Can not find device address of " << param_name;
135 }
136 return iter->second.device_address;
137 }
138
QueryHashTableSize(const std::string & param_name) const139 const size_t &PsCacheManager::QueryHashTableSize(const std::string ¶m_name) const {
140 auto iter = hash_tables_.find(param_name);
141 if (iter == hash_tables_.end()) {
142 MS_LOG(EXCEPTION) << "Can not find vocab cache size of " << param_name;
143 }
144 return iter->second.cache_vocab_size;
145 }
146
Initialize()147 void PsCacheManager::Initialize() {
148 MS_LOG(INFO) << "PS cache initialize.";
149 if (!Worker::GetInstance().running()) {
150 Worker::GetInstance().Run();
151 }
152 embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, vocab_cache_size_);
153 MS_ERROR_IF_NULL_WO_RET_VAL(embedding_device_cache_);
154 embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_vocab_cache_size_);
155 MS_ERROR_IF_NULL_WO_RET_VAL(embedding_host_cache_);
156 AddEmbeddingTable();
157 AllocMemForHashTable();
158 SetLocalIdRank();
159 DumpHashTables();
160 initialized_ps_cache_ = true;
161 }
162
AddEmbeddingTable() const163 void PsCacheManager::AddEmbeddingTable() const {
164 for (const auto &item : hash_tables_) {
165 const auto ¶m_name = item.first;
166 size_t key = Worker::GetInstance().SetParamKey(param_name);
167 size_t row_count = item.second.vocab_size;
168 // if worker role
169 Worker::GetInstance().AddEmbeddingTable(key, row_count);
170 }
171 }
172
InitParameterServer()173 void PsCacheManager::InitParameterServer() {
174 MS_LOG(INFO) << "PS embedding cache table init begin:" << finish_insert_init_info_;
175 std::unique_lock<std::mutex> locker(data_mutex_);
176 insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true || running_ == false; });
177 if (!running_) {
178 return;
179 }
180 for (const auto &item : hash_tables_) {
181 const auto ¶m_name = item.first;
182 size_t key = Worker::GetInstance().SetParamKey(param_name);
183 const auto &hash_table_info = item.second;
184 const auto ¶m_init_info = hash_table_info.param_init_info_;
185
186 std::vector<size_t> input_shape = {item.second.vocab_size, item.second.embedding_size};
187 std::vector<size_t> indices_shape = {1, 1};
188 std::vector<size_t> output_shape = {1, 1, 1};
189 ParamInitInfoMessage info;
190 info.set_param_name(param_name);
191 info.set_param_type(param_init_info.param_type_);
192 info.set_init_val(param_init_info.init_val_);
193 info.set_global_seed(param_init_info.global_seed_);
194 info.set_op_seed(param_init_info.op_seed_);
195 // if worker role
196 Worker::GetInstance().InitPSEmbeddingTable(key, input_shape, indices_shape, output_shape, info);
197 }
198
199 finish_init_parameter_server_ = true;
200 data_prase_.notify_one();
201 MS_LOG(INFO) << "PS embedding cache table init end.";
202 }
203
InitDataChannel()204 void PsCacheManager::InitDataChannel() {
205 MS_LOG(INFO) << "PS embedding cache data channel init begin.";
206 auto channel = channel_name();
207 if (channel.empty()) {
208 std::unique_lock<std::mutex> locker(data_mutex_);
209 data_prase_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; });
210 if (!running_) {
211 return;
212 }
213 }
214 MS_LOG(INFO) << "PS embedding cache data channel init end.";
215 }
216
AllocMemForHashTable()217 void PsCacheManager::AllocMemForHashTable() {
218 MS_EXCEPTION_IF_NULL(embedding_device_cache_);
219 MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
220 size_t max_embedding_size = 0;
221 for (auto &item : hash_tables_) {
222 size_t embedding_size = item.second.embedding_size;
223 auto &device_address = item.second.device_address;
224 device_address.size = vocab_cache_size_ * embedding_size * sizeof(float);
225 auto addr = embedding_device_cache_->cache_->MallocMemory(device_address.size);
226 MS_EXCEPTION_IF_NULL(addr);
227 device_address.addr = addr;
228
229 auto &host_address = item.second.host_address;
230 std::unique_ptr<float[]> host_hash_table_addr = std::make_unique<float[]>(host_vocab_cache_size_ * embedding_size);
231 MS_EXCEPTION_IF_NULL(host_hash_table_addr);
232 host_address = std::move(host_hash_table_addr);
233 MS_EXCEPTION_IF_NULL(host_address);
234
235 max_embedding_size = (embedding_size > max_embedding_size) ? embedding_size : max_embedding_size;
236 }
237 embedding_device_cache_->hash_swap_index_addr_ =
238 reinterpret_cast<int *>(embedding_device_cache_->cache_->MallocMemory(batch_elements_ * sizeof(int)));
239 MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_index_addr_);
240 embedding_device_cache_->hash_swap_value_addr_ = reinterpret_cast<float *>(
241 embedding_device_cache_->cache_->MallocMemory(max_embedding_size * batch_elements_ * sizeof(float)));
242 MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_);
243 }
244
SetLocalIdRank()245 void PsCacheManager::SetLocalIdRank() {
246 auto worker_num = PSContext::instance()->initial_worker_num();
247 if (worker_num > 0) {
248 auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num));
249 vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_);
250 emb_table_slice_bounds_.first = local_shard_size * rank_id_;
251 emb_table_slice_bounds_.second = std::min(emb_table_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_));
252 cache_indices_bounds_.first = SizeToInt(vocab_cache_size_) * rank_id_;
253 cache_indices_bounds_.second = cache_indices_bounds_.first + SizeToInt(vocab_cache_size_);
254 MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id_
255 << ", id begin:" << emb_table_slice_bounds_.first << ", id end:" << emb_table_slice_bounds_.second
256 << ", cache indices begin: " << cache_indices_bounds_.first
257 << ", cache indices end: " << cache_indices_bounds_.second
258 << ", vocab_cache_size_diff: " << vocab_cache_size_diff_;
259 }
260 }
261
cache_indices_lower_bound() const262 int PsCacheManager::cache_indices_lower_bound() const { return cache_indices_bounds_.first; }
263
channel_name()264 std::string PsCacheManager::channel_name() {
265 std::lock_guard<std::mutex> locker(channel_mutex_);
266 return channel_name_;
267 }
268
set_channel_name(const std::string channel_name)269 void PsCacheManager::set_channel_name(const std::string channel_name) {
270 if (channel_name_ == channel_name) {
271 return;
272 }
273 std::lock_guard<std::mutex> locker(channel_mutex_);
274 channel_name_ = channel_name;
275 }
276
IncreaseStep()277 bool PsCacheManager::IncreaseStep() {
278 if (data_step_ >= UINT64_MAX) {
279 MS_LOG(ERROR) << "The data step (" << data_step_ << ") will exceed the maximum value of uint64_t.";
280 return false;
281 }
282 data_step_++;
283 set_current_graph_step();
284 if (graph_running_step_ > data_step_) {
285 MS_LOG(ERROR) << "The graph running step (" << graph_running_step_ << ") exceed the data step (" << data_step_
286 << ").";
287 return false;
288 }
289 return true;
290 }
291
IncreaseGraphStep(const std::string & channel_name)292 void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
293 if (!running_) {
294 MS_LOG(EXCEPTION) << "PS embedding cache data processing thread isn't running.";
295 }
296 if (graph_step_ >= UINT64_MAX) {
297 MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t.";
298 }
299 if (graph_step_ == 0) {
300 MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameter_server_;
301 std::unique_lock<std::mutex> locker(data_mutex_);
302 data_prase_.wait(locker, [this] { return ((finish_init_parameter_server_ == true) || (running_ == false)); });
303 if (!running_) {
304 MS_LOG(EXCEPTION) << "PS embedding cache data processing thread isn't running.";
305 }
306 MS_LOG(INFO) << "Graph running waiting embedding table init end.";
307 }
308 graph_step_++;
309 set_channel_name(channel_name);
310 if (!PsDataPrefetch::GetInstance().TryWakeChannel(channel_name)) {
311 MS_LOG(EXCEPTION) << "TryWakeChannel failed, channel name: " << channel_name;
312 }
313 data_prase_.notify_one();
314 }
315
DoProcessData(uint32_t device_id,const void * context)316 void PsCacheManager::DoProcessData(uint32_t device_id, const void *context) {
317 // PS embeddingLookup cache check.
318 if (!initialized_ps_cache_) {
319 MS_LOG(EXCEPTION) << "Only the sink_mode of dataset supports embeddingLookup cache in parameter server training "
320 "mode, current dataset mode is not sink_mode.";
321 }
322 process_data_thread_ = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context);
323 }
324
ProcessDataTask(uint32_t device_id,const void * context)325 void PsCacheManager::ProcessDataTask(uint32_t device_id, const void *context) {
326 MS_LOG(INFO) << "PS embedding cache process data task begin.";
327 running_ = true;
328 MS_ERROR_IF_NULL_WO_RET_VAL(embedding_device_cache_);
329 MS_ERROR_IF_NULL_WO_RET_VAL(embedding_device_cache_->cache_);
330 embedding_device_cache_->cache_->InitDevice(device_id, context);
331
332 // MallocConstantMemory need stream on device Ascend, should be called after InitDevice.
333 if (!(embedding_device_cache_->cache_->MallocConstantMemory(vocab_cache_size_))) {
334 MS_LOG(ERROR) << "MallocConstantMemory failed.";
335 running_ = false;
336 return;
337 }
338
339 InitParameterServer();
340 InitDataChannel();
341 while (running_) {
342 if (!ProcessData()) {
343 running_ = false;
344 }
345 }
346 MS_LOG(INFO) << "PS embedding cache process data task end.";
347 }
348
Finalize()349 void PsCacheManager::Finalize() {
350 if (running_) {
351 SyncEmbeddingTable();
352 }
353 running_ = false;
354 PsDataPrefetch::GetInstance().NotifyFinalize();
355 insert_init_info_.notify_all();
356 data_prase_.notify_all();
357 if (process_data_thread_.joinable()) {
358 process_data_thread_.join();
359 }
360 }
361
ProcessData()362 bool PsCacheManager::ProcessData() {
363 struct timeval start_time, end_time;
364 const uint64_t kUSecondInSecond = 1000000;
365 (void)gettimeofday(&start_time, nullptr);
366 void *data = nullptr;
367 if (!PsDataPrefetch::GetInstance().QueryData(channel_name_, &data)) {
368 return false;
369 }
370 if (data == nullptr) {
371 MS_LOG(INFO) << "No data process, channel name:" << channel_name_;
372 std::unique_lock<std::mutex> locker(data_mutex_);
373 const int64_t longest_time_to_wait = 100;
374 (void)data_prase_.wait_for(locker, std::chrono::milliseconds(longest_time_to_wait));
375 return true;
376 }
377 RETURN_IF_FALSE(IncreaseStep());
378 auto data_size = PsDataPrefetch::GetInstance().data_size(channel_name_);
379 if (data_size == 0) {
380 MS_LOG(ERROR) << "The data_size can not be zero.";
381 return false;
382 }
383 auto batch_ids = reinterpret_cast<int *>(data);
384 auto batch_ids_len = data_size / sizeof(int);
385 std::unique_ptr<int[]> hash_index = std::make_unique<int[]>(batch_ids_len);
386 MS_ERROR_IF_NULL_W_RET_VAL(hash_index, false);
387 if (memset_s(&statistics_info_, sizeof(statistics_info_), 0, sizeof(statistics_info_))) {
388 MS_LOG(ERROR) << "Process data memset failed.";
389 return false;
390 }
391 // Get hash swap in/out index and ids.
392 RETURN_IF_FALSE(ParseData(batch_ids, batch_ids_len, hash_index.get()));
393 DumpStatisticsInfo();
394 if ((device_need_wait_graph_ || host_need_wait_graph_) && (!WaitGraphRun())) {
395 MS_LOG(ERROR) << "Ps cache wait graph finish failed.";
396 return false;
397 }
398 for (const auto &item : hash_tables_) {
399 auto key = Worker::GetInstance().GetParamKey(item.first);
400 auto hash_info = item.second;
401 RETURN_IF_FALSE(HashSwapHostToServer(key, hash_info));
402 RETURN_IF_FALSE(HashSwapDeviceToHost(hash_info));
403 RETURN_IF_FALSE(HashSwapServerToHost(key, hash_info));
404 RETURN_IF_FALSE(HashSwapHostToDevice(hash_info));
405 }
406 size_t dest_len = data_size;
407 // Replace the batch_ids by hash index for getNext-op getting hash index as input.
408 if (memcpy_s(data, dest_len, hash_index.get(), data_size) != EOK) {
409 MS_LOG(ERROR) << "Process data memcpy failed.";
410 return false;
411 }
412 RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream());
413 // Finish the data process and notify data prefetch.
414 RETURN_IF_FALSE(PsDataPrefetch::GetInstance().FinalizeData(channel_name_));
415 (void)gettimeofday(&end_time, nullptr);
416 uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
417 cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
418 const uint64_t milli_second_ratio = 1000;
419 MS_LOG(DEBUG) << "Ps cache completes processing data(data step:" << data_step_
420 << ",graph step:" << graph_running_step_ << " channel name:" << channel_name_
421 << ", time cost:" << (cost / milli_second_ratio) << "ms).";
422 return true;
423 }
424
CheckCacheHitOrOutRangeTask(const int * batch_ids,const size_t batch_ids_len,int * hash_index,bool * in_device,bool * out_range,size_t * hash_hit_count)425 bool PsCacheManager::CheckCacheHitOrOutRangeTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
426 bool *in_device, bool *out_range, size_t *hash_hit_count) {
427 MS_ERROR_IF_NULL(batch_ids);
428 MS_ERROR_IF_NULL(hash_index);
429 MS_ERROR_IF_NULL(in_device);
430 MS_ERROR_IF_NULL(hash_hit_count);
431 MS_ERROR_IF_NULL(embedding_device_cache_);
432 auto &device_hash_map = embedding_device_cache_->device_hash_map_;
433 MS_ERROR_IF_NULL(device_hash_map);
434 const auto &hash_id_to_index = device_hash_map->hash_id_to_index();
435
436 for (size_t i = 0; i < batch_ids_len; ++i) {
437 if (batch_ids[i] < emb_table_slice_bounds_.first) {
438 hash_index[i] = batch_ids[i] - vocab_cache_size_diff_;
439 out_range[i] = true;
440 continue;
441 }
442 if (batch_ids[i] >= emb_table_slice_bounds_.second) {
443 hash_index[i] = batch_ids[i] + cache_indices_bounds_.second;
444 out_range[i] = true;
445 continue;
446 }
447 auto iter = hash_id_to_index.find(batch_ids[i]);
448 if (iter != hash_id_to_index.end()) {
449 hash_index[i] = iter->second + cache_indices_bounds_.first;
450 if (device_hash_map->hash_step(iter->second) != data_step_) {
451 ++(*hash_hit_count);
452 device_hash_map->set_hash_step(iter->second, data_step_);
453 }
454 in_device[i] = true;
455 }
456 }
457 return true;
458 }
459
CheckCacheHitOrOutRange(const int * batch_ids,const size_t batch_ids_len,int * hash_index,bool * in_device,bool * out_range)460 bool PsCacheManager::CheckCacheHitOrOutRange(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
461 bool *in_device, bool *out_range) {
462 MS_ERROR_IF_NULL(batch_ids);
463 MS_ERROR_IF_NULL(hash_index);
464 MS_ERROR_IF_NULL(in_device);
465 MS_ERROR_IF_NULL(out_range);
466
467 size_t thread_num = batch_ids_len / kMaxIdsPerThread + 1;
468 thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
469 std::thread threads[kMaxThreadNum];
470 size_t hash_hit_count[kMaxThreadNum] = {0};
471 size_t i = 0;
472 size_t task_offset = 0;
473
474 for (; i < thread_num; ++i) {
475 if (task_offset >= batch_ids_len) {
476 break;
477 }
478 size_t task_proc_lens = batch_ids_len / thread_num + (i < (batch_ids_len % thread_num) ? 1 : 0);
479 threads[i] =
480 std::thread(&PsCacheManager::CheckCacheHitOrOutRangeTask, this, batch_ids + task_offset, task_proc_lens,
481 hash_index + task_offset, in_device + task_offset, out_range + task_offset, hash_hit_count + i);
482 task_offset += task_proc_lens;
483 }
484 if (task_offset != batch_ids_len) {
485 MS_LOG(WARNING) << "Ps cache check id in device inadequate, total:" << batch_ids_len << " checked:" << task_offset;
486 }
487
488 for (size_t j = 0; j < i; j++) {
489 threads[j].join();
490 }
491 for (size_t j = 0; j < i; j++) {
492 statistics_info_.hash_hit_count_ += hash_hit_count[j];
493 }
494 return true;
495 }
496
ResetEmbeddingHashMap()497 bool PsCacheManager::ResetEmbeddingHashMap() {
498 MS_ERROR_IF_NULL(embedding_device_cache_);
499 const auto &device_hash_map = embedding_device_cache_->device_hash_map_;
500 MS_ERROR_IF_NULL(device_hash_map);
501 MS_ERROR_IF_NULL(embedding_host_cache_);
502 const auto &host_hash_map = embedding_host_cache_->host_hash_map_;
503 MS_ERROR_IF_NULL(host_hash_map);
504 device_hash_map->Reset();
505 host_hash_map->Reset();
506 device_need_wait_graph_ = false;
507 host_need_wait_graph_ = false;
508 return true;
509 }
510
ParseData(const int * batch_ids,const size_t batch_ids_len,int * hash_index)511 bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) {
512 MS_ERROR_IF_NULL(batch_ids);
513 MS_ERROR_IF_NULL(hash_index);
514 statistics_info_.batch_id_count_ = batch_ids_len;
515 std::unique_ptr<bool[]> in_device = std::make_unique<bool[]>(batch_ids_len);
516 std::unique_ptr<bool[]> out_range = std::make_unique<bool[]>(batch_ids_len);
517 if (memset_s(in_device.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) {
518 MS_LOG(EXCEPTION) << "Initialize in_device array failed.";
519 }
520 if (memset_s(out_range.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) {
521 MS_LOG(EXCEPTION) << "Initialize out_range array failed.";
522 }
523 RETURN_IF_FALSE(CheckCacheHitOrOutRange(batch_ids, batch_ids_len, hash_index, in_device.get(), out_range.get()));
524 RETURN_IF_FALSE(ResetEmbeddingHashMap());
525 for (size_t i = 0; i < batch_ids_len; i++) {
526 if (in_device[i] || out_range[i]) {
527 continue;
528 }
529 bool need_swap_host_to_device = true;
530 bool need_swap_device_to_host = true;
531 int index = INVALID_INDEX_VALUE;
532 RETURN_IF_FALSE(ParseDeviceData(batch_ids[i], &need_swap_device_to_host, &need_swap_host_to_device, &index));
533 hash_index[i] = index + cache_indices_bounds_.first;
534 if (need_swap_host_to_device) {
535 RETURN_IF_FALSE(ParseHostDataHostToDevice(batch_ids[i]));
536 }
537 if (need_swap_device_to_host) {
538 RETURN_IF_FALSE(ParseHostDataDeviceToHost());
539 }
540 }
541 return true;
542 }
543
WaitGraphRun()544 bool PsCacheManager::WaitGraphRun() {
545 MS_LOG(INFO) << "Hash table has no space to insert new data and retries within 2 minutes.";
546 std::unique_lock<std::mutex> locker(data_mutex_);
547 const int64_t longest_time_to_wait = 120;
548 if (!data_prase_.wait_for(locker, std::chrono::seconds(longest_time_to_wait),
549 [this] { return graph_step_ > graph_running_step_; })) {
550 MS_LOG(ERROR) << "Ps cache data parse timeout, suggest to enlarge the cache size(graph step:" << graph_step_
551 << ", graph running step:" << graph_running_step_ << ").";
552 return false;
553 }
554 set_current_graph_step();
555 return true;
556 }
557
ParseDeviceData(size_t id,bool * need_swap_device_to_host,bool * need_swap_host_to_device,int * hash_index)558 bool PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device,
559 int *hash_index) {
560 MS_ERROR_IF_NULL(need_swap_device_to_host);
561 MS_ERROR_IF_NULL(need_swap_host_to_device);
562 MS_ERROR_IF_NULL(hash_index);
563 MS_ERROR_IF_NULL(embedding_device_cache_);
564 auto &device_hash_map = embedding_device_cache_->device_hash_map_;
565 MS_ERROR_IF_NULL(device_hash_map);
566
567 int index = INVALID_INDEX_VALUE;
568 const auto &hash_id_to_index = device_hash_map->hash_id_to_index();
569 const auto &iter = hash_id_to_index.find(id);
570 if (iter != hash_id_to_index.end()) {
571 *need_swap_device_to_host = false;
572 *need_swap_host_to_device = false;
573 index = iter->second;
574 if (device_hash_map->hash_step(index) != data_step_) {
575 statistics_info_.hash_hit_count_++;
576 device_hash_map->set_hash_step(index, data_step_);
577 }
578 } else {
579 int *device_to_host_index = embedding_device_cache_->device_to_host_index.get();
580 int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get();
581 int *host_to_device_index = embedding_device_cache_->host_to_device_index.get();
582 int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get();
583 MS_ERROR_IF_NULL(host_to_device_index);
584 MS_ERROR_IF_NULL(host_to_device_ids);
585 auto tmp_device_to_host_size = statistics_info_.device_to_host_size_;
586 while (true) {
587 index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_,
588 &(statistics_info_.device_to_host_size_), &device_need_wait_graph_);
589 if (index == INVALID_INDEX_VALUE) {
590 if (!WaitGraphRun()) {
591 return false;
592 }
593 continue;
594 }
595 host_to_device_index[statistics_info_.host_to_device_size_] = index;
596 host_to_device_ids[statistics_info_.host_to_device_size_] = id;
597 statistics_info_.host_to_device_size_++;
598 *need_swap_device_to_host = statistics_info_.device_to_host_size_ > tmp_device_to_host_size;
599 break;
600 }
601 }
602 *hash_index = index;
603 return true;
604 }
605
ParseHostDataHostToDevice(size_t id)606 bool PsCacheManager::ParseHostDataHostToDevice(size_t id) {
607 MS_ERROR_IF_NULL(embedding_host_cache_);
608 int *host_to_device_index = embedding_host_cache_->host_to_device_index.get();
609 MS_ERROR_IF_NULL(host_to_device_index);
610 auto &host_hash_map = embedding_host_cache_->host_hash_map_;
611 MS_ERROR_IF_NULL(host_hash_map);
612
613 const auto &hash_id_to_index = host_hash_map->hash_id_to_index();
614 const auto &iter = hash_id_to_index.find(id);
615 if (iter != hash_id_to_index.end()) {
616 auto index = iter->second;
617 if (host_hash_map->hash_step(index) != data_step_) {
618 host_hash_map->set_hash_step(index, data_step_);
619 }
620 host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index;
621 } else {
622 int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
623 int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
624 int *server_to_host_index = embedding_host_cache_->server_to_host_index.get();
625 int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get();
626 MS_ERROR_IF_NULL(server_to_host_index);
627 MS_ERROR_IF_NULL(server_to_host_ids);
628 while (true) {
629 auto index =
630 host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, graph_running_step_,
631 &statistics_info_.host_to_server_size_, &host_need_wait_graph_);
632 if (index == INVALID_INDEX_VALUE) {
633 RETURN_IF_FALSE(WaitGraphRun());
634 continue;
635 }
636 host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index;
637 server_to_host_index[statistics_info_.server_to_host_size_] = index;
638 server_to_host_ids[statistics_info_.server_to_host_size_++] = id;
639 break;
640 }
641 }
642 return true;
643 }
644
ParseHostDataDeviceToHost()645 bool PsCacheManager::ParseHostDataDeviceToHost() {
646 MS_ERROR_IF_NULL(embedding_device_cache_);
647 MS_ERROR_IF_NULL(embedding_host_cache_);
648 int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get();
649 int *device_to_host_index = embedding_host_cache_->device_to_host_index.get();
650 MS_ERROR_IF_NULL(device_to_host_ids);
651 MS_ERROR_IF_NULL(device_to_host_index);
652
653 auto &host_hash_map = embedding_host_cache_->host_hash_map_;
654 MS_ERROR_IF_NULL(host_hash_map);
655 int swap_device_to_host_id = device_to_host_ids[statistics_info_.device_to_host_size_ - 1];
656 const auto &hash_id_to_index = host_hash_map->hash_id_to_index();
657 const auto &iter = hash_id_to_index.find(swap_device_to_host_id);
658 if (iter != hash_id_to_index.end()) {
659 auto index = iter->second;
660 if (host_hash_map->hash_step(index) != data_step_) {
661 host_hash_map->set_hash_step(index, data_step_);
662 }
663 device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index;
664 } else {
665 int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
666 int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
667 while (true) {
668 auto index =
669 host_hash_map->ParseData(swap_device_to_host_id, host_to_server_index, host_to_server_ids, data_step_,
670 graph_running_step_, &statistics_info_.host_to_server_size_, &host_need_wait_graph_);
671 if (index == INVALID_INDEX_VALUE) {
672 RETURN_IF_FALSE(WaitGraphRun());
673 continue;
674 }
675 device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index;
676 break;
677 }
678 }
679 return true;
680 }
681
LookUpTableTask(size_t indices_lens,size_t outer_dim_size,size_t first_dim_size,const float * input_addr,const int * indices_addr,float * output_addr)682 void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size,
683 const float *input_addr, const int *indices_addr, float *output_addr) {
684 MS_ERROR_IF_NULL_WO_RET_VAL(input_addr);
685 MS_ERROR_IF_NULL_WO_RET_VAL(indices_addr);
686 MS_ERROR_IF_NULL_WO_RET_VAL(output_addr);
687 auto type_size = sizeof(float);
688 size_t lens = outer_dim_size * type_size;
689 for (size_t i = 0; i < indices_lens; ++i) {
690 int index = indices_addr[i];
691 if (index >= 0 && index < SizeToInt(first_dim_size)) {
692 size_t pos = index * outer_dim_size;
693 auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens);
694 if (ret != EOK) {
695 MS_LOG(ERROR) << "LookUpTable task memcpy failed.";
696 running_ = false;
697 return;
698 }
699 } else {
700 auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens);
701 if (ret != EOK) {
702 MS_LOG(ERROR) << "LookUpTable task memset failed.";
703 running_ = false;
704 return;
705 }
706 }
707 output_addr += outer_dim_size;
708 }
709 }
710
LookUpHostHashTable(size_t embedding_size,size_t indices_lens,const float * hash_table_addr,const int * indices_addr,float * output_addr)711 bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
712 const int *indices_addr, float *output_addr) {
713 MS_ERROR_IF_NULL_W_RET_VAL(hash_table_addr, false);
714 MS_ERROR_IF_NULL_W_RET_VAL(indices_addr, false);
715 MS_ERROR_IF_NULL_W_RET_VAL(output_addr, false);
716 size_t first_dim_size = host_vocab_cache_size_;
717 size_t outer_dim_size = embedding_size;
718
719 size_t thread_num = indices_lens / kMaxIdsPerThread + 1;
720 thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
721 std::thread threads[kMaxThreadNum];
722 size_t task_proc_lens = (indices_lens + thread_num - 1) / thread_num;
723 size_t i = 0;
724 size_t task_offset = 0;
725 MS_LOG(DEBUG) << "Indices lens: " << indices_lens << ", one task proc lens:" << task_proc_lens;
726 for (; i < thread_num; i++) {
727 if (task_offset >= indices_lens) {
728 break;
729 }
730 MS_LOG(DEBUG) << "Task offset: " << task_offset << ", task process lens:" << task_proc_lens;
731 threads[i] = std::thread(&PsCacheManager::LookUpTableTask, this, task_proc_lens, outer_dim_size, first_dim_size,
732 hash_table_addr, indices_addr + task_offset, output_addr + task_offset * outer_dim_size);
733 task_offset += task_proc_lens;
734 if (task_offset + task_proc_lens > indices_lens) {
735 task_proc_lens = indices_lens - task_offset;
736 }
737 }
738 for (size_t j = 0; j < i; j++) {
739 threads[j].join();
740 }
741 return running_;
742 }
743
InsertHostHashTable(size_t embedding_size,size_t insert_indices_size,const int * insert_indices,const float * insert_data,float * hash_table_addr)744 bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, const int *insert_indices,
745 const float *insert_data, float *hash_table_addr) {
746 MS_ERROR_IF_NULL_W_RET_VAL(insert_indices, false);
747 MS_ERROR_IF_NULL_W_RET_VAL(insert_data, false);
748 MS_ERROR_IF_NULL_W_RET_VAL(hash_table_addr, false);
749 size_t first_dim_size = host_vocab_cache_size_;
750 size_t thread_num = insert_indices_size / kMaxIdsPerThread + 1;
751 thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
752 std::thread threads[kMaxThreadNum];
753 size_t task_proc_lens = (insert_indices_size + thread_num - 1) / thread_num;
754 size_t i = 0;
755 size_t task_offset = 0;
756
757 auto insert_hash_table_task = [this](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size,
758 const int *insert_indices, const float *insert_data, float *hash_table_addr) {
759 auto type_size = sizeof(float);
760 size_t copy_len = outer_dim_size * type_size;
761 size_t dest_len = copy_len;
762 for (size_t i = 0; i < insert_indices_size; ++i) {
763 int index = insert_indices[i];
764 if (index >= 0 && index < SizeToInt(first_dim_size)) {
765 auto ret =
766 memcpy_s(hash_table_addr + index * outer_dim_size, dest_len, insert_data + i * outer_dim_size, copy_len);
767 if (ret != EOK) {
768 MS_LOG(ERROR) << "Insert hash table task memcpy failed.";
769 running_ = false;
770 return;
771 }
772 }
773 }
774 };
775
776 for (; i < thread_num; i++) {
777 if (task_offset >= insert_indices_size) {
778 break;
779 }
780 MS_LOG(DEBUG) << "Task offset: " << task_offset << ", task process lens:" << task_proc_lens;
781 threads[i] = std::thread(insert_hash_table_task, task_proc_lens, embedding_size, first_dim_size,
782 insert_indices + task_offset, insert_data + task_offset * embedding_size, hash_table_addr);
783 task_offset += task_proc_lens;
784 if (task_offset + task_proc_lens > insert_indices_size) {
785 task_proc_lens = insert_indices_size - task_offset;
786 }
787 }
788
789 for (size_t j = 0; j < i; j++) {
790 threads[j].join();
791 }
792 return running_;
793 }
794
HashSwapHostToDevice(const HashTableInfo & hash_info)795 bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) {
796 MS_ERROR_IF_NULL(embedding_device_cache_);
797 MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
798 MS_ERROR_IF_NULL(embedding_host_cache_);
799 auto host_cache_host_to_device_index = embedding_host_cache_->host_to_device_index.get();
800 auto device_cache_host_to_device_index = embedding_device_cache_->host_to_device_index.get();
801 auto swap_indices_size = statistics_info_.host_to_device_size_;
802 if (swap_indices_size == 0) {
803 return true;
804 }
805 auto embedding_size = hash_info.embedding_size;
806 MS_ERROR_IF_NULL_W_RET_VAL(hash_info.device_address.addr, false);
807 auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
808 auto cache_vocab_size = hash_info.cache_vocab_size;
809 MS_ERROR_IF_NULL_W_RET_VAL(hash_info.host_address, false);
810 auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
811 auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
812 RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr,
813 host_cache_host_to_device_index, swap_out_data.get()));
814 RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(
815 embedding_device_cache_->hash_swap_value_addr_, swap_out_data.get(),
816 swap_indices_size * embedding_size * sizeof(float)));
817 RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_,
818 device_cache_host_to_device_index,
819 swap_indices_size * sizeof(int)));
820 RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn(
821 hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
822 cache_vocab_size, embedding_size, swap_indices_size));
823 RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream());
824 return true;
825 }
826
HashSwapDeviceToHost(const HashTableInfo & hash_info)827 bool PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) {
828 MS_ERROR_IF_NULL(embedding_device_cache_);
829 MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
830 MS_ERROR_IF_NULL(embedding_host_cache_);
831 auto swap_indices_size = statistics_info_.device_to_host_size_;
832 auto device_cache_device_to_host_index = embedding_device_cache_->device_to_host_index.get();
833 auto host_cache_device_to_host_index = embedding_host_cache_->device_to_host_index.get();
834 if (swap_indices_size == 0) {
835 return true;
836 }
837 auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
838 auto cache_vocab_size = hash_info.cache_vocab_size;
839 auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
840 auto embedding_size = hash_info.embedding_size;
841 auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
842 RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_,
843 device_cache_device_to_host_index,
844 swap_indices_size * sizeof(int)));
845 RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut(
846 hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
847 cache_vocab_size, embedding_size, swap_indices_size));
848 RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost(
849 swap_out_data.get(), embedding_device_cache_->hash_swap_value_addr_,
850 swap_indices_size * embedding_size * sizeof(float)));
851 RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream());
852 RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index,
853 swap_out_data.get(), host_hash_table_addr));
854 return true;
855 }
856
HashSwapHostToServer(size_t key,const HashTableInfo & hash_info)857 bool PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_info) {
858 MS_ERROR_IF_NULL(embedding_host_cache_);
859 auto host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
860 MS_ERROR_IF_NULL_W_RET_VAL(host_to_server_ids, false);
861 auto host_to_server_index = embedding_host_cache_->host_to_server_index.get();
862 MS_ERROR_IF_NULL_W_RET_VAL(host_to_server_index, false);
863 auto swap_indices_size = statistics_info_.host_to_server_size_;
864 if (swap_indices_size == 0) {
865 return true;
866 }
867 std::vector<int> lookup_ids(swap_indices_size, 0);
868 std::vector<float> swap_out_data;
869 auto embedding_size = hash_info.embedding_size;
870 swap_out_data.resize(swap_indices_size * embedding_size);
871 auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
872 RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index,
873 swap_out_data.data()));
874
875 size_t copy_len = swap_indices_size * sizeof(int);
876 size_t dest_len = copy_len;
877 auto ret = memcpy_s(lookup_ids.data(), dest_len, host_to_server_ids, copy_len);
878 if (ret != EOK) {
879 MS_LOG(ERROR) << "Lookup id memcpy failed.";
880 return false;
881 }
882 Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
883 return true;
884 }
885
HashSwapServerToHost(size_t key,const HashTableInfo & hash_info)886 bool PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_info) {
887 MS_ERROR_IF_NULL(embedding_host_cache_);
888 auto swap_indices_size = statistics_info_.server_to_host_size_;
889 auto server_to_host_ids = embedding_host_cache_->server_to_host_ids.get();
890 MS_ERROR_IF_NULL_W_RET_VAL(server_to_host_ids, false);
891 auto server_to_host_index = embedding_host_cache_->server_to_host_index.get();
892 MS_ERROR_IF_NULL_W_RET_VAL(server_to_host_index, false);
893 if (swap_indices_size == 0) {
894 return true;
895 }
896 auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
897 MS_ERROR_IF_NULL_W_RET_VAL(host_hash_table_addr, false);
898 auto embedding_size = hash_info.embedding_size;
899 std::vector<float> lookup_result(swap_indices_size * embedding_size, 0);
900 std::vector<int> lookup_ids(swap_indices_size, 0);
901 size_t copy_len = swap_indices_size * sizeof(int);
902 size_t dest_len = copy_len;
903 auto ret = memcpy_s(lookup_ids.data(), dest_len, server_to_host_ids, copy_len);
904 if (ret != EOK) {
905 MS_LOG(ERROR) << "Lookup id memcpy failed.";
906 return false;
907 }
908 Worker::GetInstance().DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
909 RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index,
910 lookup_result.data(), host_hash_table_addr));
911 return true;
912 }
913
HashSwapDeviceOut(int * swap_out_index,std::vector<float> * swap_out_data,const HashTableInfo & hash_info)914 bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, std::vector<float> *swap_out_data,
915 const HashTableInfo &hash_info) {
916 MS_ERROR_IF_NULL(swap_out_index);
917 MS_ERROR_IF_NULL(swap_out_data);
918 MS_ERROR_IF_NULL(embedding_device_cache_);
919 MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
920 auto swap_out_index_size = statistics_info_.device_to_host_size_;
921 if (swap_out_index_size == 0) {
922 return true;
923 }
924
925 MS_ERROR_IF_NULL_W_RET_VAL(hash_info.device_address.addr, false);
926 auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
927 auto cache_vocab_size = hash_info.cache_vocab_size;
928 auto embedding_size = hash_info.embedding_size;
929 swap_out_data->resize(swap_out_index_size * embedding_size);
930 RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(
931 embedding_device_cache_->hash_swap_index_addr_, swap_out_index, swap_out_index_size * sizeof(int)));
932 RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut(
933 hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
934 cache_vocab_size, embedding_size, swap_out_index_size));
935 RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost(
936 swap_out_data->data(), embedding_device_cache_->hash_swap_value_addr_,
937 swap_out_index_size * embedding_size * sizeof(float)));
938 RETURN_IF_FALSE(embedding_device_cache_->cache_->RecordEvent());
939 return true;
940 }
941
HashSwapDeviceIn(const int * swap_in_ids,const int * swap_in_index,const HashTableInfo & hash_info,size_t key)942 bool PsCacheManager::HashSwapDeviceIn(const int *swap_in_ids, const int *swap_in_index, const HashTableInfo &hash_info,
943 size_t key) {
944 MS_ERROR_IF_NULL(swap_in_ids);
945 MS_ERROR_IF_NULL(swap_in_index);
946 MS_ERROR_IF_NULL(embedding_device_cache_);
947 MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
948 auto swap_in_ids_size = statistics_info_.host_to_device_size_;
949 if (swap_in_ids_size == 0) {
950 return true;
951 }
952
953 MS_ERROR_IF_NULL_W_RET_VAL(hash_info.device_address.addr, false);
954 auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
955 auto cache_vocab_size = hash_info.cache_vocab_size;
956 auto embedding_size = hash_info.embedding_size;
957 // Get id embs by swap_in_ids in host(Pipeline with hash swap-out in device).
958 std::vector<float> lookup_result(swap_in_ids_size * embedding_size, 0);
959 std::vector<int> lookup_ids(swap_in_ids_size, 0);
960 size_t copy_len = swap_in_ids_size * sizeof(int);
961 size_t dest_len = copy_len;
962 auto ret = memcpy_s(lookup_ids.data(), dest_len, swap_in_ids, copy_len);
963 if (ret != EOK) {
964 MS_LOG(ERROR) << "Lookup id memcpy failed.";
965 return false;
966 }
967 Worker::GetInstance().DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
968 // Hash swap-in in device.
969 RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(
970 embedding_device_cache_->hash_swap_value_addr_, lookup_result.data(),
971 swap_in_ids_size * embedding_size * sizeof(float)));
972 RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_,
973 swap_in_index, swap_in_ids_size * sizeof(int)));
974 RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn(
975 hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
976 cache_vocab_size, embedding_size, swap_in_ids_size));
977 return true;
978 }
979
UpdataEmbeddingTable(const std::vector<float> & swap_out_data,int * const swap_out_ids,size_t key)980 bool PsCacheManager::UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *const swap_out_ids,
981 size_t key) {
982 MS_ERROR_IF_NULL(embedding_device_cache_);
983 MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
984 MS_ERROR_IF_NULL(swap_out_ids);
985 auto swap_out_ids_size = statistics_info_.device_to_host_size_;
986 if (swap_out_ids_size == 0) {
987 return true;
988 }
989 std::vector<int> lookup_ids(swap_out_ids_size, 0);
990 size_t copy_len = swap_out_ids_size * sizeof(int);
991 size_t dest_len = copy_len;
992 auto ret = memcpy_s(lookup_ids.data(), dest_len, swap_out_ids, copy_len);
993 if (ret != EOK) {
994 MS_LOG(ERROR) << "Lookup id memcpy failed.";
995 return false;
996 }
997 // Need synchronize event to ensure that the swap-out in device is completed.
998 RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeEvent());
999 Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
1000 return true;
1001 }
1002
SyncEmbeddingTable()1003 void PsCacheManager::SyncEmbeddingTable() {
1004 if (finish_embedding_table_sync_) {
1005 return;
1006 }
1007 if (!initialized_ps_cache_) {
1008 return;
1009 }
1010 if (!SyncHostEmbeddingTable()) {
1011 MS_LOG(ERROR) << "SyncHostEmbeddingTable failed.";
1012 }
1013 if (!SyncDeviceEmbeddingTable()) {
1014 MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed.";
1015 }
1016 finish_embedding_table_sync_ = true;
1017 }
1018
SyncHostEmbeddingTable()1019 bool PsCacheManager::SyncHostEmbeddingTable() {
1020 MS_ERROR_IF_NULL(embedding_host_cache_);
1021 MS_ERROR_IF_NULL(embedding_host_cache_->host_hash_map_);
1022 const auto &hash_id_to_index = embedding_host_cache_->host_hash_map_->hash_id_to_index();
1023 size_t swap_indices_lens = hash_id_to_index.size();
1024 if (swap_indices_lens == 0) {
1025 return true;
1026 }
1027 std::unique_ptr<int[]> host_to_server_ids_ptr = std::make_unique<int[]>(swap_indices_lens);
1028 MS_ERROR_IF_NULL(host_to_server_ids_ptr);
1029 std::unique_ptr<int[]> host_to_server_indices_ptr = std::make_unique<int[]>(swap_indices_lens);
1030 MS_ERROR_IF_NULL(host_to_server_indices_ptr);
1031 size_t idx = 0;
1032 for (const auto &item : hash_id_to_index) {
1033 host_to_server_ids_ptr[idx] = item.first;
1034 host_to_server_indices_ptr[idx++] = item.second;
1035 }
1036 for (const auto &item : hash_tables_) {
1037 const auto &hash_info = item.second;
1038 if (hash_info.param_init_info_.param_type_ != kWeight) {
1039 continue;
1040 }
1041 auto key = Worker::GetInstance().GetParamKey(item.first);
1042 std::vector<int> lookup_ids(swap_indices_lens, 0);
1043 std::vector<float> swap_out_data;
1044 auto embedding_size = hash_info.embedding_size;
1045 swap_out_data.resize(swap_indices_lens * embedding_size);
1046 auto host_hash_table_addr = hash_info.host_address.get();
1047 MS_ERROR_IF_NULL(host_hash_table_addr);
1048 RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_lens, host_hash_table_addr,
1049 host_to_server_indices_ptr.get(), swap_out_data.data()));
1050
1051 size_t copy_len = swap_indices_lens * sizeof(int);
1052 size_t dest_len = copy_len;
1053 auto ret = memcpy_s(lookup_ids.data(), dest_len, host_to_server_ids_ptr.get(), copy_len);
1054 if (ret != EOK) {
1055 MS_LOG(ERROR) << "Lookup id memcpy failed.";
1056 return false;
1057 }
1058 Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
1059 }
1060 return true;
1061 }
1062
SyncDeviceEmbeddingTable()1063 bool PsCacheManager::SyncDeviceEmbeddingTable() {
1064 MS_ERROR_IF_NULL(embedding_device_cache_);
1065 MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
1066 const auto &device_hash_map = embedding_device_cache_->device_hash_map_;
1067 MS_ERROR_IF_NULL(device_hash_map);
1068 const auto &hash_id_to_index = device_hash_map->hash_id_to_index();
1069 size_t swap_indices_lens = hash_id_to_index.size();
1070 if (swap_indices_lens == 0) {
1071 return true;
1072 }
1073 std::unique_ptr<int[]> device_to_server_ids_ptr = std::make_unique<int[]>(swap_indices_lens);
1074 MS_ERROR_IF_NULL(device_to_server_ids_ptr);
1075 std::unique_ptr<int[]> device_to_server_indices_ptr = std::make_unique<int[]>(swap_indices_lens);
1076 MS_ERROR_IF_NULL(device_to_server_indices_ptr);
1077 size_t idx = 0;
1078 for (const auto &item : hash_id_to_index) {
1079 device_to_server_ids_ptr[idx] = item.first;
1080 device_to_server_indices_ptr[idx++] = item.second;
1081 }
1082 for (const auto &item : hash_tables_) {
1083 const auto &hash_info = item.second;
1084 if (hash_info.param_init_info_.param_type_ != kWeight) {
1085 continue;
1086 }
1087 auto key = Worker::GetInstance().GetParamKey(item.first);
1088 std::vector<int> lookup_ids(swap_indices_lens, 0);
1089 std::vector<float> swap_out_data;
1090 auto embedding_size = hash_info.embedding_size;
1091 swap_out_data.resize(swap_indices_lens * embedding_size);
1092 std::unique_ptr<float[]> device_hash_table_addr_tmp =
1093 std::make_unique<float[]>(device_hash_map->hash_capacity() * embedding_size);
1094 MS_ERROR_IF_NULL(device_hash_table_addr_tmp);
1095
1096 auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
1097 MS_ERROR_IF_NULL(hash_table_addr);
1098 auto hash_table_size = hash_info.device_address.size;
1099 RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost(device_hash_table_addr_tmp.get(),
1100 hash_table_addr, hash_table_size));
1101 RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream());
1102 RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_lens, device_hash_table_addr_tmp.get(),
1103 device_to_server_indices_ptr.get(), swap_out_data.data()));
1104
1105 size_t copy_len = swap_indices_lens * sizeof(int);
1106 size_t dest_len = copy_len;
1107 auto ret = memcpy_s(lookup_ids.data(), dest_len, device_to_server_ids_ptr.get(), copy_len);
1108 if (ret != EOK) {
1109 MS_LOG(ERROR) << "Lookup id memcpy failed.";
1110 return false;
1111 }
1112 Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
1113 }
1114 return true;
1115 }
1116
DumpHashTables(bool dump_device_tables) const1117 void PsCacheManager::DumpHashTables(bool dump_device_tables) const {
1118 MS_EXCEPTION_IF_NULL(embedding_device_cache_);
1119 MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
1120 for (const auto &item : hash_tables_) {
1121 const auto ¶m_name = item.first;
1122 size_t cache_vocab_size = item.second.cache_vocab_size;
1123 size_t host_cache_vocab_size = item.second.host_cache_vocab_size;
1124 size_t embedding_size = item.second.embedding_size;
1125 size_t vocab_size = item.second.vocab_size;
1126 MS_LOG(INFO) << "Hash table info:"
1127 << " embedding table name:" << param_name << ", vocab size:" << vocab_size
1128 << ", embedding size:" << embedding_size << ", device cache size:" << cache_vocab_size
1129 << ", host cache size:" << host_cache_vocab_size
1130 << ", device cache address:" << reinterpret_cast<void *>(item.second.device_address.addr)
1131 << ", host cache address:" << reinterpret_cast<void *>(item.second.host_address.get());
1132 if (dump_device_tables) {
1133 std::unique_ptr<float[]> output = std::make_unique<float[]>(item.second.device_address.size / sizeof(float));
1134 embedding_device_cache_->cache_->CopyDeviceMemToHost(output.get(), item.second.device_address.addr,
1135 item.second.device_address.size);
1136 embedding_device_cache_->cache_->SynchronizeStream();
1137 for (size_t i = 0; i < cache_vocab_size; i++) {
1138 for (size_t j = 0; j < embedding_size; j++) {
1139 std::cout << output[i * embedding_size + j] << " ";
1140 }
1141 std::cout << std::endl;
1142 }
1143 std::cout << std::endl;
1144 }
1145 }
1146 }
1147
DumpStatisticsInfo(size_t each_print_step)1148 void PsCacheManager::DumpStatisticsInfo(size_t each_print_step) {
1149 // Default each 1000 step prints ps cache hit rate.
1150 const size_t kFloatToPercentSign = 100;
1151 if (data_step_ % each_print_step == 0) {
1152 statistics_info_.batch_id_unique_count_ = statistics_info_.hash_hit_count_ + statistics_info_.host_to_device_size_;
1153 auto repeat_rate = SizeToFloat(statistics_info_.batch_id_count_ - statistics_info_.batch_id_unique_count_) /
1154 statistics_info_.batch_id_count_;
1155 auto device_hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_;
1156 auto host_hit_rate = SizeToFloat(statistics_info_.batch_id_unique_count_ - statistics_info_.server_to_host_size_) /
1157 statistics_info_.batch_id_unique_count_;
1158 MS_LOG(INFO) << "PS embedding cache data statistics info(total id num:" << statistics_info_.batch_id_count_
1159 << ", unique id num:" << statistics_info_.batch_id_unique_count_
1160 << ", host swap to device num:" << statistics_info_.host_to_device_size_
1161 << ", device swap to host num:" << statistics_info_.device_to_host_size_
1162 << ", host swap to server num:" << statistics_info_.host_to_server_size_
1163 << ", server swap to host num:" << statistics_info_.server_to_host_size_
1164 << ", data repeat rate:" << (repeat_rate * kFloatToPercentSign)
1165 << "%, device cache hit rate:" << (device_hit_rate * kFloatToPercentSign)
1166 << "%, host cache hit rate:" << (host_hit_rate * kFloatToPercentSign) << ").";
1167 }
1168 }
1169 } // namespace ps
1170 } // namespace mindspore
1171