• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "distributed/embedding_cache/embedding_storage/sparse_embedding_storage.h"
18 #include <memory>
19 #include <vector>
20 #include <algorithm>
21 
22 namespace mindspore {
23 namespace distributed {
24 namespace storage {
25 constexpr size_t kMegaByteToByteRate = static_cast<size_t>(1) << 20;
26 
27 template <typename KeyType, typename ValueType, typename Allocator>
Initialize(const DeviceAddress * device_address)28 void SparseEmbeddingStorage<KeyType, ValueType, Allocator>::Initialize(const DeviceAddress *device_address) {
29   MS_EXCEPTION_IF_NULL(device_address);
30   EmbeddingStorage<KeyType, ValueType, Allocator>::Initialize(device_address);
31 
32   auto user_data = device_address->user_data();
33   MS_EXCEPTION_IF_NULL(user_data);
34   hash_table_ = user_data->get<HashTable>(kUserDataData).get();
35   MS_EXCEPTION_IF_NULL(hash_table_);
36 }
37 
38 template <typename KeyType, typename ValueType, typename Allocator>
Finalize()39 void SparseEmbeddingStorage<KeyType, ValueType, Allocator>::Finalize() {
40   hash_table_ = nullptr;
41   EmbeddingStorage<KeyType, ValueType, Allocator>::Finalize();
42 }
43 
44 template <typename KeyType, typename ValueType, typename Allocator>
Get(const ConstDataWithLen & keys,const DataWithLen & values)45 bool SparseEmbeddingStorage<KeyType, ValueType, Allocator>::Get(const ConstDataWithLen &keys,
46                                                                 const DataWithLen &values) {
47   const KeyType *keys_data = reinterpret_cast<const KeyType *>(keys.data_);
48   ValueType *values_data = reinterpret_cast<ValueType *>(values.data_);
49   size_t key_num = keys.data_len_ / sizeof(KeyType);
50   if (values.data_len_ < key_num * this->embedding_dim_ * sizeof(ValueType)) {
51     MS_LOG(EXCEPTION) << "The value buffer length is insufficient.";
52   }
53   MS_EXCEPTION_IF_NULL(keys_data);
54   MS_EXCEPTION_IF_NULL(values_data);
55   MS_EXCEPTION_IF_NULL(hash_table_);
56 
57   // 1. Query cache to analyse the information of cache hit and miss keys, update the positions of cache hit elements in
58   // the cache (cache refresh).
59   size_t cache_miss_cnt = 0;
60   size_t *cache_miss_offsets = this->template AllocateMemory<size_t>(sizeof(size_t) * key_num);
61   MS_EXCEPTION_IF_NULL(cache_miss_offsets);
62   bool *cache_hit = this->template AllocateMemory<bool>(sizeof(bool) * key_num);
63   MS_EXCEPTION_IF_NULL(cache_hit);
64   QueryCache(keys_data, key_num, cache_miss_offsets, &cache_miss_cnt, cache_hit);
65 
66   // 2. Copy the embeddings from cache to the returned values for cache hit keys.
67   for (size_t i = 0; i < key_num; i++) {
68     if (!cache_hit[i]) {
69       continue;
70     }
71     RETURN_IF_FALSE_WITH_LOG(
72       hash_table_->Find(keys_data + i, 1, false, values_data + this->embedding_dim_ * i, nullptr),
73       "Find key from hash table failed.");
74   }
75 
76   if (cache_miss_cnt == 0) {
77     return true;
78   }
79 
80   // 3. Reserve space for cache miss keys in the cache (if there is enough space in the cache, then do nothing), write
81   // the evicted element to persistent storage, and record the space in the cache, using the space in the cache first.
82   RETURN_IF_FALSE_WITH_LOG(TryEvict(cache_miss_cnt), "Reserve space for miss keys failed.");
83 
84   // 4. Insert the cache miss elements into the cache from persistent storage, and copy them to the returned values.
85   RETURN_IF_FALSE_WITH_LOG(InsertMissCacheFromStorage(keys_data, cache_miss_offsets, cache_miss_cnt, values_data),
86                            "Insert the cache miss elements into the cache from persistent storage failed.");
87 
88   this->FreeMemory(cache_hit);
89   this->FreeMemory(cache_miss_offsets);
90   return true;
91 }
92 
93 template <typename KeyType, typename ValueType, typename Allocator>
Put(const ConstDataWithLen & keys,const ConstDataWithLen & values)94 bool SparseEmbeddingStorage<KeyType, ValueType, Allocator>::Put(const ConstDataWithLen &keys,
95                                                                 const ConstDataWithLen &values) {
96   const KeyType *keys_data = reinterpret_cast<const KeyType *>(keys.data_);
97   const ValueType *values_data = reinterpret_cast<const ValueType *>(values.data_);
98   size_t key_num = keys.data_len_ / sizeof(KeyType);
99   if (values.data_len_ != key_num * this->embedding_dim_ * sizeof(ValueType)) {
100     MS_LOG(EXCEPTION) << "The value length is invalid, expected length["
101                       << key_num * this->embedding_dim_ * sizeof(ValueType) << "], but got[" << values.data_len_ << "]";
102   }
103   MS_EXCEPTION_IF_NULL(keys_data);
104   MS_EXCEPTION_IF_NULL(values_data);
105   MS_EXCEPTION_IF_NULL(hash_table_);
106 
107   // 1. Query cache to analyse the information of cache hit and miss keys, update the positions of cache hit elements in
108   // the cache (cache refresh).
109   size_t cache_miss_cnt = 0;
110   size_t *cache_miss_offsets = this->template AllocateMemory<size_t>(sizeof(size_t) * key_num);
111   MS_EXCEPTION_IF_NULL(cache_miss_offsets);
112   bool *cache_hit = this->template AllocateMemory<bool>(sizeof(bool) * key_num);
113   MS_EXCEPTION_IF_NULL(cache_hit);
114   QueryCache(keys_data, key_num, cache_miss_offsets, &cache_miss_cnt, cache_hit);
115 
116   // 2. Update the embedding value to the cache for cache hit keys.
117   for (size_t i = 0; i < key_num; i++) {
118     if (!cache_hit[i]) {
119       continue;
120     }
121     RETURN_IF_FALSE_WITH_LOG(hash_table_->Insert(keys_data + i, 1, values_data + this->embedding_dim_ * i, nullptr),
122                              "Insert hash table failed.");
123   }
124 
125   if (cache_miss_cnt == 0) {
126     return true;
127   }
128 
129   // 3. Reserve space for cache miss keys in the cache (if there is enough space in the cache, then do nothing), write
130   // the evicted element to persistent storage, and record the space in the cache, using the space in the cache first.
131   RETURN_IF_FALSE_WITH_LOG(TryEvict(cache_miss_cnt), "Reserve space for miss keys failed.");
132 
133   // 4. Insert the cache miss elements into the cache from host memory.
134   // Note: step 2 and step 4 can not merge.
135   RETURN_IF_FALSE_WITH_LOG(InsertMissCacheFromMemory(keys_data, cache_miss_offsets, cache_miss_cnt, values_data),
136                            "Insert cache miss elements into cache from host memory failed.");
137 
138   this->FreeMemory(cache_hit);
139   this->FreeMemory(cache_miss_offsets);
140   return true;
141 }
142 
143 template <typename KeyType, typename ValueType, typename Allocator>
QueryCache(const KeyType * keys,size_t key_num,size_t * cache_miss_offsets,size_t * cache_miss_cnt,bool * cache_hit) const144 void SparseEmbeddingStorage<KeyType, ValueType, Allocator>::QueryCache(const KeyType *keys, size_t key_num,
145                                                                        size_t *cache_miss_offsets,
146                                                                        size_t *cache_miss_cnt, bool *cache_hit) const {
147   MS_EXCEPTION_IF_NULL(keys);
148   MS_EXCEPTION_IF_NULL(cache_miss_offsets);
149   MS_EXCEPTION_IF_NULL(cache_miss_cnt);
150   MS_EXCEPTION_IF_NULL(cache_hit);
151   MS_EXCEPTION_IF_NULL(this->cache_);
152 
153   int fake_index;
154   for (size_t i = 0; i < key_num; i++) {
155     if (this->cache_->Get(keys[i], &fake_index)) {
156       // Touch keys to affect the location or order of the elements in the cache, the return value for hash table is
157       // useless.
158       cache_hit[i] = true;
159       continue;
160     }
161 
162     cache_hit[i] = false;
163     // Record cache miss key's offset in all query keys.
164     cache_miss_offsets[(*cache_miss_cnt)++] = i;
165   }
166 
167   MS_LOG(DEBUG) << "Total keys number: " << key_num << ", cache hit number: " << (key_num - *cache_miss_cnt)
168                 << ", cache hit rate: " << static_cast<float>(key_num - *cache_miss_cnt) / static_cast<float>(key_num);
169 }
170 
171 template <typename KeyType, typename ValueType, typename Allocator>
TryEvict(size_t reserve_size)172 bool SparseEmbeddingStorage<KeyType, ValueType, Allocator>::TryEvict(size_t reserve_size) {
173   // 1. Try evict some non-hot data in cache to reserve space for elements that will be inserted into the cache.
174   if (reserve_size > this->cache_capacity_) {
175     MS_LOG(EXCEPTION) << "The evict number must be less or equal to cache capacity: " << this->cache_capacity_
176                       << ", but got: " << reserve_size << ", please enlarge cache capacity";
177   }
178 
179   MS_EXCEPTION_IF_NULL(this->cache_);
180   std::vector<CacheElement> evicted_elements;
181   this->cache_->TryEvict(reserve_size, &evicted_elements);
182   if (evicted_elements.size() == 0) {
183     return true;
184   }
185 
186   size_t evicted_keys_len = evicted_elements.size() * sizeof(KeyType);
187   KeyType *evicted_keys = this->template AllocateMemory<KeyType>(evicted_keys_len);
188   MS_EXCEPTION_IF_NULL(evicted_keys);
189 
190   size_t evicted_cnt = 0;
191   (void)std::for_each(evicted_elements.begin(), evicted_elements.end(),
192                       [&, this](const CacheElement &element) { evicted_keys[evicted_cnt++] = element.first; });
193 
194   // 2. Get all evicted embedding vector values.
195   size_t evicted_values_len = evicted_elements.size() * this->embedding_dim_ * sizeof(ValueType);
196   ValueType *evicted_values = this->template AllocateMemory<ValueType>(evicted_values_len);
197   MS_EXCEPTION_IF_NULL(evicted_values);
198   MS_EXCEPTION_IF_NULL(hash_table_);
199   for (size_t i = 0; i < evicted_elements.size(); i++) {
200     RETURN_IF_FALSE_WITH_LOG(
201       hash_table_->Find(evicted_keys + i, 1, false, evicted_values + this->embedding_dim_ * i, nullptr),
202       "Find key from hash table failed.");
203     // Erase evicted element from hash table after using.
204     RETURN_IF_FALSE_WITH_LOG(hash_table_->Erase(evicted_keys + i, 1, nullptr), "Erase key from hash table failed.");
205   }
206 
207   if (this->cache_->size() != hash_table_->size()) {
208     MS_LOG(EXCEPTION) << "The size of cache and hash table should be equal, but got cache size[" << this->cache_->size()
209                       << "], hash table size[" << hash_table_->size() << "].";
210   }
211 
212   // 3. Write evicted elements to persistent storage.
213   MS_EXCEPTION_IF_NULL(this->storage_);
214   this->storage_->Write({evicted_keys, evicted_keys_len}, {evicted_values, evicted_values_len});
215 
216   this->FreeMemory(evicted_keys);
217   this->FreeMemory(evicted_values);
218 
219   return true;
220 }
221 
222 template <typename KeyType, typename ValueType, typename Allocator>
InsertMissCacheFromStorage(const KeyType * keys,const size_t * cache_miss_offsets,size_t cache_miss_cnt,ValueType * values)223 bool SparseEmbeddingStorage<KeyType, ValueType, Allocator>::InsertMissCacheFromStorage(const KeyType *keys,
224                                                                                        const size_t *cache_miss_offsets,
225                                                                                        size_t cache_miss_cnt,
226                                                                                        ValueType *values) {
227   MS_EXCEPTION_IF_NULL(keys);
228   MS_EXCEPTION_IF_NULL(cache_miss_offsets);
229   MS_EXCEPTION_IF_NULL(values);
230   MS_EXCEPTION_IF_NULL(this->cache_);
231   MS_EXCEPTION_IF_NULL(hash_table_);
232 
233   // 1. Read the cache miss element from the persistent storage.
234   size_t cache_miss_keys_len = cache_miss_cnt * sizeof(KeyType);
235   KeyType *cache_miss_keys = this->template AllocateMemory<KeyType>(cache_miss_keys_len);
236   MS_EXCEPTION_IF_NULL(cache_miss_keys);
237   for (size_t i = 0; i < cache_miss_cnt; i++) {
238     cache_miss_keys[i] = keys[cache_miss_offsets[i]];
239   }
240   size_t cache_miss_values_len = cache_miss_cnt * this->embedding_dim_ * sizeof(ValueType);
241   ValueType *cache_miss_values = this->template AllocateMemory<ValueType>(cache_miss_values_len);
242   MS_EXCEPTION_IF_NULL(cache_miss_values);
243 
244   // Read the miss values from persistent storage.
245   MS_EXCEPTION_IF_NULL(this->storage_);
246   this->storage_->Read({cache_miss_keys, cache_miss_keys_len}, {cache_miss_values, cache_miss_values_len});
247 
248   // 2. Insert the cache miss elements into cache, and copy them to the returned values.
249   for (size_t i = 0; i < cache_miss_cnt; i++) {
250     // Insert key-index pairs of the cache miss elements into the cache, the index for hash embedding table is useless,
251     // set the value to 0.
252     this->cache_->Put(cache_miss_keys[i], 0);
253 
254     // Insert the embedding vectors of cache miss elements to the cache.
255     RETURN_IF_FALSE_WITH_LOG(
256       hash_table_->Insert(cache_miss_keys + i, 1, cache_miss_values + this->embedding_dim_ * i, nullptr),
257       "Insert hash table failed.");
258 
259     // Copy the embedding vectors of cache miss elements to the returned values.
260     auto ret = memcpy_s(values + this->embedding_dim_ * cache_miss_offsets[i], this->embedding_dim_ * sizeof(ValueType),
261                         cache_miss_values + this->embedding_dim_ * i, this->embedding_dim_ * sizeof(ValueType));
262     if (ret != EOK) {
263       MS_LOG(ERROR) << "Memcpy the embedding vectors of cache miss elements to the returned values failed, errno["
264                     << ret << "]";
265       return false;
266     }
267   }
268 
269   this->FreeMemory(cache_miss_keys);
270   this->FreeMemory(cache_miss_values);
271   return true;
272 }
273 
274 template <typename KeyType, typename ValueType, typename Allocator>
InsertMissCacheFromMemory(const KeyType * keys,const size_t * cache_miss_offsets,size_t cache_miss_cnt,const ValueType * values)275 bool SparseEmbeddingStorage<KeyType, ValueType, Allocator>::InsertMissCacheFromMemory(const KeyType *keys,
276                                                                                       const size_t *cache_miss_offsets,
277                                                                                       size_t cache_miss_cnt,
278                                                                                       const ValueType *values) {
279   MS_EXCEPTION_IF_NULL(keys);
280   MS_EXCEPTION_IF_NULL(cache_miss_offsets);
281   MS_EXCEPTION_IF_NULL(values);
282   MS_EXCEPTION_IF_NULL(this->cache_);
283   MS_EXCEPTION_IF_NULL(hash_table_);
284 
285   for (size_t i = 0; i < cache_miss_cnt; i++) {
286     // Insert key-index pairs of the cache miss elements into the cache, the index for hash embedding table is useless,
287     // set the value to 0.
288     this->cache_->Put(keys[cache_miss_offsets[i]], 0);
289 
290     // Insert the embedding vectors of cache miss elements to the cache.
291     RETURN_IF_FALSE_WITH_LOG(hash_table_->Insert(keys + cache_miss_offsets[i], 1,
292                                                  values + this->embedding_dim_ * cache_miss_offsets[i], nullptr),
293                              "Insert hash table failed.");
294   }
295 
296   return true;
297 }
298 
299 template <typename KeyType, typename ValueType, typename Allocator>
ExportSlice(bool,bool * last_slice,size_t slice_size_in_mega_bytes)300 std::vector<std::shared_ptr<std::vector<char>>> SparseEmbeddingStorage<KeyType, ValueType, Allocator>::ExportSlice(
301   bool, bool *last_slice, size_t slice_size_in_mega_bytes) {
302   MS_EXCEPTION_IF_NULL(last_slice);
303   // Only support fully export currently.
304   // 1. Export data in host cache.
305   if (!this->finish_export_element_in_host_mem_) {
306     return hash_table_->ExportSlice(false, &this->finish_export_element_in_host_mem_, slice_size_in_mega_bytes);
307   }
308 
309   // 2. Export data in storage.
310   static KeyType *deduplicated_keys_in_storage = nullptr;
311   static size_t deduplicated_keys_num;
312   if (this->keys_in_storage_ == nullptr) {
313     this->keys_in_storage_ = this->storage_->GetAllKeys();
314     MS_EXCEPTION_IF_NULL(this->keys_in_storage_);
315     deduplicated_keys_num = LongToSize(std::count_if(this->keys_in_storage_->begin(), this->keys_in_storage_->end(),
316                                                      [this](KeyType key) { return !this->cache_->Exists(key); }));
317     if (deduplicated_keys_num == 0) {
318       *last_slice = true;
319       this->keys_in_storage_ = nullptr;
320       return {std::make_shared<std::vector<char>>(0), std::make_shared<std::vector<char>>(0),
321               std::make_shared<std::vector<char>>(0)};
322     }
323 
324     deduplicated_keys_in_storage = this->template AllocateMemory<KeyType>(deduplicated_keys_num * sizeof(KeyType));
325     MS_EXCEPTION_IF_NULL(deduplicated_keys_in_storage);
326     size_t index = 0;
327     for (size_t i = 0; i < this->keys_in_storage_->size(); ++i) {
328       if (!this->cache_->Exists(this->keys_in_storage_->at(i))) {
329         deduplicated_keys_in_storage[index] = this->keys_in_storage_->at(i);
330         ++index;
331       }
332     }
333   }
334 
335   size_t slice_size = slice_size_in_mega_bytes * kMegaByteToByteRate / (this->embedding_dim_ * sizeof(ValueType));
336   if (slice_size == 0) {
337     MS_LOG(EXCEPTION) << "The parameter[slice_size_in_mega_bytes] " << slice_size_in_mega_bytes
338                       << " should be greater than the length in meta bytes of one element in storage: "
339                       << (this->embedding_dim_ * sizeof(ValueType)) / kMegaByteToByteRate;
340   }
341   if (this->end_ == 0) {
342     this->end_ = std::min(this->begin_ + slice_size, deduplicated_keys_num);
343   }
344 
345   auto ret = ReadSliceFromStorage(deduplicated_keys_in_storage);
346   *last_slice = (this->end_ == deduplicated_keys_num);
347   // Update the iterator and record status.
348   UpdateExportStatus(*last_slice, slice_size, deduplicated_keys_num);
349   if (*last_slice) {
350     this->FreeMemory(deduplicated_keys_in_storage);
351   }
352 
353   return ret;
354 }
355 
356 template <typename KeyType, typename ValueType, typename Allocator>
357 std::vector<std::shared_ptr<std::vector<char>>>
ReadSliceFromStorage(KeyType * keys_in_storage) const358 SparseEmbeddingStorage<KeyType, ValueType, Allocator>::ReadSliceFromStorage(KeyType *keys_in_storage) const {
359   MS_EXCEPTION_IF_NULL(keys_in_storage);
360 
361   if (this->end_ < this->begin_) {
362     MS_LOG(EXCEPTION) << "Invalid export position parameter, begin: " << this->begin_ << ", end: " << this->end_;
363   }
364   const size_t size = this->end_ - this->begin_;
365   auto keys = std::make_shared<std::vector<char>>(size * sizeof(KeyType));
366   auto keys_data = reinterpret_cast<KeyType *>(keys->data());
367   auto values = std::make_shared<std::vector<char>>(size * this->embedding_dim_ * sizeof(ValueType));
368   auto values_data = reinterpret_cast<ValueType *>(values->data());
369   auto statuses = std::make_shared<std::vector<char>>(size * sizeof(HashTableElementStatus));
370 
371   auto ret = memcpy_s(keys_data, size * sizeof(KeyType), keys_in_storage + this->begin_, size * sizeof(KeyType));
372   if (ret != EOK) {
373     MS_LOG(EXCEPTION) << "Memcpy failed, errno[" << ret << "].";
374   }
375 
376   // Read from storage
377   this->storage_->Read({keys_data, size * sizeof(KeyType)},
378                        {values_data, size * this->embedding_dim_ * sizeof(ValueType)});
379 
380   return {keys, values, statuses};
381 }
382 
383 template <typename KeyType, typename ValueType, typename Allocator>
UpdateExportStatus(bool last_slice,size_t slice_size,size_t deduplicated_keys_num_in_storage)384 void SparseEmbeddingStorage<KeyType, ValueType, Allocator>::UpdateExportStatus(
385   bool last_slice, size_t slice_size, size_t deduplicated_keys_num_in_storage) {
386   if (last_slice) {
387     this->begin_ = 0;
388     this->end_ = 0;
389     this->keys_in_storage_ = nullptr;
390     this->finish_export_element_in_host_mem_ = false;
391   } else {
392     this->begin_ += slice_size;
393     this->end_ = std::min(this->begin_ + slice_size, deduplicated_keys_num_in_storage);
394   }
395 }
396 
397 template class SparseEmbeddingStorage<int32_t, bool>;
398 template class SparseEmbeddingStorage<int32_t, int8_t>;
399 template class SparseEmbeddingStorage<int32_t, int16_t>;
400 template class SparseEmbeddingStorage<int32_t, int32_t>;
401 template class SparseEmbeddingStorage<int32_t, int64_t>;
402 template class SparseEmbeddingStorage<int32_t, uint8_t>;
403 template class SparseEmbeddingStorage<int32_t, uint16_t>;
404 template class SparseEmbeddingStorage<int32_t, uint32_t>;
405 template class SparseEmbeddingStorage<int32_t, uint64_t>;
406 template class SparseEmbeddingStorage<int32_t, float16>;
407 template class SparseEmbeddingStorage<int32_t, float>;
408 template class SparseEmbeddingStorage<int32_t, double>;
409 template class SparseEmbeddingStorage<int32_t, bfloat16>;
410 
411 template class SparseEmbeddingStorage<int64_t, bool>;
412 template class SparseEmbeddingStorage<int64_t, int8_t>;
413 template class SparseEmbeddingStorage<int64_t, int16_t>;
414 template class SparseEmbeddingStorage<int64_t, int32_t>;
415 template class SparseEmbeddingStorage<int64_t, int64_t>;
416 template class SparseEmbeddingStorage<int64_t, uint8_t>;
417 template class SparseEmbeddingStorage<int64_t, uint16_t>;
418 template class SparseEmbeddingStorage<int64_t, uint32_t>;
419 template class SparseEmbeddingStorage<int64_t, uint64_t>;
420 template class SparseEmbeddingStorage<int64_t, float16>;
421 template class SparseEmbeddingStorage<int64_t, float>;
422 template class SparseEmbeddingStorage<int64_t, double>;
423 template class SparseEmbeddingStorage<int64_t, bfloat16>;
424 
425 template class SparseEmbeddingStorage<int32_t, float, std::allocator<uint8_t>>;
426 template class SparseEmbeddingStorage<int64_t, float, std::allocator<uint8_t>>;
427 }  // namespace storage
428 }  // namespace distributed
429 }  // namespace mindspore
430