• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/backend/distributed/embedding_cache/embedding_cache_utils.h"
18 #include <algorithm>
19 #include <thread>
20 #include "utils/log_adapter.h"
21 #include "utils/ms_utils.h"
22 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
23 #include "include/backend/distributed/cluster/cluster_context.h"
24 #endif
25 #include "include/backend/distributed/ps/ps_context.h"
26 #include "distributed/embedding_cache/embedding_storage/dense_embedding_storage.h"
27 #include "distributed/embedding_cache/embedding_storage/sparse_embedding_storage.h"
28 #include "include/backend/distributed/embedding_cache/embedding_storage/abstract_embedding_storage.h"
29 
30 namespace mindspore {
31 namespace distributed {
GetInstance()32 EmbeddingCacheTableManager &EmbeddingCacheTableManager::GetInstance() {
33   static EmbeddingCacheTableManager instance{};
34   return instance;
35 }
36 
Initialize()37 void EmbeddingCacheTableManager::Initialize() {
38   auto worker_num = ps::PSContext::instance()->worker_num();
39   multi_batch_threshold_ = worker_num > 1 ? 1 : kMultiBatchThreshold;
40   GetEmbeddingTableSliceBound();
41 
42   device::DeviceContextKey host_key = {"CPU", 0};
43   cpu_device_context_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key);
44   MS_EXCEPTION_IF_NULL(cpu_device_context_);
45   cpu_device_context_->Initialize();
46 }
47 
Finalize(const device::DeviceContext * device_context)48 void EmbeddingCacheTableManager::Finalize(const device::DeviceContext *device_context) {
49   hash_tables_.clear();
50 
51   device_hash_map_ = nullptr;
52   host_hash_map_ = nullptr;
53 
54   MS_EXCEPTION_IF_NULL(device_context);
55   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
56   if (hash_swap_index_addr_) {
57     device_context->device_res_manager_->FreeMemory(hash_swap_index_addr_);
58   }
59   if (hash_swap_value_addr_) {
60     device_context->device_res_manager_->FreeMemory(hash_swap_value_addr_);
61   }
62   for (auto &item : hash_tables_) {
63     if (item.second.host_address) {
64       MS_EXCEPTION_IF_NULL(cpu_device_context_);
65       MS_EXCEPTION_IF_NULL(cpu_device_context_->device_res_manager_);
66       cpu_device_context_->device_res_manager_->FreeMemory(item.second.host_address);
67     }
68   }
69 }
70 
InsertHashTableSize(const std::string & param_name,size_t cache_vocab_size,size_t embedding_size,size_t vocab_size,int32_t param_key)71 void EmbeddingCacheTableManager::InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size,
72                                                      size_t embedding_size, size_t vocab_size, int32_t param_key) {
73   if (cache_vocab_size == 0 || embedding_size == 0 || vocab_size == 0) {
74     MS_LOG(EXCEPTION) << "The size of hash table can not equal to zero.";
75   }
76   hash_tables_[param_name].cache_vocab_size = cache_vocab_size;
77   hash_tables_[param_name].host_cache_vocab_size = cache_vocab_size * kHostCacheScaleFactor;
78   hash_tables_[param_name].embedding_size = embedding_size;
79   hash_tables_[param_name].vocab_size = vocab_size;
80   hash_tables_[param_name].param_key_ = param_key;
81 
82   if (vocab_size_ == 0) {
83     vocab_size_ = vocab_size;
84   }
85   if (device_cache_size_ == 0) {
86     device_cache_size_ = cache_vocab_size;
87   }
88   if (host_cache_size_ == 0) {
89     host_cache_size_ = cache_vocab_size * kHostCacheScaleFactor;
90   }
91 }
92 
ReInsertHashTableSize(const std::string & new_param_name,const std::string & cur_param_name)93 void EmbeddingCacheTableManager::ReInsertHashTableSize(const std::string &new_param_name,
94                                                        const std::string &cur_param_name) {
95   if (new_param_name.empty() || cur_param_name.empty()) {
96     MS_LOG(EXCEPTION) << "Parameter name can not be empty.";
97   }
98   if (new_param_name == cur_param_name) {
99     return;
100   }
101   auto iter = hash_tables_.find(cur_param_name);
102   if (iter == hash_tables_.end()) {
103     MS_LOG(EXCEPTION) << "Can not find parameter[" << cur_param_name << "] in hash table.";
104   }
105   (void)hash_tables_.emplace(new_param_name, iter->second);
106   (void)hash_tables_.erase(iter);
107 }
108 
InsertAccumuInitInfo(const std::string & param_name,float init_val)109 void EmbeddingCacheTableManager::InsertAccumuInitInfo(const std::string &param_name, float init_val) {
110   auto iter = hash_tables_.find(param_name);
111   if (iter == hash_tables_.end()) {
112     MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table.";
113   }
114   auto &hash_table_info = iter->second;
115   if (hash_table_info.param_init_info_.param_type_ != kUnKnown) {
116     return;
117   }
118   MS_LOG(INFO) << "Insert accumulation init info:" << param_name << ", init value:" << init_val;
119   hash_table_info.param_init_info_.param_name_ = param_name;
120   hash_table_info.param_init_info_.param_type_ = kAccumulation;
121   hash_table_info.param_init_info_.init_val_ = init_val;
122 }
123 
CloneHashTable(const std::string & dest_param_name,int32_t dest_param_key,const std::string & src_param_name,int32_t src_param_key)124 void EmbeddingCacheTableManager::CloneHashTable(const std::string &dest_param_name, int32_t dest_param_key,
125                                                 const std::string &src_param_name, int32_t src_param_key) {
126   if (dest_param_name == src_param_name) {
127     MS_LOG(INFO) << "The dest_param_name is same as src_param_name";
128     return;
129   }
130   auto iter = hash_tables_.find(src_param_name);
131   if (iter == hash_tables_.end()) {
132     MS_LOG(EXCEPTION) << "The source hash table[" << src_param_name << "] does not exist, clone failed.";
133   }
134   (void)hash_tables_.emplace(dest_param_name, iter->second);
135   hash_tables_[src_param_name].param_key_ = src_param_key;
136   hash_tables_[dest_param_name].param_key_ = dest_param_key;
137 }
138 
QueryEmbeddingDeviceAddress(const std::string & param_name) const139 const DeviceAddress *EmbeddingCacheTableManager::QueryEmbeddingDeviceAddress(const std::string &param_name) const {
140   auto iter = hash_tables_.find(param_name);
141   if (iter == hash_tables_.end()) {
142     MS_LOG(EXCEPTION) << "Can not find device address of " << param_name;
143   }
144   return iter->second.device_address;
145 }
146 
QueryHashTableSize(const std::string & param_name) const147 size_t EmbeddingCacheTableManager::QueryHashTableSize(const std::string &param_name) const {
148   auto iter = hash_tables_.find(param_name);
149   if (iter == hash_tables_.end()) {
150     MS_LOG(EXCEPTION) << "Can not find vocab cache size of " << param_name;
151   }
152   return iter->second.cache_vocab_size;
153 }
154 
SetEmbeddingDeviceAddress(const std::string & param_name,DeviceAddress * device_address)155 void EmbeddingCacheTableManager::SetEmbeddingDeviceAddress(const std::string &param_name,
156                                                            DeviceAddress *device_address) {
157   MS_EXCEPTION_IF_NULL(device_address);
158   auto iter = hash_tables_.find(param_name);
159   if (iter == hash_tables_.end()) {
160     MS_LOG(EXCEPTION) << "Can not find hash table info for " << param_name;
161   }
162   iter->second.device_address = device_address;
163 }
164 
AllocMemForEmbedding(const device::DeviceContext * device_context)165 void EmbeddingCacheTableManager::AllocMemForEmbedding(const device::DeviceContext *device_context) {
166   MS_EXCEPTION_IF_NULL(device_context);
167   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
168 
169   size_t max_embedding_size = 0;
170   for (auto &item : hash_tables_) {
171     auto *device_address = item.second.device_address;
172     MS_EXCEPTION_IF_NULL(device_address);
173     if (device_address->GetPtr() == nullptr) {
174       MS_EXCEPTION_IF_CHECK_FAIL(device_context->device_res_manager_->AllocateMemory(device_address),
175                                  "Allocate device memory for embedding table failed.");
176     }
177     item.second.address = Address(device_address->GetMutablePtr(), device_address->GetSize());
178 
179     size_t embedding_size = item.second.embedding_size;
180     auto &host_address = item.second.host_address;
181     host_address = reinterpret_cast<float *>(
182       cpu_device_context_->device_res_manager_->AllocateMemory(host_cache_size_ * embedding_size * sizeof(float)));
183     MS_EXCEPTION_IF_NULL(host_address);
184 
185     max_embedding_size = (embedding_size > max_embedding_size) ? embedding_size : max_embedding_size;
186   }
187 
188   device_hash_map_ = std::make_shared<EmbeddingHashMap>(device_cache_size_);
189   MS_EXCEPTION_IF_NULL(device_hash_map_);
190   host_hash_map_ = std::make_shared<EmbeddingHashMap>(host_cache_size_);
191   MS_EXCEPTION_IF_NULL(host_hash_map_);
192 
193   hash_swap_index_addr_ = reinterpret_cast<int *>(
194     device_context->device_res_manager_->AllocateMemory(batch_ids_num_ * sizeof(int) * multi_batch_threshold_));
195   MS_EXCEPTION_IF_NULL(hash_swap_index_addr_);
196   hash_swap_value_addr_ = reinterpret_cast<float *>(device_context->device_res_manager_->AllocateMemory(
197     max_embedding_size * batch_ids_num_ * sizeof(float) * multi_batch_threshold_));
198   MS_EXCEPTION_IF_NULL(hash_swap_value_addr_);
199 }
200 
GetEmbeddingTableSliceBound()201 void EmbeddingCacheTableManager::GetEmbeddingTableSliceBound() {
202   auto worker_num = ps::PSContext::instance()->worker_num();
203   auto server_num = ps::PSContext::instance()->server_num();
204   if (worker_num == 0) {
205     return;
206   }
207   if (is_sparse_format() && (worker_num > 1 || server_num > 1)) {
208     MS_LOG(EXCEPTION) << "The sparse format can not support multi worker or multi server currently.";
209   }
210 
211   uint32_t rank_id = 0;
212 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
213   auto node = distributed::cluster::ClusterContext::instance()->node();
214   MS_EXCEPTION_IF_NULL(node);
215   rank_id = node->rank_id();
216 #endif
217 
218   if (!is_sparse_format()) {
219     int local_shard_size = UintToInt(SizeToUint(vocab_size_) / worker_num);
220     if (vocab_size_ % worker_num != 0) {
221       local_shard_size += 1;
222     }
223     local_embedding_slice_bounds_.first = local_shard_size * UintToInt(rank_id);
224     local_embedding_slice_bounds_.second =
225       std::min(local_embedding_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_));
226   } else {
227     local_embedding_slice_bounds_.first = 0;
228     local_embedding_slice_bounds_.second = INT_MAX;
229   }
230   local_device_cache_bounds_.first = SizeToInt(device_cache_size_) * UintToInt(rank_id);
231   local_device_cache_bounds_.second = local_device_cache_bounds_.first + SizeToInt(device_cache_size_);
232   MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id
233                << ", id begin:" << local_embedding_slice_bounds_.first
234                << ", id end:" << local_embedding_slice_bounds_.second
235                << ", cache indices begin: " << local_device_cache_bounds_.first
236                << ", cache indices end: " << local_device_cache_bounds_.second << ", vocab_size: " << vocab_size_
237                << ", device cache size: " << device_cache_size_;
238 }
239 
cache_indices_lower_bound() const240 int EmbeddingCacheTableManager::cache_indices_lower_bound() const { return local_device_cache_bounds_.first; }
241 
DumpHashTables() const242 void EmbeddingCacheTableManager::DumpHashTables() const {
243   for (const auto &item : hash_tables_) {
244     const auto &param_name = item.first;
245     size_t cache_vocab_size = item.second.cache_vocab_size;
246     size_t host_cache_vocab_size = item.second.host_cache_vocab_size;
247     size_t embedding_size = item.second.embedding_size;
248     size_t vocab_size = item.second.vocab_size;
249     int32_t param_key = item.second.param_key_;
250     MS_LOG(INFO) << "Hash table info:"
251                  << " param_key:" << param_key << ", embedding table name:" << param_name
252                  << ", vocab size:" << vocab_size << ", embedding size:" << embedding_size
253                  << ", device cache size:" << cache_vocab_size << ", host cache size:" << host_cache_vocab_size
254                  << ", device cache address:" << item.second.address.addr
255                  << ", host cache address:" << item.second.host_address;
256   }
257 }
258 
enable_pipeline() const259 bool EmbeddingCacheTableManager::enable_pipeline() const {
260   return ps::PSContext::instance()->is_worker() && ps::PSContext::instance()->cache_enable();
261 }
262 
StoreWarmUpPtr(const int32_t param_key,const tensor::TensorPtr & tensor_ptr)263 int32_t EmbeddingCacheTableManager::StoreWarmUpPtr(const int32_t param_key, const tensor::TensorPtr &tensor_ptr) {
264   return StoreWarmUpPtr(param_key, nullptr, tensor_ptr, nullptr);
265 }
266 
StoreWarmUpPtr(const int32_t param_key,const tensor::TensorPtr & key_ptr,const tensor::TensorPtr & value_ptr,const tensor::TensorPtr & status_ptr)267 int32_t EmbeddingCacheTableManager::StoreWarmUpPtr(const int32_t param_key, const tensor::TensorPtr &key_ptr,
268                                                    const tensor::TensorPtr &value_ptr,
269                                                    const tensor::TensorPtr &status_ptr) {
270   MS_LOG(INFO) << "Enter store warm up ptr, param_key : " << param_key << ".";
271   MS_EXCEPTION_IF_NULL(value_ptr);
272   std::unique_lock<std::mutex> lock(host_cache_mutex_);
273   auto ret = host_cache_ptrs_.find(param_key);
274   if (ret != host_cache_ptrs_.end()) {
275     MS_LOG(WARNING) << "Store warm up ptr duplicate, id : " << param_key << ".";
276   }
277   (void)host_cache_ptrs_.try_emplace(param_key, key_ptr, value_ptr, status_ptr);
278   MS_LOG(INFO) << "Exit store warm up ptr, host cache ptrs size : " << host_cache_ptrs_.size()
279                << ", hash tables size : " << hash_tables_.size() << ".";
280   return 0;
281 }
282 
FindHashTablesByParamKey(const int param_key)283 const HashTableInfo *EmbeddingCacheTableManager::FindHashTablesByParamKey(const int param_key) {
284   const auto &iter = std::find_if(hash_tables_.begin(), hash_tables_.end(),
285                                   [this, param_key](const auto &data) { return data.second.param_key_ == param_key; });
286   return iter != hash_tables_.end() ? &(iter->second) : nullptr;
287 }
288 
WarmUpHostCacheSync(const int32_t batch_count)289 void EmbeddingCacheTableManager::WarmUpHostCacheSync(const int32_t batch_count) {
290   MS_LOG(INFO) << "Enter warm up host cache sync, batch_count : " << batch_count << ".";
291   auto cache_ptr_size = host_cache_ptrs_.size();
292   auto hash_table_size = hash_tables_.size();
293   if (cache_ptr_size != hash_table_size) {
294     MS_LOG(WARNING) << "Host cache ptrs size : " << cache_ptr_size
295                     << " is not equal to hash table size : " << hash_table_size
296                     << ", will skip warm up host cache sync.";
297     std::unique_lock<std::mutex> lock(host_cache_mutex_);
298     host_cache_promise_->set_value(false);
299     lock.unlock();
300     host_cache_ptrs_.clear();
301     return;
302   }
303 
304   for (auto &item : host_cache_ptrs_) {
305     WarmUpHostCacheItemBatch(batch_count, item);
306   }
307   std::unique_lock<std::mutex> lock(host_cache_mutex_);
308   host_cache_promise_->set_value(true);
309   lock.unlock();
310   host_cache_ptrs_.clear();
311   MS_LOG(INFO) << "Exit warm up host cache sync.";
312 }
313 
WarmUpHostCacheAsync(const int32_t batch_count)314 void EmbeddingCacheTableManager::WarmUpHostCacheAsync(const int32_t batch_count) {
315   MS_LOG(DEBUG) << "Enter warm up host cache async, batch_count : " << batch_count << ".";
316   std::unique_lock<std::mutex> lock(host_cache_mutex_);
317   if (host_cache_promise_ != nullptr) {
318     lock.unlock();
319     MS_LOG(WARNING) << "Host cache promise is not null, cache sync has already done.";
320     return;
321   }
322   host_cache_promise_ = std::make_shared<std::promise<bool>>();
323   lock.unlock();
324   std::thread([this, batch_count]() { WarmUpHostCacheSync(batch_count); }).detach();
325   MS_LOG(DEBUG) << "Exit warm up host cache async.";
326 }
327 
GetWarmUpHostCacheAsyncStatus()328 std::pair<std::shared_ptr<std::future<bool>>, bool> EmbeddingCacheTableManager::GetWarmUpHostCacheAsyncStatus() {
329   MS_LOG(DEBUG) << "Enter get warm up host cache async status.";
330   std::unique_lock<std::mutex> lock(host_cache_mutex_);
331   if (host_cache_promise_ == nullptr) {
332     return std::make_pair(nullptr, false);
333   }
334   return std::make_pair(std::make_shared<std::future<bool>>(host_cache_promise_->get_future()), true);
335 }
336 
WaitForWarmUpHostCacheComplete()337 bool EmbeddingCacheTableManager::WaitForWarmUpHostCacheComplete() {
338   MS_LOG(DEBUG) << "Enter wait for warm up host cache complete.";
339   const int32_t batch_count = 4;
340   WarmUpHostCacheAsync(batch_count);
341   const auto &[complete_future, status] = GetWarmUpHostCacheAsyncStatus();
342   return status ? complete_future->get() : status;
343 }
344 
generate_key_tensor_ptr(const tensor::TensorPtr & tensor_ptr)345 tensor::TensorPtr generate_key_tensor_ptr(const tensor::TensorPtr &tensor_ptr) {
346   auto &vec = tensor_ptr->shape();
347   auto cel_num = static_cast<int>(vec[0]);
348   std::vector<int32_t> key_vec(cel_num);
349   for (auto i = 0; i != cel_num; i++) {
350     key_vec[i] = i;
351   }
352   return std::make_shared<tensor::Tensor>(key_vec);
353 }
354 
WarmUpHostCacheItemBatch(const int32_t batch_count,const WarmUpCacheMapEntry & entry)355 void EmbeddingCacheTableManager::WarmUpHostCacheItemBatch(const int32_t batch_count, const WarmUpCacheMapEntry &entry) {
356   MS_LOG(DEBUG) << "Enter warm up host cache item batch.";
357   auto key_ptr = std::get<0>(entry.second);
358   auto value_ptr = std::get<1>(entry.second);
359   MS_EXCEPTION_IF_NULL(value_ptr);
360   // Key tensor may be nullptr since we stored single value tensor.
361   if (key_ptr == nullptr) {
362     MS_LOG(INFO) << "key_ptr is nullptr, generate key tensor.";
363     key_ptr = generate_key_tensor_ptr(value_ptr);
364   }
365   auto &vec = key_ptr->shape();
366   auto l_len = static_cast<int>(vec[0]);
367   const int32_t default_batch_count = 1;
368   const int validate_batch_count = batch_count < default_batch_count ? default_batch_count : batch_count;
369   int batch_size = l_len / validate_batch_count;
370   if (l_len % validate_batch_count != 0) {
371     batch_size++;
372   }
373 
374   if (host_hash_map_ == nullptr) {
375     MS_LOG(WARNING) << "Embedding hash map of embedding host cache is nullptr, will skip warm up.";
376     return;
377   }
378 
379   auto hash_table_info_ptr = FindHashTablesByParamKey(entry.first);
380   if (hash_table_info_ptr == nullptr) {
381     MS_LOG(WARNING) << "Hash table info is nullptr, will skip warm up.";
382     return;
383   }
384 
385   size_t host_length = (hash_table_info_ptr->host_cache_vocab_size * hash_table_info_ptr->embedding_size) << 2;
386   auto &value_shape = value_ptr->shape();
387   size_t value_len = 0;
388   (void)std::for_each(value_shape.begin() + 1, value_shape.end(), [&](int n) { value_len += n; });
389   MS_EXCEPTION_IF_NULL(value_ptr->data_ptr());
390   value_len *= static_cast<size_t>(value_ptr->data_ptr()->itemsize());
391   size_t value_expected_len = value_len * (value_shape[0] + 1);
392   MS_EXCEPTION_IF_CHECK_FAIL(value_expected_len <= host_length, "Size of value tensor is overflow.");
393 
394   for (int i = 0; i < l_len; i += batch_size) {
395     WarmUpHostCacheItem(host_hash_map_, hash_table_info_ptr, entry, i, std::min(i + batch_size, l_len), value_len);
396   }
397   MS_LOG(DEBUG) << "Exit warm up host cache item batch.";
398 }
399 
WarmUpHostCacheItem(const std::shared_ptr<EmbeddingHashMap> & embedding_hash_map,const HashTableInfo * hash_table_info_ptr,const WarmUpCacheMapEntry & entry,const int start,const int end,const size_t value_len)400 void EmbeddingCacheTableManager::WarmUpHostCacheItem(const std::shared_ptr<EmbeddingHashMap> &embedding_hash_map,
401                                                      const HashTableInfo *hash_table_info_ptr,
402                                                      const WarmUpCacheMapEntry &entry, const int start, const int end,
403                                                      const size_t value_len) {
404   // Value type is float, bit num is 2
405   const int shift_bit_num = 2;
406   MS_EXCEPTION_IF_NULL(hash_table_info_ptr);
407   if (hash_table_info_ptr->embedding_size != (value_len >> shift_bit_num)) {
408     MS_LOG(WARNING) << "Hash table info embedding_size : " << hash_table_info_ptr->embedding_size
409                     << " is not equal to value_len : " << value_len << ".";
410     return;
411   }
412 
413   auto key_ptr = std::get<0>(entry.second);
414   MS_EXCEPTION_IF_NULL(key_ptr);
415   auto key_data_ptr = key_ptr->data_ptr();
416   for (ssize_t i = start; i != end; i++) {
417     auto key_data_type = key_ptr->data_type();
418     int64_t key = 0;
419     switch (key_data_type) {
420       case TypeId::kNumberTypeInt32:
421       case TypeId::kNumberTypeUInt32: {
422         auto int_ptr = static_cast<int *>(key_ptr->data_c());
423         key = *(int_ptr + i);
424       } break;
425       case TypeId::kNumberTypeInt64:
426       case TypeId::kNumberTypeUInt64: {
427         auto int64_ptr = static_cast<int64_t *>(key_ptr->data_c());
428         key = *(int64_ptr + i);
429       } break;
430       default:
431         MS_LOG(WARNING) << "Invalid key_data_type : " << key_data_type << ".";
432         return;
433     }
434 
435     int id = embedding_hash_map->GetOrInsertDataUnsafe(static_cast<int>(key));
436     if (id == kInvalidIndexValue) {
437       MS_LOG(WARNING) << "Embedding hash map is full, exit warm up process.";
438       break;
439     }
440 
441     size_t offset = static_cast<size_t>(id) * value_len;
442     auto host_address = hash_table_info_ptr->host_address;
443     auto des_ptr = AddressOffset(host_address, 0);
444     auto value_data_ptr = std::get<1>(entry.second)->data_c();
445     auto src_ptr = AddressOffset(value_data_ptr, 0);
446     auto ret_code = memcpy_s(des_ptr + offset, value_len, src_ptr + i * value_len, value_len);
447     if (ret_code != EOK) {
448       MS_LOG(EXCEPTION) << "Failed to copy data, memcpy_s errorno: " << ret_code;
449     }
450   }
451 }
452 
GetInstance()453 EmbeddingStorageManager &EmbeddingStorageManager::GetInstance() {
454   static EmbeddingStorageManager instance{};
455   return instance;
456 }
457 
458 namespace {
459 /**
460  * @brief Create a new embedding storage instance for specific key and value type, and add the instance to
461  * EmbeddingStorageManager, this function is the implementation of function `CreateEmbeddingStorage`.
462  * @param[in] `embedding_key`: The unique parameter key for embedding table.
463  * @param[in] `embedding_dim`: The length of each embedding vector.
464  * @param[in] `capacity`: The capacity for new embedding storage.
465  */
466 template <typename KeyType, typename ValueType>
CreateEmbeddingStorageFunc(int32_t embedding_key,size_t embedding_dim,size_t capacity)467 void CreateEmbeddingStorageFunc(int32_t embedding_key, size_t embedding_dim, size_t capacity) {
468   std::shared_ptr<storage::AbstractEmbeddingStorage> embedding_storage = nullptr;
469   if (!EmbeddingCacheTableManager::GetInstance().is_sparse_format()) {
470     embedding_storage =
471       std::make_shared<storage::DenseEmbeddingStorage<KeyType, ValueType>>(embedding_key, embedding_dim, capacity);
472   } else {
473     embedding_storage =
474       std::make_shared<storage::SparseEmbeddingStorage<KeyType, ValueType>>(embedding_key, embedding_dim, capacity);
475   }
476   MS_EXCEPTION_IF_NULL(embedding_storage);
477   EmbeddingStorageManager::GetInstance().Add(embedding_key, embedding_storage);
478 }
479 
480 // Key-Value type pair -> CreateEmbeddingStorageFunc map.
481 const std::map<std::pair<TypeId, TypeId>, std::function<void(int32_t, size_t, size_t)>> kCreateEmbeddingStorageFuncs = {
482   {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeBool), CreateEmbeddingStorageFunc<int32_t, bool>},
483   {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt8), CreateEmbeddingStorageFunc<int32_t, int8_t>},
484   {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt16), CreateEmbeddingStorageFunc<int32_t, int16_t>},
485   {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32), CreateEmbeddingStorageFunc<int32_t, int32_t>},
486   {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt64), CreateEmbeddingStorageFunc<int32_t, int64_t>},
487   {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeUInt8), CreateEmbeddingStorageFunc<int32_t, uint8_t>},
488   {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeUInt16), CreateEmbeddingStorageFunc<int32_t, uint16_t>},
489   {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeUInt32), CreateEmbeddingStorageFunc<int32_t, uint32_t>},
490   {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeUInt64), CreateEmbeddingStorageFunc<int32_t, uint64_t>},
491   {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat16), CreateEmbeddingStorageFunc<int32_t, float16>},
492   {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32), CreateEmbeddingStorageFunc<int32_t, float>},
493   {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat64), CreateEmbeddingStorageFunc<int32_t, double>},
494   {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeBFloat16),
495    CreateEmbeddingStorageFunc<int32_t, bfloat16>},
496 
497   {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeBool), CreateEmbeddingStorageFunc<int64_t, bool>},
498   {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeInt8), CreateEmbeddingStorageFunc<int64_t, int8_t>},
499   {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeInt16), CreateEmbeddingStorageFunc<int64_t, int16_t>},
500   {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeInt32), CreateEmbeddingStorageFunc<int64_t, int32_t>},
501   {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeInt64), CreateEmbeddingStorageFunc<int64_t, int64_t>},
502   {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeUInt8), CreateEmbeddingStorageFunc<int64_t, uint8_t>},
503   {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeUInt16), CreateEmbeddingStorageFunc<int64_t, uint16_t>},
504   {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeUInt32), CreateEmbeddingStorageFunc<int64_t, uint32_t>},
505   {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeUInt64), CreateEmbeddingStorageFunc<int64_t, uint64_t>},
506   {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeFloat16), CreateEmbeddingStorageFunc<int64_t, float16>},
507   {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeFloat32), CreateEmbeddingStorageFunc<int64_t, float>},
508   {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeFloat64), CreateEmbeddingStorageFunc<int64_t, double>},
509   {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeBFloat16),
510    CreateEmbeddingStorageFunc<int64_t, bfloat16>}};
511 }  // namespace
512 
CreateEmbeddingStorage(std::pair<TypeId,TypeId> key_value_types,int32_t embedding_key,size_t embedding_dim,size_t capacity)513 void CreateEmbeddingStorage(std::pair<TypeId, TypeId> key_value_types, int32_t embedding_key, size_t embedding_dim,
514                             size_t capacity) {
515   const auto &iter = kCreateEmbeddingStorageFuncs.find(key_value_types);
516   if (iter == kCreateEmbeddingStorageFuncs.end()) {
517     MS_LOG(EXCEPTION) << "Can not find function to create embedding storage for key type:"
518                       << TypeIdToString(key_value_types.first)
519                       << ", value type:" << TypeIdToString(key_value_types.second);
520   }
521   iter->second(embedding_key, embedding_dim, capacity);
522 }
523 
EmbeddingDeviceCache(size_t batch_ids_num)524 EmbeddingDeviceCache::EmbeddingDeviceCache(size_t batch_ids_num) {
525   device_to_host_index = std::make_unique<int[]>(batch_ids_num);
526   device_to_host_ids = std::make_unique<int[]>(batch_ids_num);
527   host_to_device_index = std::make_unique<int[]>(batch_ids_num);
528   host_to_device_ids = std::make_unique<int[]>(batch_ids_num);
529 }
530 
EmbeddingHostCache(size_t batch_ids_num)531 EmbeddingHostCache::EmbeddingHostCache(size_t batch_ids_num) {
532   host_to_server_index = std::make_unique<int[]>(batch_ids_num);
533   host_to_server_ids = std::make_unique<int[]>(batch_ids_num);
534   server_to_host_index = std::make_unique<int[]>(batch_ids_num);
535   server_to_host_ids = std::make_unique<int[]>(batch_ids_num);
536   new_id_index = std::make_unique<int[]>(batch_ids_num);
537   host_to_device_index = std::make_unique<int[]>(batch_ids_num);
538   device_to_host_index = std::make_unique<int[]>(batch_ids_num);
539 }
540 
Add(int32_t param_key,const std::shared_ptr<storage::AbstractEmbeddingStorage> & embed_storage)541 void EmbeddingStorageManager::Add(int32_t param_key,
542                                   const std::shared_ptr<storage::AbstractEmbeddingStorage> &embed_storage) {
543   MS_EXCEPTION_IF_NULL(embed_storage);
544   embedding_storages_[param_key] = embed_storage;
545 }
546 
Get(int32_t param_key)547 std::shared_ptr<storage::AbstractEmbeddingStorage> EmbeddingStorageManager::Get(int32_t param_key) {
548   const auto &iter = embedding_storages_.find(param_key);
549   if (iter != embedding_storages_.end()) {
550     return iter->second;
551   }
552   MS_LOG(EXCEPTION) << "Can not find embedding storage for parameter key[" << param_key << "].";
553 }
554 
Clear()555 void EmbeddingStorageManager::Clear() {
556   for (const auto &item : embedding_storages_) {
557     const auto &embedding_storage = item.second;
558     MS_EXCEPTION_IF_NULL(embedding_storage);
559     embedding_storage->Finalize();
560   }
561 
562   embedding_storages_.clear();
563 }
564 }  // namespace distributed
565 }  // namespace mindspore
566