• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_SPARSE_EMBEDDING_STORAGE_EMBEDDING_STORAGE_H_
18 #define MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_SPARSE_EMBEDDING_STORAGE_EMBEDDING_STORAGE_H_
19 
20 #include <string>
21 #include <memory>
22 #include <vector>
23 
24 #include "distributed/embedding_cache/embedding_storage/embedding_storage.h"
25 #include "runtime/device/hash_table.h"
26 
27 namespace mindspore {
28 namespace distributed {
29 namespace storage {
30 /**
31  * A derived class for sparse implementation to manage lookup and update of a huge Embedding Table for Hash Table type.
32  */
33 template <typename KeyType, typename ValueType, typename Allocator = Allocator<uint8_t>>
34 class SparseEmbeddingStorage : public EmbeddingStorage<KeyType, ValueType, Allocator> {
35  public:
36   // The cache element type, a key-value pair, key is same the key of this sparse embedding storage, value is
37   // meaningless for now.
38   using CacheElement = typename EmbeddingStorage<KeyType, ValueType, Allocator>::CacheType::Element;
39   // The hash table type corresponding to the sparse embedding storage, and they have same key-value type.
40   using HashTable = device::HashTable<KeyType, ValueType>;
41 
42   SparseEmbeddingStorage(int32_t embedding_key, size_t embedding_dim, size_t cache_capacity,
43                          const Allocator &alloc = Allocator())
44       : EmbeddingStorage<KeyType, ValueType, Allocator>(embedding_key, embedding_dim, cache_capacity, alloc) {}
45   ~SparseEmbeddingStorage() override = default;
46 
47   /**
48    * @brief Initialize the EmbeddingStorage, such as recording the hash table of the Embedding Table corresponding to
49    * the SparseEmbeddingStorage.
50    * @param[in] `device_address`: The device address of the Embedding Table parameter corresponding to the
51    * SparseEmbeddingStorage.
52    */
53   void Initialize(const DeviceAddress *device_address) override;
54 
55   /**
56    * @brief Finalize the EmbeddingStorage, release allocated resource.
57    */
58   void Finalize() override;
59 
60   /**
61    * @brief Batch embeddings lookup operation.
62    * Query Embeddings in the host cache first, if the corresponding element cannot be found in the host cache, then read
63    * the element from the persistent storage and insert host cache.
64    * Access an element of the cache generally affects the location or order of the elements in the cache, depending
65    * on different cache strategies.
66    */
67   bool Get(const ConstDataWithLen &keys, const DataWithLen &values) override;
68 
69   /**
70    * @brief Batch embeddings update/insert operation.
71    * Update/Insert Embeddings in the host cache first, if the host cache has insufficient space, the expired elements
72    * will automatically be evicted the to the persistent storage.
73    * Update or Insert an element of the cache generally affects the location or order of the elements in the cache,
74    * depending on different cache strategies.
75    */
76   bool Put(const ConstDataWithLen &keys, const ConstDataWithLen &values) override;
77 
78   /**
79    * @brief To export a slice from the storage, the size is specified by the parameter 'slice_size_in_mega_bytes' in MB.
80    */
81   std::vector<std::shared_ptr<std::vector<char>>> ExportSlice(bool incremental, bool *last_slice,
82                                                               size_t slice_size_in_mega_bytes) override;
83 
84  private:
85   /**
86    * @brief Query cache to analyse the information of cache hit and miss keys. Access an element of the cache generally
87    * affects the location or order of the elements in the cache, depending on different cache strategies.
88    * @param[in] `keys`: The array records all keys which need to query.
89    * @param[in] `key_num`: The number of keys which need to query.
90    * @param[out] `cache_miss_offsets`: The array records the offset(index) of cache miss key in origin keys array.
91    * @param[out] `cache_miss_cnt`: The number of cache miss keys.
92    * @param[out] `cache_hit`: The array records cache hit/miss for each key.
93    * keys.
94    */
95   void QueryCache(const KeyType *keys, size_t key_num, size_t *cache_miss_offsets, size_t *cache_miss_cnt,
96                   bool *cache_hit) const;
97 
98   /**
99    * @brief Reserve space for cache miss keys in the cache, write the evicted element to persistent storage,
100    * and record the new space position in the cache.
101    * @param[in] `reserve_size`: The number of element slots that are expected to be reserved. If the
102    * reserve_size is less than or equal to the number of slots remaining in the cache, the function does nothing.
103    * @return Whether the function was successfully executed.
104    */
105   bool TryEvict(size_t reserve_size);
106 
107   /**
108    * @brief Insert the cache miss elements into the cache from persistent storage, and copy them to the output values.
109    * @param[in] `keys`: The array records all origin keys for batch embeddings lookup operation.
110    * @param[in] `cache_miss_offsets`: The array records the offset(index) of cache miss key in origin keys array.
111    * @param[in] `cache_miss_cnt`: The number of cache miss keys.
112    * @param[out] `values`: The output embeddings.
113    * @return Whether the function was successfully executed.
114    */
115   bool InsertMissCacheFromStorage(const KeyType *keys, const size_t *cache_miss_offsets, size_t cache_miss_cnt,
116                                   ValueType *values);
117 
118   /**
119    * @brief Insert the cache miss elements into the cache from host memory.
120    * @param[in] `keys`: The array records all origin keys for batch embeddings update/insert operation.
121    * @param[in] `cache_miss_offsets`: The array records the offset(index) of cache miss key in origin keys array.
122    * @param[in] `cache_miss_cnt`: The number of cache miss keys.
123    * @param[in] `values`: Embeddings corresponding to all keys need to be updated.
124    * @return Whether the function was successfully executed.
125    */
126   bool InsertMissCacheFromMemory(const KeyType *keys, const size_t *cache_miss_offsets, size_t cache_miss_cnt,
127                                  const ValueType *values);
128 
129   /**
130    * @brief Read slice data from storage.
131    * @param[in] `keys_in_storage`: The array records all keys which only exist in storage.
132    * @return The byte sequence of data read form storage.
133    */
134   std::vector<std::shared_ptr<std::vector<char>>> ReadSliceFromStorage(KeyType *keys_in_storage) const;
135 
136   /**
137    * @brief Update begin and end iterator and some record status.
138    * @param[in] `last_slice`: Indicate whether the slice by export is the last slice, that is,
139    * the export is complete.
140    * @param[in] `slice_size`: The number of elements in a slice to export.
141    * @param[in] `deduplicated_keys_num_in_storage`: The total keys number which only exist in storage.
142    */
143   void UpdateExportStatus(bool last_slice, size_t slice_size, size_t deduplicated_keys_num_in_storage);
144 
145   // The base pointer to the hash table of the embedding table parameter.
146   // All embeddings in host cache is recorded in it.
147   HashTable *hash_table_{nullptr};
148 };
149 }  // namespace storage
150 }  // namespace distributed
151 }  // namespace mindspore
152 #endif  // MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_SPARSE_EMBEDDING_STORAGE_EMBEDDING_STORAGE_H_
153