1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_DENSE_EMBEDDING_STORAGE_EMBEDDING_STORAGE_H_ 18 #define MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_DENSE_EMBEDDING_STORAGE_EMBEDDING_STORAGE_H_ 19 20 #include <vector> 21 #include "distributed/embedding_cache/embedding_storage/embedding_storage.h" 22 23 namespace mindspore { 24 namespace distributed { 25 namespace storage { 26 /** 27 * @brief A derived class for dense implementation to manage lookup and update of a huge Embedding Table for Tensor 28 * type. 29 */ 30 template <typename KeyType, typename ValueType, typename Allocator = Allocator<uint8_t>> 31 class DenseEmbeddingStorage : public EmbeddingStorage<KeyType, ValueType, Allocator> { 32 public: 33 // The cache element type, a key-value pair, key is same the key of this dense embedding storage, value is the 34 // index(line number in embedding table tensor) of the key. 35 using CacheElement = typename EmbeddingStorage<KeyType, ValueType, Allocator>::CacheType::Element; 36 37 DenseEmbeddingStorage(int32_t embedding_key, size_t embedding_dim, size_t cache_capacity, 38 const Allocator &alloc = Allocator()) 39 : EmbeddingStorage<KeyType, ValueType, Allocator>(embedding_key, embedding_dim, cache_capacity, alloc) {} 40 ~DenseEmbeddingStorage() override = default; 41 42 /** 43 * @brief Initialize the EmbeddingStorage, such as recording the device address of the Embedding Table corresponding 44 * to the DenseEmbeddingStorage. 45 * @param[in] `device_address`: The device address of the Embedding Table tensor parameter 46 * corresponding to the DenseEmbeddingStorage. 47 */ 48 void Initialize(const DeviceAddress *device_address) override; 49 50 /** 51 * @brief Finalize the EmbeddingStorage, release allocated resource. 52 */ 53 void Finalize() override; 54 55 /** 56 * @brief Batch embeddings lookup operation. 57 * Query Embeddings in the host cache first, if the corresponding element cannot be found in the host cache, then read 58 * the element from the persistent storage and insert host cache. 59 * Access an element of the cache generally affects the location or order of the elements in the cache, depending 60 * on different cache strategies. 61 */ 62 bool Get(const ConstDataWithLen &keys, const DataWithLen &values) override; 63 64 /** 65 * @brief Batch embeddings update/insert operation. 66 * Update/Insert Embeddings in the host cache first, if the host cache has insufficient space, the expired elements 67 * will automatically be evicted the to the persistent storage. 68 * Update or Insert an element of the cache generally affects the location or order of the elements in the cache, 69 * depending on different cache strategies. 70 */ 71 bool Put(const ConstDataWithLen &keys, const ConstDataWithLen &values) override; 72 73 private: 74 /** 75 * @brief Query cache to get the index in the embedding table tensor of each cache hit key, and count the number of 76 * cache miss keys. Access an element of the cache generally affects the location or order of the elements in the 77 * cache, depending on different cache strategies. 78 * @param[in] `keys`: The array records all keys which need to query. 79 * @param[in] `key_num`: The number of keys which need to query. 80 * @param[out] `cache_miss_offsets`: The array records the offset(index) of cache miss key in origin keys array. 81 * @param[out] `cache_miss_cnt`: The number of cache miss keys. 82 * @param[out] `indices_in_cache`: The array records the indices in the embedding table tensor of each cache hit 83 * keys. 84 */ 85 void QueryCache(const KeyType *keys, size_t key_num, size_t *cache_miss_offsets, size_t *cache_miss_cnt, 86 int *indices_in_cache) const; 87 88 /** 89 * @brief Reserve space for cache miss keys in the cache, write the evicted element to persistent storage, 90 * and record the new space position in the cache. 91 * @param[in] `reserve_size`: The number of element slots that are expected to be reserved. If the 92 * reserve_size is less than or equal to the number of slots remaining in the cache, the function does nothing. 93 * @return Whether the function was successfully executed. 94 */ 95 bool TryEvict(size_t reserve_size); 96 97 /** 98 * @brief Insert the cache miss elements into the cache from persistent storage, and copy them to the output values. 99 * @param[in] `keys`: The array records all origin keys for batch embeddings lookup operation. 100 * @param[in] `cache_miss_offsets`: The array records the offset(index) of cache miss key in origin keys array. 101 * @param[in] `cache_miss_cnt`: The number of cache miss keys. 102 * @param[out] `values`: The output embeddings. 103 * @return Whether the function was successfully executed. 104 */ 105 bool InsertMissCacheFromStorage(const KeyType *keys, const size_t *cache_miss_offsets, size_t cache_miss_cnt, 106 ValueType *values); 107 108 /** 109 * @brief Insert the cache miss elements into the cache from host memory. 110 * @param[in] `keys`: The array records all origin keys for batch embeddings update/insert operation. 111 * @param[in] `cache_miss_offsets`: The array records the offset(index) of cache miss key in origin keys array. 112 * @param[in] `cache_miss_cnt`: The number of cache miss keys. 113 * @param[in] `values`: Embeddings corresponding to all keys need to be updated. 114 * @return Whether the function was successfully executed. 115 */ 116 bool InsertMissCacheFromMemory(const KeyType *keys, const size_t *cache_miss_offsets, size_t cache_miss_cnt, 117 const ValueType *values); 118 119 // The base pointer to embedding table parameter, all embeddings in host cache is recorded in 120 // embedding_param_address_. 121 const DeviceAddress *embedding_param_address_{nullptr}; 122 123 // For performance, keep the pointer snapshot for `embedding_param_address_`. 124 ValueType *embedding_param_ptr_{nullptr}; 125 126 // Record all empty slot(idle slot or index) in embedding table tensor. 127 std::vector<int> empty_slots_; 128 }; 129 } // namespace storage 130 } // namespace distributed 131 } // namespace mindspore 132 #endif // MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_DENSE_EMBEDDING_STORAGE_EMBEDDING_STORAGE_H_ 133