1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_SPARSE_EMBEDDING_STORAGE_EMBEDDING_STORAGE_H_ 18 #define MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_SPARSE_EMBEDDING_STORAGE_EMBEDDING_STORAGE_H_ 19 20 #include <string> 21 #include <memory> 22 #include <vector> 23 24 #include "distributed/embedding_cache/embedding_storage/embedding_storage.h" 25 #include "runtime/device/hash_table.h" 26 27 namespace mindspore { 28 namespace distributed { 29 namespace storage { 30 /** 31 * A derived class for sparse implementation to manage lookup and update of a huge Embedding Table for Hash Table type. 32 */ 33 template <typename KeyType, typename ValueType, typename Allocator = Allocator<uint8_t>> 34 class SparseEmbeddingStorage : public EmbeddingStorage<KeyType, ValueType, Allocator> { 35 public: 36 // The cache element type, a key-value pair, key is same the key of this sparse embedding storage, value is 37 // meaningless for now. 38 using CacheElement = typename EmbeddingStorage<KeyType, ValueType, Allocator>::CacheType::Element; 39 // The hash table type corresponding to the sparse embedding storage, and they have same key-value type. 40 using HashTable = device::HashTable<KeyType, ValueType>; 41 42 SparseEmbeddingStorage(int32_t embedding_key, size_t embedding_dim, size_t cache_capacity, 43 const Allocator &alloc = Allocator()) 44 : EmbeddingStorage<KeyType, ValueType, Allocator>(embedding_key, embedding_dim, cache_capacity, alloc) {} 45 ~SparseEmbeddingStorage() override = default; 46 47 /** 48 * @brief Initialize the EmbeddingStorage, such as recording the hash table of the Embedding Table corresponding to 49 * the SparseEmbeddingStorage. 50 * @param[in] `device_address`: The device address of the Embedding Table parameter corresponding to the 51 * SparseEmbeddingStorage. 52 */ 53 void Initialize(const DeviceAddress *device_address) override; 54 55 /** 56 * @brief Finalize the EmbeddingStorage, release allocated resource. 57 */ 58 void Finalize() override; 59 60 /** 61 * @brief Batch embeddings lookup operation. 62 * Query Embeddings in the host cache first, if the corresponding element cannot be found in the host cache, then read 63 * the element from the persistent storage and insert host cache. 64 * Access an element of the cache generally affects the location or order of the elements in the cache, depending 65 * on different cache strategies. 66 */ 67 bool Get(const ConstDataWithLen &keys, const DataWithLen &values) override; 68 69 /** 70 * @brief Batch embeddings update/insert operation. 71 * Update/Insert Embeddings in the host cache first, if the host cache has insufficient space, the expired elements 72 * will automatically be evicted the to the persistent storage. 73 * Update or Insert an element of the cache generally affects the location or order of the elements in the cache, 74 * depending on different cache strategies. 75 */ 76 bool Put(const ConstDataWithLen &keys, const ConstDataWithLen &values) override; 77 78 /** 79 * @brief To export a slice from the storage, the size is specified by the parameter 'slice_size_in_mega_bytes' in MB. 80 */ 81 std::vector<std::shared_ptr<std::vector<char>>> ExportSlice(bool incremental, bool *last_slice, 82 size_t slice_size_in_mega_bytes) override; 83 84 private: 85 /** 86 * @brief Query cache to analyse the information of cache hit and miss keys. Access an element of the cache generally 87 * affects the location or order of the elements in the cache, depending on different cache strategies. 88 * @param[in] `keys`: The array records all keys which need to query. 89 * @param[in] `key_num`: The number of keys which need to query. 90 * @param[out] `cache_miss_offsets`: The array records the offset(index) of cache miss key in origin keys array. 91 * @param[out] `cache_miss_cnt`: The number of cache miss keys. 92 * @param[out] `cache_hit`: The array records cache hit/miss for each key. 93 * keys. 94 */ 95 void QueryCache(const KeyType *keys, size_t key_num, size_t *cache_miss_offsets, size_t *cache_miss_cnt, 96 bool *cache_hit) const; 97 98 /** 99 * @brief Reserve space for cache miss keys in the cache, write the evicted element to persistent storage, 100 * and record the new space position in the cache. 101 * @param[in] `reserve_size`: The number of element slots that are expected to be reserved. If the 102 * reserve_size is less than or equal to the number of slots remaining in the cache, the function does nothing. 103 * @return Whether the function was successfully executed. 104 */ 105 bool TryEvict(size_t reserve_size); 106 107 /** 108 * @brief Insert the cache miss elements into the cache from persistent storage, and copy them to the output values. 109 * @param[in] `keys`: The array records all origin keys for batch embeddings lookup operation. 110 * @param[in] `cache_miss_offsets`: The array records the offset(index) of cache miss key in origin keys array. 111 * @param[in] `cache_miss_cnt`: The number of cache miss keys. 112 * @param[out] `values`: The output embeddings. 113 * @return Whether the function was successfully executed. 114 */ 115 bool InsertMissCacheFromStorage(const KeyType *keys, const size_t *cache_miss_offsets, size_t cache_miss_cnt, 116 ValueType *values); 117 118 /** 119 * @brief Insert the cache miss elements into the cache from host memory. 120 * @param[in] `keys`: The array records all origin keys for batch embeddings update/insert operation. 121 * @param[in] `cache_miss_offsets`: The array records the offset(index) of cache miss key in origin keys array. 122 * @param[in] `cache_miss_cnt`: The number of cache miss keys. 123 * @param[in] `values`: Embeddings corresponding to all keys need to be updated. 124 * @return Whether the function was successfully executed. 125 */ 126 bool InsertMissCacheFromMemory(const KeyType *keys, const size_t *cache_miss_offsets, size_t cache_miss_cnt, 127 const ValueType *values); 128 129 /** 130 * @brief Read slice data from storage. 131 * @param[in] `keys_in_storage`: The array records all keys which only exist in storage. 132 * @return The byte sequence of data read form storage. 133 */ 134 std::vector<std::shared_ptr<std::vector<char>>> ReadSliceFromStorage(KeyType *keys_in_storage) const; 135 136 /** 137 * @brief Update begin and end iterator and some record status. 138 * @param[in] `last_slice`: Indicate whether the slice by export is the last slice, that is, 139 * the export is complete. 140 * @param[in] `slice_size`: The number of elements in a slice to export. 141 * @param[in] `deduplicated_keys_num_in_storage`: The total keys number which only exist in storage. 142 */ 143 void UpdateExportStatus(bool last_slice, size_t slice_size, size_t deduplicated_keys_num_in_storage); 144 145 // The base pointer to the hash table of the embedding table parameter. 146 // All embeddings in host cache is recorded in it. 147 HashTable *hash_table_{nullptr}; 148 }; 149 } // namespace storage 150 } // namespace distributed 151 } // namespace mindspore 152 #endif // MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_SPARSE_EMBEDDING_STORAGE_EMBEDDING_STORAGE_H_ 153