1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/backend/distributed/embedding_cache/embedding_cache_utils.h"
18 #include <algorithm>
19 #include <thread>
20 #include "utils/log_adapter.h"
21 #include "utils/ms_utils.h"
22 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
23 #include "include/backend/distributed/cluster/cluster_context.h"
24 #endif
25 #include "include/backend/distributed/ps/ps_context.h"
26 #include "distributed/embedding_cache/embedding_storage/dense_embedding_storage.h"
27 #include "distributed/embedding_cache/embedding_storage/sparse_embedding_storage.h"
28 #include "include/backend/distributed/embedding_cache/embedding_storage/abstract_embedding_storage.h"
29
30 namespace mindspore {
31 namespace distributed {
GetInstance()32 EmbeddingCacheTableManager &EmbeddingCacheTableManager::GetInstance() {
33 static EmbeddingCacheTableManager instance{};
34 return instance;
35 }
36
Initialize()37 void EmbeddingCacheTableManager::Initialize() {
38 auto worker_num = ps::PSContext::instance()->worker_num();
39 multi_batch_threshold_ = worker_num > 1 ? 1 : kMultiBatchThreshold;
40 GetEmbeddingTableSliceBound();
41
42 device::DeviceContextKey host_key = {"CPU", 0};
43 cpu_device_context_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(host_key);
44 MS_EXCEPTION_IF_NULL(cpu_device_context_);
45 cpu_device_context_->Initialize();
46 }
47
Finalize(const device::DeviceContext * device_context)48 void EmbeddingCacheTableManager::Finalize(const device::DeviceContext *device_context) {
49 hash_tables_.clear();
50
51 device_hash_map_ = nullptr;
52 host_hash_map_ = nullptr;
53
54 MS_EXCEPTION_IF_NULL(device_context);
55 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
56 if (hash_swap_index_addr_) {
57 device_context->device_res_manager_->FreeMemory(hash_swap_index_addr_);
58 }
59 if (hash_swap_value_addr_) {
60 device_context->device_res_manager_->FreeMemory(hash_swap_value_addr_);
61 }
62 for (auto &item : hash_tables_) {
63 if (item.second.host_address) {
64 MS_EXCEPTION_IF_NULL(cpu_device_context_);
65 MS_EXCEPTION_IF_NULL(cpu_device_context_->device_res_manager_);
66 cpu_device_context_->device_res_manager_->FreeMemory(item.second.host_address);
67 }
68 }
69 }
70
InsertHashTableSize(const std::string & param_name,size_t cache_vocab_size,size_t embedding_size,size_t vocab_size,int32_t param_key)71 void EmbeddingCacheTableManager::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size,
72 size_t embedding_size, size_t vocab_size, int32_t param_key) {
73 if (cache_vocab_size == 0 || embedding_size == 0 || vocab_size == 0) {
74 MS_LOG(EXCEPTION) << "The size of hash table can not equal to zero.";
75 }
76 hash_tables_[param_name].cache_vocab_size = cache_vocab_size;
77 hash_tables_[param_name].host_cache_vocab_size = cache_vocab_size * kHostCacheScaleFactor;
78 hash_tables_[param_name].embedding_size = embedding_size;
79 hash_tables_[param_name].vocab_size = vocab_size;
80 hash_tables_[param_name].param_key_ = param_key;
81
82 if (vocab_size_ == 0) {
83 vocab_size_ = vocab_size;
84 }
85 if (device_cache_size_ == 0) {
86 device_cache_size_ = cache_vocab_size;
87 }
88 if (host_cache_size_ == 0) {
89 host_cache_size_ = cache_vocab_size * kHostCacheScaleFactor;
90 }
91 }
92
ReInsertHashTableSize(const std::string & new_param_name,const std::string & cur_param_name)93 void EmbeddingCacheTableManager::ReInsertHashTableSize(const std::string &new_param_name,
94 const std::string &cur_param_name) {
95 if (new_param_name.empty() || cur_param_name.empty()) {
96 MS_LOG(EXCEPTION) << "Parameter name can not be empty.";
97 }
98 if (new_param_name == cur_param_name) {
99 return;
100 }
101 auto iter = hash_tables_.find(cur_param_name);
102 if (iter == hash_tables_.end()) {
103 MS_LOG(EXCEPTION) << "Can not find parameter[" << cur_param_name << "] in hash table.";
104 }
105 (void)hash_tables_.emplace(new_param_name, iter->second);
106 (void)hash_tables_.erase(iter);
107 }
108
InsertAccumuInitInfo(const std::string & param_name,float init_val)109 void EmbeddingCacheTableManager::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) {
110 auto iter = hash_tables_.find(param_name);
111 if (iter == hash_tables_.end()) {
112 MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table.";
113 }
114 auto &hash_table_info = iter->second;
115 if (hash_table_info.param_init_info_.param_type_ != kUnKnown) {
116 return;
117 }
118 MS_LOG(INFO) << "Insert accumulation init info:" << param_name << ", init value:" << init_val;
119 hash_table_info.param_init_info_.param_name_ = param_name;
120 hash_table_info.param_init_info_.param_type_ = kAccumulation;
121 hash_table_info.param_init_info_.init_val_ = init_val;
122 }
123
CloneHashTable(const std::string & dest_param_name,int32_t dest_param_key,const std::string & src_param_name,int32_t src_param_key)124 void EmbeddingCacheTableManager::CloneHashTable(const std::string &dest_param_name, int32_t dest_param_key,
125 const std::string &src_param_name, int32_t src_param_key) {
126 if (dest_param_name == src_param_name) {
127 MS_LOG(INFO) << "The dest_param_name is same as src_param_name";
128 return;
129 }
130 auto iter = hash_tables_.find(src_param_name);
131 if (iter == hash_tables_.end()) {
132 MS_LOG(EXCEPTION) << "The source hash table[" << src_param_name << "] does not exist, clone failed.";
133 }
134 (void)hash_tables_.emplace(dest_param_name, iter->second);
135 hash_tables_[src_param_name].param_key_ = src_param_key;
136 hash_tables_[dest_param_name].param_key_ = dest_param_key;
137 }
138
QueryEmbeddingDeviceAddress(const std::string & param_name) const139 const DeviceAddress *EmbeddingCacheTableManager::QueryEmbeddingDeviceAddress(const std::string ¶m_name) const {
140 auto iter = hash_tables_.find(param_name);
141 if (iter == hash_tables_.end()) {
142 MS_LOG(EXCEPTION) << "Can not find device address of " << param_name;
143 }
144 return iter->second.device_address;
145 }
146
QueryHashTableSize(const std::string & param_name) const147 size_t EmbeddingCacheTableManager::QueryHashTableSize(const std::string ¶m_name) const {
148 auto iter = hash_tables_.find(param_name);
149 if (iter == hash_tables_.end()) {
150 MS_LOG(EXCEPTION) << "Can not find vocab cache size of " << param_name;
151 }
152 return iter->second.cache_vocab_size;
153 }
154
SetEmbeddingDeviceAddress(const std::string & param_name,DeviceAddress * device_address)155 void EmbeddingCacheTableManager::SetEmbeddingDeviceAddress(const std::string ¶m_name,
156 DeviceAddress *device_address) {
157 MS_EXCEPTION_IF_NULL(device_address);
158 auto iter = hash_tables_.find(param_name);
159 if (iter == hash_tables_.end()) {
160 MS_LOG(EXCEPTION) << "Can not find hash table info for " << param_name;
161 }
162 iter->second.device_address = device_address;
163 }
164
AllocMemForEmbedding(const device::DeviceContext * device_context)165 void EmbeddingCacheTableManager::AllocMemForEmbedding(const device::DeviceContext *device_context) {
166 MS_EXCEPTION_IF_NULL(device_context);
167 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
168
169 size_t max_embedding_size = 0;
170 for (auto &item : hash_tables_) {
171 auto *device_address = item.second.device_address;
172 MS_EXCEPTION_IF_NULL(device_address);
173 if (device_address->GetPtr() == nullptr) {
174 MS_EXCEPTION_IF_CHECK_FAIL(device_context->device_res_manager_->AllocateMemory(device_address),
175 "Allocate device memory for embedding table failed.");
176 }
177 item.second.address = Address(device_address->GetMutablePtr(), device_address->GetSize());
178
179 size_t embedding_size = item.second.embedding_size;
180 auto &host_address = item.second.host_address;
181 host_address = reinterpret_cast<float *>(
182 cpu_device_context_->device_res_manager_->AllocateMemory(host_cache_size_ * embedding_size * sizeof(float)));
183 MS_EXCEPTION_IF_NULL(host_address);
184
185 max_embedding_size = (embedding_size > max_embedding_size) ? embedding_size : max_embedding_size;
186 }
187
188 device_hash_map_ = std::make_shared<EmbeddingHashMap>(device_cache_size_);
189 MS_EXCEPTION_IF_NULL(device_hash_map_);
190 host_hash_map_ = std::make_shared<EmbeddingHashMap>(host_cache_size_);
191 MS_EXCEPTION_IF_NULL(host_hash_map_);
192
193 hash_swap_index_addr_ = reinterpret_cast<int *>(
194 device_context->device_res_manager_->AllocateMemory(batch_ids_num_ * sizeof(int) * multi_batch_threshold_));
195 MS_EXCEPTION_IF_NULL(hash_swap_index_addr_);
196 hash_swap_value_addr_ = reinterpret_cast<float *>(device_context->device_res_manager_->AllocateMemory(
197 max_embedding_size * batch_ids_num_ * sizeof(float) * multi_batch_threshold_));
198 MS_EXCEPTION_IF_NULL(hash_swap_value_addr_);
199 }
200
GetEmbeddingTableSliceBound()201 void EmbeddingCacheTableManager::GetEmbeddingTableSliceBound() {
202 auto worker_num = ps::PSContext::instance()->worker_num();
203 auto server_num = ps::PSContext::instance()->server_num();
204 if (worker_num == 0) {
205 return;
206 }
207 if (is_sparse_format() && (worker_num > 1 || server_num > 1)) {
208 MS_LOG(EXCEPTION) << "The sparse format can not support multi worker or multi server currently.";
209 }
210
211 uint32_t rank_id = 0;
212 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
213 auto node = distributed::cluster::ClusterContext::instance()->node();
214 MS_EXCEPTION_IF_NULL(node);
215 rank_id = node->rank_id();
216 #endif
217
218 if (!is_sparse_format()) {
219 int local_shard_size = UintToInt(SizeToUint(vocab_size_) / worker_num);
220 if (vocab_size_ % worker_num != 0) {
221 local_shard_size += 1;
222 }
223 local_embedding_slice_bounds_.first = local_shard_size * UintToInt(rank_id);
224 local_embedding_slice_bounds_.second =
225 std::min(local_embedding_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_));
226 } else {
227 local_embedding_slice_bounds_.first = 0;
228 local_embedding_slice_bounds_.second = INT_MAX;
229 }
230 local_device_cache_bounds_.first = SizeToInt(device_cache_size_) * UintToInt(rank_id);
231 local_device_cache_bounds_.second = local_device_cache_bounds_.first + SizeToInt(device_cache_size_);
232 MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id
233 << ", id begin:" << local_embedding_slice_bounds_.first
234 << ", id end:" << local_embedding_slice_bounds_.second
235 << ", cache indices begin: " << local_device_cache_bounds_.first
236 << ", cache indices end: " << local_device_cache_bounds_.second << ", vocab_size: " << vocab_size_
237 << ", device cache size: " << device_cache_size_;
238 }
239
cache_indices_lower_bound() const240 int EmbeddingCacheTableManager::cache_indices_lower_bound() const { return local_device_cache_bounds_.first; }
241
DumpHashTables() const242 void EmbeddingCacheTableManager::DumpHashTables() const {
243 for (const auto &item : hash_tables_) {
244 const auto ¶m_name = item.first;
245 size_t cache_vocab_size = item.second.cache_vocab_size;
246 size_t host_cache_vocab_size = item.second.host_cache_vocab_size;
247 size_t embedding_size = item.second.embedding_size;
248 size_t vocab_size = item.second.vocab_size;
249 int32_t param_key = item.second.param_key_;
250 MS_LOG(INFO) << "Hash table info:"
251 << " param_key:" << param_key << ", embedding table name:" << param_name
252 << ", vocab size:" << vocab_size << ", embedding size:" << embedding_size
253 << ", device cache size:" << cache_vocab_size << ", host cache size:" << host_cache_vocab_size
254 << ", device cache address:" << item.second.address.addr
255 << ", host cache address:" << item.second.host_address;
256 }
257 }
258
enable_pipeline() const259 bool EmbeddingCacheTableManager::enable_pipeline() const {
260 return ps::PSContext::instance()->is_worker() && ps::PSContext::instance()->cache_enable();
261 }
262
StoreWarmUpPtr(const int32_t param_key,const tensor::TensorPtr & tensor_ptr)263 int32_t EmbeddingCacheTableManager::StoreWarmUpPtr(const int32_t param_key, const tensor::TensorPtr &tensor_ptr) {
264 return StoreWarmUpPtr(param_key, nullptr, tensor_ptr, nullptr);
265 }
266
StoreWarmUpPtr(const int32_t param_key,const tensor::TensorPtr & key_ptr,const tensor::TensorPtr & value_ptr,const tensor::TensorPtr & status_ptr)267 int32_t EmbeddingCacheTableManager::StoreWarmUpPtr(const int32_t param_key, const tensor::TensorPtr &key_ptr,
268 const tensor::TensorPtr &value_ptr,
269 const tensor::TensorPtr &status_ptr) {
270 MS_LOG(INFO) << "Enter store warm up ptr, param_key : " << param_key << ".";
271 MS_EXCEPTION_IF_NULL(value_ptr);
272 std::unique_lock<std::mutex> lock(host_cache_mutex_);
273 auto ret = host_cache_ptrs_.find(param_key);
274 if (ret != host_cache_ptrs_.end()) {
275 MS_LOG(WARNING) << "Store warm up ptr duplicate, id : " << param_key << ".";
276 }
277 (void)host_cache_ptrs_.try_emplace(param_key, key_ptr, value_ptr, status_ptr);
278 MS_LOG(INFO) << "Exit store warm up ptr, host cache ptrs size : " << host_cache_ptrs_.size()
279 << ", hash tables size : " << hash_tables_.size() << ".";
280 return 0;
281 }
282
FindHashTablesByParamKey(const int param_key)283 const HashTableInfo *EmbeddingCacheTableManager::FindHashTablesByParamKey(const int param_key) {
284 const auto &iter = std::find_if(hash_tables_.begin(), hash_tables_.end(),
285 [this, param_key](const auto &data) { return data.second.param_key_ == param_key; });
286 return iter != hash_tables_.end() ? &(iter->second) : nullptr;
287 }
288
WarmUpHostCacheSync(const int32_t batch_count)289 void EmbeddingCacheTableManager::WarmUpHostCacheSync(const int32_t batch_count) {
290 MS_LOG(INFO) << "Enter warm up host cache sync, batch_count : " << batch_count << ".";
291 auto cache_ptr_size = host_cache_ptrs_.size();
292 auto hash_table_size = hash_tables_.size();
293 if (cache_ptr_size != hash_table_size) {
294 MS_LOG(WARNING) << "Host cache ptrs size : " << cache_ptr_size
295 << " is not equal to hash table size : " << hash_table_size
296 << ", will skip warm up host cache sync.";
297 std::unique_lock<std::mutex> lock(host_cache_mutex_);
298 host_cache_promise_->set_value(false);
299 lock.unlock();
300 host_cache_ptrs_.clear();
301 return;
302 }
303
304 for (auto &item : host_cache_ptrs_) {
305 WarmUpHostCacheItemBatch(batch_count, item);
306 }
307 std::unique_lock<std::mutex> lock(host_cache_mutex_);
308 host_cache_promise_->set_value(true);
309 lock.unlock();
310 host_cache_ptrs_.clear();
311 MS_LOG(INFO) << "Exit warm up host cache sync.";
312 }
313
WarmUpHostCacheAsync(const int32_t batch_count)314 void EmbeddingCacheTableManager::WarmUpHostCacheAsync(const int32_t batch_count) {
315 MS_LOG(DEBUG) << "Enter warm up host cache async, batch_count : " << batch_count << ".";
316 std::unique_lock<std::mutex> lock(host_cache_mutex_);
317 if (host_cache_promise_ != nullptr) {
318 lock.unlock();
319 MS_LOG(WARNING) << "Host cache promise is not null, cache sync has already done.";
320 return;
321 }
322 host_cache_promise_ = std::make_shared<std::promise<bool>>();
323 lock.unlock();
324 std::thread([this, batch_count]() { WarmUpHostCacheSync(batch_count); }).detach();
325 MS_LOG(DEBUG) << "Exit warm up host cache async.";
326 }
327
GetWarmUpHostCacheAsyncStatus()328 std::pair<std::shared_ptr<std::future<bool>>, bool> EmbeddingCacheTableManager::GetWarmUpHostCacheAsyncStatus() {
329 MS_LOG(DEBUG) << "Enter get warm up host cache async status.";
330 std::unique_lock<std::mutex> lock(host_cache_mutex_);
331 if (host_cache_promise_ == nullptr) {
332 return std::make_pair(nullptr, false);
333 }
334 return std::make_pair(std::make_shared<std::future<bool>>(host_cache_promise_->get_future()), true);
335 }
336
WaitForWarmUpHostCacheComplete()337 bool EmbeddingCacheTableManager::WaitForWarmUpHostCacheComplete() {
338 MS_LOG(DEBUG) << "Enter wait for warm up host cache complete.";
339 const int32_t batch_count = 4;
340 WarmUpHostCacheAsync(batch_count);
341 const auto &[complete_future, status] = GetWarmUpHostCacheAsyncStatus();
342 return status ? complete_future->get() : status;
343 }
344
generate_key_tensor_ptr(const tensor::TensorPtr & tensor_ptr)345 tensor::TensorPtr generate_key_tensor_ptr(const tensor::TensorPtr &tensor_ptr) {
346 auto &vec = tensor_ptr->shape();
347 auto cel_num = static_cast<int>(vec[0]);
348 std::vector<int32_t> key_vec(cel_num);
349 for (auto i = 0; i != cel_num; i++) {
350 key_vec[i] = i;
351 }
352 return std::make_shared<tensor::Tensor>(key_vec);
353 }
354
WarmUpHostCacheItemBatch(const int32_t batch_count,const WarmUpCacheMapEntry & entry)355 void EmbeddingCacheTableManager::WarmUpHostCacheItemBatch(const int32_t batch_count, const WarmUpCacheMapEntry &entry) {
356 MS_LOG(DEBUG) << "Enter warm up host cache item batch.";
357 auto key_ptr = std::get<0>(entry.second);
358 auto value_ptr = std::get<1>(entry.second);
359 MS_EXCEPTION_IF_NULL(value_ptr);
360 // Key tensor may be nullptr since we stored single value tensor.
361 if (key_ptr == nullptr) {
362 MS_LOG(INFO) << "key_ptr is nullptr, generate key tensor.";
363 key_ptr = generate_key_tensor_ptr(value_ptr);
364 }
365 auto &vec = key_ptr->shape();
366 auto l_len = static_cast<int>(vec[0]);
367 const int32_t default_batch_count = 1;
368 const int validate_batch_count = batch_count < default_batch_count ? default_batch_count : batch_count;
369 int batch_size = l_len / validate_batch_count;
370 if (l_len % validate_batch_count != 0) {
371 batch_size++;
372 }
373
374 if (host_hash_map_ == nullptr) {
375 MS_LOG(WARNING) << "Embedding hash map of embedding host cache is nullptr, will skip warm up.";
376 return;
377 }
378
379 auto hash_table_info_ptr = FindHashTablesByParamKey(entry.first);
380 if (hash_table_info_ptr == nullptr) {
381 MS_LOG(WARNING) << "Hash table info is nullptr, will skip warm up.";
382 return;
383 }
384
385 size_t host_length = (hash_table_info_ptr->host_cache_vocab_size * hash_table_info_ptr->embedding_size) << 2;
386 auto &value_shape = value_ptr->shape();
387 size_t value_len = 0;
388 (void)std::for_each(value_shape.begin() + 1, value_shape.end(), [&](int n) { value_len += n; });
389 MS_EXCEPTION_IF_NULL(value_ptr->data_ptr());
390 value_len *= static_cast<size_t>(value_ptr->data_ptr()->itemsize());
391 size_t value_expected_len = value_len * (value_shape[0] + 1);
392 MS_EXCEPTION_IF_CHECK_FAIL(value_expected_len <= host_length, "Size of value tensor is overflow.");
393
394 for (int i = 0; i < l_len; i += batch_size) {
395 WarmUpHostCacheItem(host_hash_map_, hash_table_info_ptr, entry, i, std::min(i + batch_size, l_len), value_len);
396 }
397 MS_LOG(DEBUG) << "Exit warm up host cache item batch.";
398 }
399
WarmUpHostCacheItem(const std::shared_ptr<EmbeddingHashMap> & embedding_hash_map,const HashTableInfo * hash_table_info_ptr,const WarmUpCacheMapEntry & entry,const int start,const int end,const size_t value_len)400 void EmbeddingCacheTableManager::WarmUpHostCacheItem(const std::shared_ptr<EmbeddingHashMap> &embedding_hash_map,
401 const HashTableInfo *hash_table_info_ptr,
402 const WarmUpCacheMapEntry &entry, const int start, const int end,
403 const size_t value_len) {
404 // Value type is float, bit num is 2
405 const int shift_bit_num = 2;
406 MS_EXCEPTION_IF_NULL(hash_table_info_ptr);
407 if (hash_table_info_ptr->embedding_size != (value_len >> shift_bit_num)) {
408 MS_LOG(WARNING) << "Hash table info embedding_size : " << hash_table_info_ptr->embedding_size
409 << " is not equal to value_len : " << value_len << ".";
410 return;
411 }
412
413 auto key_ptr = std::get<0>(entry.second);
414 MS_EXCEPTION_IF_NULL(key_ptr);
415 auto key_data_ptr = key_ptr->data_ptr();
416 for (ssize_t i = start; i != end; i++) {
417 auto key_data_type = key_ptr->data_type();
418 int64_t key = 0;
419 switch (key_data_type) {
420 case TypeId::kNumberTypeInt32:
421 case TypeId::kNumberTypeUInt32: {
422 auto int_ptr = static_cast<int *>(key_ptr->data_c());
423 key = *(int_ptr + i);
424 } break;
425 case TypeId::kNumberTypeInt64:
426 case TypeId::kNumberTypeUInt64: {
427 auto int64_ptr = static_cast<int64_t *>(key_ptr->data_c());
428 key = *(int64_ptr + i);
429 } break;
430 default:
431 MS_LOG(WARNING) << "Invalid key_data_type : " << key_data_type << ".";
432 return;
433 }
434
435 int id = embedding_hash_map->GetOrInsertDataUnsafe(static_cast<int>(key));
436 if (id == kInvalidIndexValue) {
437 MS_LOG(WARNING) << "Embedding hash map is full, exit warm up process.";
438 break;
439 }
440
441 size_t offset = static_cast<size_t>(id) * value_len;
442 auto host_address = hash_table_info_ptr->host_address;
443 auto des_ptr = AddressOffset(host_address, 0);
444 auto value_data_ptr = std::get<1>(entry.second)->data_c();
445 auto src_ptr = AddressOffset(value_data_ptr, 0);
446 auto ret_code = memcpy_s(des_ptr + offset, value_len, src_ptr + i * value_len, value_len);
447 if (ret_code != EOK) {
448 MS_LOG(EXCEPTION) << "Failed to copy data, memcpy_s errorno: " << ret_code;
449 }
450 }
451 }
452
GetInstance()453 EmbeddingStorageManager &EmbeddingStorageManager::GetInstance() {
454 static EmbeddingStorageManager instance{};
455 return instance;
456 }
457
458 namespace {
459 /**
460 * @brief Create a new embedding storage instance for specific key and value type, and add the instance to
461 * EmbeddingStorageManager, this function is the implementation of function `CreateEmbeddingStorage`.
462 * @param[in] `embedding_key`: The unique parameter key for embedding table.
463 * @param[in] `embedding_dim`: The length of each embedding vector.
464 * @param[in] `capacity`: The capacity for new embedding storage.
465 */
466 template <typename KeyType, typename ValueType>
CreateEmbeddingStorageFunc(int32_t embedding_key,size_t embedding_dim,size_t capacity)467 void CreateEmbeddingStorageFunc(int32_t embedding_key, size_t embedding_dim, size_t capacity) {
468 std::shared_ptr<storage::AbstractEmbeddingStorage> embedding_storage = nullptr;
469 if (!EmbeddingCacheTableManager::GetInstance().is_sparse_format()) {
470 embedding_storage =
471 std::make_shared<storage::DenseEmbeddingStorage<KeyType, ValueType>>(embedding_key, embedding_dim, capacity);
472 } else {
473 embedding_storage =
474 std::make_shared<storage::SparseEmbeddingStorage<KeyType, ValueType>>(embedding_key, embedding_dim, capacity);
475 }
476 MS_EXCEPTION_IF_NULL(embedding_storage);
477 EmbeddingStorageManager::GetInstance().Add(embedding_key, embedding_storage);
478 }
479
480 // Key-Value type pair -> CreateEmbeddingStorageFunc map.
481 const std::map<std::pair<TypeId, TypeId>, std::function<void(int32_t, size_t, size_t)>> kCreateEmbeddingStorageFuncs = {
482 {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeBool), CreateEmbeddingStorageFunc<int32_t, bool>},
483 {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt8), CreateEmbeddingStorageFunc<int32_t, int8_t>},
484 {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt16), CreateEmbeddingStorageFunc<int32_t, int16_t>},
485 {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32), CreateEmbeddingStorageFunc<int32_t, int32_t>},
486 {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt64), CreateEmbeddingStorageFunc<int32_t, int64_t>},
487 {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeUInt8), CreateEmbeddingStorageFunc<int32_t, uint8_t>},
488 {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeUInt16), CreateEmbeddingStorageFunc<int32_t, uint16_t>},
489 {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeUInt32), CreateEmbeddingStorageFunc<int32_t, uint32_t>},
490 {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeUInt64), CreateEmbeddingStorageFunc<int32_t, uint64_t>},
491 {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat16), CreateEmbeddingStorageFunc<int32_t, float16>},
492 {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32), CreateEmbeddingStorageFunc<int32_t, float>},
493 {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat64), CreateEmbeddingStorageFunc<int32_t, double>},
494 {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeBFloat16),
495 CreateEmbeddingStorageFunc<int32_t, bfloat16>},
496
497 {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeBool), CreateEmbeddingStorageFunc<int64_t, bool>},
498 {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeInt8), CreateEmbeddingStorageFunc<int64_t, int8_t>},
499 {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeInt16), CreateEmbeddingStorageFunc<int64_t, int16_t>},
500 {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeInt32), CreateEmbeddingStorageFunc<int64_t, int32_t>},
501 {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeInt64), CreateEmbeddingStorageFunc<int64_t, int64_t>},
502 {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeUInt8), CreateEmbeddingStorageFunc<int64_t, uint8_t>},
503 {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeUInt16), CreateEmbeddingStorageFunc<int64_t, uint16_t>},
504 {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeUInt32), CreateEmbeddingStorageFunc<int64_t, uint32_t>},
505 {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeUInt64), CreateEmbeddingStorageFunc<int64_t, uint64_t>},
506 {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeFloat16), CreateEmbeddingStorageFunc<int64_t, float16>},
507 {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeFloat32), CreateEmbeddingStorageFunc<int64_t, float>},
508 {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeFloat64), CreateEmbeddingStorageFunc<int64_t, double>},
509 {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeBFloat16),
510 CreateEmbeddingStorageFunc<int64_t, bfloat16>}};
511 } // namespace
512
CreateEmbeddingStorage(std::pair<TypeId,TypeId> key_value_types,int32_t embedding_key,size_t embedding_dim,size_t capacity)513 void CreateEmbeddingStorage(std::pair<TypeId, TypeId> key_value_types, int32_t embedding_key, size_t embedding_dim,
514 size_t capacity) {
515 const auto &iter = kCreateEmbeddingStorageFuncs.find(key_value_types);
516 if (iter == kCreateEmbeddingStorageFuncs.end()) {
517 MS_LOG(EXCEPTION) << "Can not find function to create embedding storage for key type:"
518 << TypeIdToString(key_value_types.first)
519 << ", value type:" << TypeIdToString(key_value_types.second);
520 }
521 iter->second(embedding_key, embedding_dim, capacity);
522 }
523
EmbeddingDeviceCache(size_t batch_ids_num)524 EmbeddingDeviceCache::EmbeddingDeviceCache(size_t batch_ids_num) {
525 device_to_host_index = std::make_unique<int[]>(batch_ids_num);
526 device_to_host_ids = std::make_unique<int[]>(batch_ids_num);
527 host_to_device_index = std::make_unique<int[]>(batch_ids_num);
528 host_to_device_ids = std::make_unique<int[]>(batch_ids_num);
529 }
530
EmbeddingHostCache(size_t batch_ids_num)531 EmbeddingHostCache::EmbeddingHostCache(size_t batch_ids_num) {
532 host_to_server_index = std::make_unique<int[]>(batch_ids_num);
533 host_to_server_ids = std::make_unique<int[]>(batch_ids_num);
534 server_to_host_index = std::make_unique<int[]>(batch_ids_num);
535 server_to_host_ids = std::make_unique<int[]>(batch_ids_num);
536 new_id_index = std::make_unique<int[]>(batch_ids_num);
537 host_to_device_index = std::make_unique<int[]>(batch_ids_num);
538 device_to_host_index = std::make_unique<int[]>(batch_ids_num);
539 }
540
Add(int32_t param_key,const std::shared_ptr<storage::AbstractEmbeddingStorage> & embed_storage)541 void EmbeddingStorageManager::Add(int32_t param_key,
542 const std::shared_ptr<storage::AbstractEmbeddingStorage> &embed_storage) {
543 MS_EXCEPTION_IF_NULL(embed_storage);
544 embedding_storages_[param_key] = embed_storage;
545 }
546
Get(int32_t param_key)547 std::shared_ptr<storage::AbstractEmbeddingStorage> EmbeddingStorageManager::Get(int32_t param_key) {
548 const auto &iter = embedding_storages_.find(param_key);
549 if (iter != embedding_storages_.end()) {
550 return iter->second;
551 }
552 MS_LOG(EXCEPTION) << "Can not find embedding storage for parameter key[" << param_key << "].";
553 }
554
Clear()555 void EmbeddingStorageManager::Clear() {
556 for (const auto &item : embedding_storages_) {
557 const auto &embedding_storage = item.second;
558 MS_EXCEPTION_IF_NULL(embedding_storage);
559 embedding_storage->Finalize();
560 }
561
562 embedding_storages_.clear();
563 }
564 } // namespace distributed
565 } // namespace mindspore
566