1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "distributed/embedding_cache/embedding_storage/sparse_embedding_storage.h"
18 #include <memory>
19 #include <vector>
20 #include <algorithm>
21
22 namespace mindspore {
23 namespace distributed {
24 namespace storage {
25 constexpr size_t kMegaByteToByteRate = static_cast<size_t>(1) << 20;
26
27 template <typename KeyType, typename ValueType, typename Allocator>
Initialize(const DeviceAddress * device_address)28 void SparseEmbeddingStorage<KeyType, ValueType, Allocator>::Initialize(const DeviceAddress *device_address) {
29 MS_EXCEPTION_IF_NULL(device_address);
30 EmbeddingStorage<KeyType, ValueType, Allocator>::Initialize(device_address);
31
32 auto user_data = device_address->user_data();
33 MS_EXCEPTION_IF_NULL(user_data);
34 hash_table_ = user_data->get<HashTable>(kUserDataData).get();
35 MS_EXCEPTION_IF_NULL(hash_table_);
36 }
37
38 template <typename KeyType, typename ValueType, typename Allocator>
Finalize()39 void SparseEmbeddingStorage<KeyType, ValueType, Allocator>::Finalize() {
40 hash_table_ = nullptr;
41 EmbeddingStorage<KeyType, ValueType, Allocator>::Finalize();
42 }
43
44 template <typename KeyType, typename ValueType, typename Allocator>
Get(const ConstDataWithLen & keys,const DataWithLen & values)45 bool SparseEmbeddingStorage<KeyType, ValueType, Allocator>::Get(const ConstDataWithLen &keys,
46 const DataWithLen &values) {
47 const KeyType *keys_data = reinterpret_cast<const KeyType *>(keys.data_);
48 ValueType *values_data = reinterpret_cast<ValueType *>(values.data_);
49 size_t key_num = keys.data_len_ / sizeof(KeyType);
50 if (values.data_len_ < key_num * this->embedding_dim_ * sizeof(ValueType)) {
51 MS_LOG(EXCEPTION) << "The value buffer length is insufficient.";
52 }
53 MS_EXCEPTION_IF_NULL(keys_data);
54 MS_EXCEPTION_IF_NULL(values_data);
55 MS_EXCEPTION_IF_NULL(hash_table_);
56
57 // 1. Query cache to analyse the information of cache hit and miss keys, update the positions of cache hit elements in
58 // the cache (cache refresh).
59 size_t cache_miss_cnt = 0;
60 size_t *cache_miss_offsets = this->template AllocateMemory<size_t>(sizeof(size_t) * key_num);
61 MS_EXCEPTION_IF_NULL(cache_miss_offsets);
62 bool *cache_hit = this->template AllocateMemory<bool>(sizeof(bool) * key_num);
63 MS_EXCEPTION_IF_NULL(cache_hit);
64 QueryCache(keys_data, key_num, cache_miss_offsets, &cache_miss_cnt, cache_hit);
65
66 // 2. Copy the embeddings from cache to the returned values for cache hit keys.
67 for (size_t i = 0; i < key_num; i++) {
68 if (!cache_hit[i]) {
69 continue;
70 }
71 RETURN_IF_FALSE_WITH_LOG(
72 hash_table_->Find(keys_data + i, 1, false, values_data + this->embedding_dim_ * i, nullptr),
73 "Find key from hash table failed.");
74 }
75
76 if (cache_miss_cnt == 0) {
77 return true;
78 }
79
80 // 3. Reserve space for cache miss keys in the cache (if there is enough space in the cache, then do nothing), write
81 // the evicted element to persistent storage, and record the space in the cache, using the space in the cache first.
82 RETURN_IF_FALSE_WITH_LOG(TryEvict(cache_miss_cnt), "Reserve space for miss keys failed.");
83
84 // 4. Insert the cache miss elements into the cache from persistent storage, and copy them to the returned values.
85 RETURN_IF_FALSE_WITH_LOG(InsertMissCacheFromStorage(keys_data, cache_miss_offsets, cache_miss_cnt, values_data),
86 "Insert the cache miss elements into the cache from persistent storage failed.");
87
88 this->FreeMemory(cache_hit);
89 this->FreeMemory(cache_miss_offsets);
90 return true;
91 }
92
93 template <typename KeyType, typename ValueType, typename Allocator>
Put(const ConstDataWithLen & keys,const ConstDataWithLen & values)94 bool SparseEmbeddingStorage<KeyType, ValueType, Allocator>::Put(const ConstDataWithLen &keys,
95 const ConstDataWithLen &values) {
96 const KeyType *keys_data = reinterpret_cast<const KeyType *>(keys.data_);
97 const ValueType *values_data = reinterpret_cast<const ValueType *>(values.data_);
98 size_t key_num = keys.data_len_ / sizeof(KeyType);
99 if (values.data_len_ != key_num * this->embedding_dim_ * sizeof(ValueType)) {
100 MS_LOG(EXCEPTION) << "The value length is invalid, expected length["
101 << key_num * this->embedding_dim_ * sizeof(ValueType) << "], but got[" << values.data_len_ << "]";
102 }
103 MS_EXCEPTION_IF_NULL(keys_data);
104 MS_EXCEPTION_IF_NULL(values_data);
105 MS_EXCEPTION_IF_NULL(hash_table_);
106
107 // 1. Query cache to analyse the information of cache hit and miss keys, update the positions of cache hit elements in
108 // the cache (cache refresh).
109 size_t cache_miss_cnt = 0;
110 size_t *cache_miss_offsets = this->template AllocateMemory<size_t>(sizeof(size_t) * key_num);
111 MS_EXCEPTION_IF_NULL(cache_miss_offsets);
112 bool *cache_hit = this->template AllocateMemory<bool>(sizeof(bool) * key_num);
113 MS_EXCEPTION_IF_NULL(cache_hit);
114 QueryCache(keys_data, key_num, cache_miss_offsets, &cache_miss_cnt, cache_hit);
115
116 // 2. Update the embedding value to the cache for cache hit keys.
117 for (size_t i = 0; i < key_num; i++) {
118 if (!cache_hit[i]) {
119 continue;
120 }
121 RETURN_IF_FALSE_WITH_LOG(hash_table_->Insert(keys_data + i, 1, values_data + this->embedding_dim_ * i, nullptr),
122 "Insert hash table failed.");
123 }
124
125 if (cache_miss_cnt == 0) {
126 return true;
127 }
128
129 // 3. Reserve space for cache miss keys in the cache (if there is enough space in the cache, then do nothing), write
130 // the evicted element to persistent storage, and record the space in the cache, using the space in the cache first.
131 RETURN_IF_FALSE_WITH_LOG(TryEvict(cache_miss_cnt), "Reserve space for miss keys failed.");
132
133 // 4. Insert the cache miss elements into the cache from host memory.
134 // Note: step 2 and step 4 can not merge.
135 RETURN_IF_FALSE_WITH_LOG(InsertMissCacheFromMemory(keys_data, cache_miss_offsets, cache_miss_cnt, values_data),
136 "Insert cache miss elements into cache from host memory failed.");
137
138 this->FreeMemory(cache_hit);
139 this->FreeMemory(cache_miss_offsets);
140 return true;
141 }
142
143 template <typename KeyType, typename ValueType, typename Allocator>
QueryCache(const KeyType * keys,size_t key_num,size_t * cache_miss_offsets,size_t * cache_miss_cnt,bool * cache_hit) const144 void SparseEmbeddingStorage<KeyType, ValueType, Allocator>::QueryCache(const KeyType *keys, size_t key_num,
145 size_t *cache_miss_offsets,
146 size_t *cache_miss_cnt, bool *cache_hit) const {
147 MS_EXCEPTION_IF_NULL(keys);
148 MS_EXCEPTION_IF_NULL(cache_miss_offsets);
149 MS_EXCEPTION_IF_NULL(cache_miss_cnt);
150 MS_EXCEPTION_IF_NULL(cache_hit);
151 MS_EXCEPTION_IF_NULL(this->cache_);
152
153 int fake_index;
154 for (size_t i = 0; i < key_num; i++) {
155 if (this->cache_->Get(keys[i], &fake_index)) {
156 // Touch keys to affect the location or order of the elements in the cache, the return value for hash table is
157 // useless.
158 cache_hit[i] = true;
159 continue;
160 }
161
162 cache_hit[i] = false;
163 // Record cache miss key's offset in all query keys.
164 cache_miss_offsets[(*cache_miss_cnt)++] = i;
165 }
166
167 MS_LOG(DEBUG) << "Total keys number: " << key_num << ", cache hit number: " << (key_num - *cache_miss_cnt)
168 << ", cache hit rate: " << static_cast<float>(key_num - *cache_miss_cnt) / static_cast<float>(key_num);
169 }
170
171 template <typename KeyType, typename ValueType, typename Allocator>
TryEvict(size_t reserve_size)172 bool SparseEmbeddingStorage<KeyType, ValueType, Allocator>::TryEvict(size_t reserve_size) {
173 // 1. Try evict some non-hot data in cache to reserve space for elements that will be inserted into the cache.
174 if (reserve_size > this->cache_capacity_) {
175 MS_LOG(EXCEPTION) << "The evict number must be less or equal to cache capacity: " << this->cache_capacity_
176 << ", but got: " << reserve_size << ", please enlarge cache capacity";
177 }
178
179 MS_EXCEPTION_IF_NULL(this->cache_);
180 std::vector<CacheElement> evicted_elements;
181 this->cache_->TryEvict(reserve_size, &evicted_elements);
182 if (evicted_elements.size() == 0) {
183 return true;
184 }
185
186 size_t evicted_keys_len = evicted_elements.size() * sizeof(KeyType);
187 KeyType *evicted_keys = this->template AllocateMemory<KeyType>(evicted_keys_len);
188 MS_EXCEPTION_IF_NULL(evicted_keys);
189
190 size_t evicted_cnt = 0;
191 (void)std::for_each(evicted_elements.begin(), evicted_elements.end(),
192 [&, this](const CacheElement &element) { evicted_keys[evicted_cnt++] = element.first; });
193
194 // 2. Get all evicted embedding vector values.
195 size_t evicted_values_len = evicted_elements.size() * this->embedding_dim_ * sizeof(ValueType);
196 ValueType *evicted_values = this->template AllocateMemory<ValueType>(evicted_values_len);
197 MS_EXCEPTION_IF_NULL(evicted_values);
198 MS_EXCEPTION_IF_NULL(hash_table_);
199 for (size_t i = 0; i < evicted_elements.size(); i++) {
200 RETURN_IF_FALSE_WITH_LOG(
201 hash_table_->Find(evicted_keys + i, 1, false, evicted_values + this->embedding_dim_ * i, nullptr),
202 "Find key from hash table failed.");
203 // Erase evicted element from hash table after using.
204 RETURN_IF_FALSE_WITH_LOG(hash_table_->Erase(evicted_keys + i, 1, nullptr), "Erase key from hash table failed.");
205 }
206
207 if (this->cache_->size() != hash_table_->size()) {
208 MS_LOG(EXCEPTION) << "The size of cache and hash table should be equal, but got cache size[" << this->cache_->size()
209 << "], hash table size[" << hash_table_->size() << "].";
210 }
211
212 // 3. Write evicted elements to persistent storage.
213 MS_EXCEPTION_IF_NULL(this->storage_);
214 this->storage_->Write({evicted_keys, evicted_keys_len}, {evicted_values, evicted_values_len});
215
216 this->FreeMemory(evicted_keys);
217 this->FreeMemory(evicted_values);
218
219 return true;
220 }
221
222 template <typename KeyType, typename ValueType, typename Allocator>
InsertMissCacheFromStorage(const KeyType * keys,const size_t * cache_miss_offsets,size_t cache_miss_cnt,ValueType * values)223 bool SparseEmbeddingStorage<KeyType, ValueType, Allocator>::InsertMissCacheFromStorage(const KeyType *keys,
224 const size_t *cache_miss_offsets,
225 size_t cache_miss_cnt,
226 ValueType *values) {
227 MS_EXCEPTION_IF_NULL(keys);
228 MS_EXCEPTION_IF_NULL(cache_miss_offsets);
229 MS_EXCEPTION_IF_NULL(values);
230 MS_EXCEPTION_IF_NULL(this->cache_);
231 MS_EXCEPTION_IF_NULL(hash_table_);
232
233 // 1. Read the cache miss element from the persistent storage.
234 size_t cache_miss_keys_len = cache_miss_cnt * sizeof(KeyType);
235 KeyType *cache_miss_keys = this->template AllocateMemory<KeyType>(cache_miss_keys_len);
236 MS_EXCEPTION_IF_NULL(cache_miss_keys);
237 for (size_t i = 0; i < cache_miss_cnt; i++) {
238 cache_miss_keys[i] = keys[cache_miss_offsets[i]];
239 }
240 size_t cache_miss_values_len = cache_miss_cnt * this->embedding_dim_ * sizeof(ValueType);
241 ValueType *cache_miss_values = this->template AllocateMemory<ValueType>(cache_miss_values_len);
242 MS_EXCEPTION_IF_NULL(cache_miss_values);
243
244 // Read the miss values from persistent storage.
245 MS_EXCEPTION_IF_NULL(this->storage_);
246 this->storage_->Read({cache_miss_keys, cache_miss_keys_len}, {cache_miss_values, cache_miss_values_len});
247
248 // 2. Insert the cache miss elements into cache, and copy them to the returned values.
249 for (size_t i = 0; i < cache_miss_cnt; i++) {
250 // Insert key-index pairs of the cache miss elements into the cache, the index for hash embedding table is useless,
251 // set the value to 0.
252 this->cache_->Put(cache_miss_keys[i], 0);
253
254 // Insert the embedding vectors of cache miss elements to the cache.
255 RETURN_IF_FALSE_WITH_LOG(
256 hash_table_->Insert(cache_miss_keys + i, 1, cache_miss_values + this->embedding_dim_ * i, nullptr),
257 "Insert hash table failed.");
258
259 // Copy the embedding vectors of cache miss elements to the returned values.
260 auto ret = memcpy_s(values + this->embedding_dim_ * cache_miss_offsets[i], this->embedding_dim_ * sizeof(ValueType),
261 cache_miss_values + this->embedding_dim_ * i, this->embedding_dim_ * sizeof(ValueType));
262 if (ret != EOK) {
263 MS_LOG(ERROR) << "Memcpy the embedding vectors of cache miss elements to the returned values failed, errno["
264 << ret << "]";
265 return false;
266 }
267 }
268
269 this->FreeMemory(cache_miss_keys);
270 this->FreeMemory(cache_miss_values);
271 return true;
272 }
273
274 template <typename KeyType, typename ValueType, typename Allocator>
InsertMissCacheFromMemory(const KeyType * keys,const size_t * cache_miss_offsets,size_t cache_miss_cnt,const ValueType * values)275 bool SparseEmbeddingStorage<KeyType, ValueType, Allocator>::InsertMissCacheFromMemory(const KeyType *keys,
276 const size_t *cache_miss_offsets,
277 size_t cache_miss_cnt,
278 const ValueType *values) {
279 MS_EXCEPTION_IF_NULL(keys);
280 MS_EXCEPTION_IF_NULL(cache_miss_offsets);
281 MS_EXCEPTION_IF_NULL(values);
282 MS_EXCEPTION_IF_NULL(this->cache_);
283 MS_EXCEPTION_IF_NULL(hash_table_);
284
285 for (size_t i = 0; i < cache_miss_cnt; i++) {
286 // Insert key-index pairs of the cache miss elements into the cache, the index for hash embedding table is useless,
287 // set the value to 0.
288 this->cache_->Put(keys[cache_miss_offsets[i]], 0);
289
290 // Insert the embedding vectors of cache miss elements to the cache.
291 RETURN_IF_FALSE_WITH_LOG(hash_table_->Insert(keys + cache_miss_offsets[i], 1,
292 values + this->embedding_dim_ * cache_miss_offsets[i], nullptr),
293 "Insert hash table failed.");
294 }
295
296 return true;
297 }
298
299 template <typename KeyType, typename ValueType, typename Allocator>
ExportSlice(bool,bool * last_slice,size_t slice_size_in_mega_bytes)300 std::vector<std::shared_ptr<std::vector<char>>> SparseEmbeddingStorage<KeyType, ValueType, Allocator>::ExportSlice(
301 bool, bool *last_slice, size_t slice_size_in_mega_bytes) {
302 MS_EXCEPTION_IF_NULL(last_slice);
303 // Only support fully export currently.
304 // 1. Export data in host cache.
305 if (!this->finish_export_element_in_host_mem_) {
306 return hash_table_->ExportSlice(false, &this->finish_export_element_in_host_mem_, slice_size_in_mega_bytes);
307 }
308
309 // 2. Export data in storage.
310 static KeyType *deduplicated_keys_in_storage = nullptr;
311 static size_t deduplicated_keys_num;
312 if (this->keys_in_storage_ == nullptr) {
313 this->keys_in_storage_ = this->storage_->GetAllKeys();
314 MS_EXCEPTION_IF_NULL(this->keys_in_storage_);
315 deduplicated_keys_num = LongToSize(std::count_if(this->keys_in_storage_->begin(), this->keys_in_storage_->end(),
316 [this](KeyType key) { return !this->cache_->Exists(key); }));
317 if (deduplicated_keys_num == 0) {
318 *last_slice = true;
319 this->keys_in_storage_ = nullptr;
320 return {std::make_shared<std::vector<char>>(0), std::make_shared<std::vector<char>>(0),
321 std::make_shared<std::vector<char>>(0)};
322 }
323
324 deduplicated_keys_in_storage = this->template AllocateMemory<KeyType>(deduplicated_keys_num * sizeof(KeyType));
325 MS_EXCEPTION_IF_NULL(deduplicated_keys_in_storage);
326 size_t index = 0;
327 for (size_t i = 0; i < this->keys_in_storage_->size(); ++i) {
328 if (!this->cache_->Exists(this->keys_in_storage_->at(i))) {
329 deduplicated_keys_in_storage[index] = this->keys_in_storage_->at(i);
330 ++index;
331 }
332 }
333 }
334
335 size_t slice_size = slice_size_in_mega_bytes * kMegaByteToByteRate / (this->embedding_dim_ * sizeof(ValueType));
336 if (slice_size == 0) {
337 MS_LOG(EXCEPTION) << "The parameter[slice_size_in_mega_bytes] " << slice_size_in_mega_bytes
338 << " should be greater than the length in meta bytes of one element in storage: "
339 << (this->embedding_dim_ * sizeof(ValueType)) / kMegaByteToByteRate;
340 }
341 if (this->end_ == 0) {
342 this->end_ = std::min(this->begin_ + slice_size, deduplicated_keys_num);
343 }
344
345 auto ret = ReadSliceFromStorage(deduplicated_keys_in_storage);
346 *last_slice = (this->end_ == deduplicated_keys_num);
347 // Update the iterator and record status.
348 UpdateExportStatus(*last_slice, slice_size, deduplicated_keys_num);
349 if (*last_slice) {
350 this->FreeMemory(deduplicated_keys_in_storage);
351 }
352
353 return ret;
354 }
355
356 template <typename KeyType, typename ValueType, typename Allocator>
357 std::vector<std::shared_ptr<std::vector<char>>>
ReadSliceFromStorage(KeyType * keys_in_storage) const358 SparseEmbeddingStorage<KeyType, ValueType, Allocator>::ReadSliceFromStorage(KeyType *keys_in_storage) const {
359 MS_EXCEPTION_IF_NULL(keys_in_storage);
360
361 if (this->end_ < this->begin_) {
362 MS_LOG(EXCEPTION) << "Invalid export position parameter, begin: " << this->begin_ << ", end: " << this->end_;
363 }
364 const size_t size = this->end_ - this->begin_;
365 auto keys = std::make_shared<std::vector<char>>(size * sizeof(KeyType));
366 auto keys_data = reinterpret_cast<KeyType *>(keys->data());
367 auto values = std::make_shared<std::vector<char>>(size * this->embedding_dim_ * sizeof(ValueType));
368 auto values_data = reinterpret_cast<ValueType *>(values->data());
369 auto statuses = std::make_shared<std::vector<char>>(size * sizeof(HashTableElementStatus));
370
371 auto ret = memcpy_s(keys_data, size * sizeof(KeyType), keys_in_storage + this->begin_, size * sizeof(KeyType));
372 if (ret != EOK) {
373 MS_LOG(EXCEPTION) << "Memcpy failed, errno[" << ret << "].";
374 }
375
376 // Read from storage
377 this->storage_->Read({keys_data, size * sizeof(KeyType)},
378 {values_data, size * this->embedding_dim_ * sizeof(ValueType)});
379
380 return {keys, values, statuses};
381 }
382
383 template <typename KeyType, typename ValueType, typename Allocator>
UpdateExportStatus(bool last_slice,size_t slice_size,size_t deduplicated_keys_num_in_storage)384 void SparseEmbeddingStorage<KeyType, ValueType, Allocator>::UpdateExportStatus(
385 bool last_slice, size_t slice_size, size_t deduplicated_keys_num_in_storage) {
386 if (last_slice) {
387 this->begin_ = 0;
388 this->end_ = 0;
389 this->keys_in_storage_ = nullptr;
390 this->finish_export_element_in_host_mem_ = false;
391 } else {
392 this->begin_ += slice_size;
393 this->end_ = std::min(this->begin_ + slice_size, deduplicated_keys_num_in_storage);
394 }
395 }
396
397 template class SparseEmbeddingStorage<int32_t, bool>;
398 template class SparseEmbeddingStorage<int32_t, int8_t>;
399 template class SparseEmbeddingStorage<int32_t, int16_t>;
400 template class SparseEmbeddingStorage<int32_t, int32_t>;
401 template class SparseEmbeddingStorage<int32_t, int64_t>;
402 template class SparseEmbeddingStorage<int32_t, uint8_t>;
403 template class SparseEmbeddingStorage<int32_t, uint16_t>;
404 template class SparseEmbeddingStorage<int32_t, uint32_t>;
405 template class SparseEmbeddingStorage<int32_t, uint64_t>;
406 template class SparseEmbeddingStorage<int32_t, float16>;
407 template class SparseEmbeddingStorage<int32_t, float>;
408 template class SparseEmbeddingStorage<int32_t, double>;
409 template class SparseEmbeddingStorage<int32_t, bfloat16>;
410
411 template class SparseEmbeddingStorage<int64_t, bool>;
412 template class SparseEmbeddingStorage<int64_t, int8_t>;
413 template class SparseEmbeddingStorage<int64_t, int16_t>;
414 template class SparseEmbeddingStorage<int64_t, int32_t>;
415 template class SparseEmbeddingStorage<int64_t, int64_t>;
416 template class SparseEmbeddingStorage<int64_t, uint8_t>;
417 template class SparseEmbeddingStorage<int64_t, uint16_t>;
418 template class SparseEmbeddingStorage<int64_t, uint32_t>;
419 template class SparseEmbeddingStorage<int64_t, uint64_t>;
420 template class SparseEmbeddingStorage<int64_t, float16>;
421 template class SparseEmbeddingStorage<int64_t, float>;
422 template class SparseEmbeddingStorage<int64_t, double>;
423 template class SparseEmbeddingStorage<int64_t, bfloat16>;
424
425 template class SparseEmbeddingStorage<int32_t, float, std::allocator<uint8_t>>;
426 template class SparseEmbeddingStorage<int64_t, float, std::allocator<uint8_t>>;
427 } // namespace storage
428 } // namespace distributed
429 } // namespace mindspore
430