• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_DENSE_EMBEDDING_STORAGE_EMBEDDING_STORAGE_H_
18 #define MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_DENSE_EMBEDDING_STORAGE_EMBEDDING_STORAGE_H_
19 
20 #include <vector>
21 #include "distributed/embedding_cache/embedding_storage/embedding_storage.h"
22 
23 namespace mindspore {
24 namespace distributed {
25 namespace storage {
26 /**
27  * @brief A derived class for dense implementation to manage lookup and update of a huge Embedding Table for Tensor
28  * type.
29  */
30 template <typename KeyType, typename ValueType, typename Allocator = Allocator<uint8_t>>
31 class DenseEmbeddingStorage : public EmbeddingStorage<KeyType, ValueType, Allocator> {
32  public:
33   // The cache element type, a key-value pair, key is same the key of this dense embedding storage, value is the
34   // index(line number in embedding table tensor) of the key.
35   using CacheElement = typename EmbeddingStorage<KeyType, ValueType, Allocator>::CacheType::Element;
36 
37   DenseEmbeddingStorage(int32_t embedding_key, size_t embedding_dim, size_t cache_capacity,
38                         const Allocator &alloc = Allocator())
39       : EmbeddingStorage<KeyType, ValueType, Allocator>(embedding_key, embedding_dim, cache_capacity, alloc) {}
40   ~DenseEmbeddingStorage() override = default;
41 
42   /**
43    * @brief Initialize the EmbeddingStorage, such as recording the device address of the Embedding Table corresponding
44    * to the DenseEmbeddingStorage.
45    * @param[in] `device_address`: The device address of the Embedding Table tensor parameter
46    * corresponding to the DenseEmbeddingStorage.
47    */
48   void Initialize(const DeviceAddress *device_address) override;
49 
50   /**
51    * @brief Finalize the EmbeddingStorage, release allocated resource.
52    */
53   void Finalize() override;
54 
55   /**
56    * @brief Batch embeddings lookup operation.
57    * Query Embeddings in the host cache first, if the corresponding element cannot be found in the host cache, then read
58    * the element from the persistent storage and insert host cache.
59    * Access an element of the cache generally affects the location or order of the elements in the cache, depending
60    * on different cache strategies.
61    */
62   bool Get(const ConstDataWithLen &keys, const DataWithLen &values) override;
63 
64   /**
65    * @brief Batch embeddings update/insert operation.
66    * Update/Insert Embeddings in the host cache first, if the host cache has insufficient space, the expired elements
67    * will automatically be evicted the to the persistent storage.
68    * Update or Insert an element of the cache generally affects the location or order of the elements in the cache,
69    * depending on different cache strategies.
70    */
71   bool Put(const ConstDataWithLen &keys, const ConstDataWithLen &values) override;
72 
73  private:
74   /**
75    * @brief Query cache to get the index in the embedding table tensor of each cache hit key, and count the number of
76    * cache miss keys. Access an element of the cache generally affects the location or order of the elements in the
77    * cache, depending on different cache strategies.
78    * @param[in] `keys`: The array records all keys which need to query.
79    * @param[in] `key_num`: The number of keys which need to query.
80    * @param[out] `cache_miss_offsets`: The array records the offset(index) of cache miss key in origin keys array.
81    * @param[out] `cache_miss_cnt`: The number of cache miss keys.
82    * @param[out] `indices_in_cache`: The array records the indices in the embedding table tensor of each cache hit
83    * keys.
84    */
85   void QueryCache(const KeyType *keys, size_t key_num, size_t *cache_miss_offsets, size_t *cache_miss_cnt,
86                   int *indices_in_cache) const;
87 
88   /**
89    * @brief Reserve space for cache miss keys in the cache, write the evicted element to persistent storage,
90    * and record the new space position in the cache.
91    * @param[in] `reserve_size`: The number of element slots that are expected to be reserved. If the
92    * reserve_size is less than or equal to the number of slots remaining in the cache, the function does nothing.
93    * @return Whether the function was successfully executed.
94    */
95   bool TryEvict(size_t reserve_size);
96 
97   /**
98    * @brief Insert the cache miss elements into the cache from persistent storage, and copy them to the output values.
99    * @param[in] `keys`: The array records all origin keys for batch embeddings lookup operation.
100    * @param[in] `cache_miss_offsets`: The array records the offset(index) of cache miss key in origin keys array.
101    * @param[in] `cache_miss_cnt`: The number of cache miss keys.
102    * @param[out] `values`: The output embeddings.
103    * @return Whether the function was successfully executed.
104    */
105   bool InsertMissCacheFromStorage(const KeyType *keys, const size_t *cache_miss_offsets, size_t cache_miss_cnt,
106                                   ValueType *values);
107 
108   /**
109    * @brief Insert the cache miss elements into the cache from host memory.
110    * @param[in] `keys`: The array records all origin keys for batch embeddings update/insert operation.
111    * @param[in] `cache_miss_offsets`: The array records the offset(index) of cache miss key in origin keys array.
112    * @param[in] `cache_miss_cnt`: The number of cache miss keys.
113    * @param[in] `values`: Embeddings corresponding to all keys need to be updated.
114    * @return Whether the function was successfully executed.
115    */
116   bool InsertMissCacheFromMemory(const KeyType *keys, const size_t *cache_miss_offsets, size_t cache_miss_cnt,
117                                  const ValueType *values);
118 
119   // The base pointer to embedding table parameter, all embeddings in host cache is recorded in
120   // embedding_param_address_.
121   const DeviceAddress *embedding_param_address_{nullptr};
122 
123   // For performance, keep the pointer snapshot for `embedding_param_address_`.
124   ValueType *embedding_param_ptr_{nullptr};
125 
126   // Record all empty slot(idle slot or index) in embedding table tensor.
127   std::vector<int> empty_slots_;
128 };
129 }  // namespace storage
130 }  // namespace distributed
131 }  // namespace mindspore
132 #endif  // MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_DENSE_EMBEDDING_STORAGE_EMBEDDING_STORAGE_H_
133