• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_CHCHE_UTILS_H_
18 #define MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_CHCHE_UTILS_H_
19 
20 #include <future>
21 #include <map>
22 #include <string>
23 #include <memory>
24 #include <vector>
25 #include <tuple>
26 #include <utility>
27 #include "kernel/kernel.h"
28 #include "runtime/hardware/device_context.h"
29 #include "include/backend/visible.h"
30 #include "include/backend/distributed/embedding_cache/embedding_storage/abstract_embedding_storage.h"
31 #include "include/backend/distributed/embedding_cache/embedding_hash_map.h"
32 #include "include/backend/distributed/embedding_cache/blocking_queue.h"
33 #include "include/backend/data_queue/data_queue.h"
34 
35 namespace mindspore {
36 namespace runtime {
37 class EmbeddingCachePrefetchActor;
38 class DeviceEmbeddingOperation;
39 class DeviceDenseEmbeddingOperation;
40 class DeviceSparseEmbeddingOperation;
41 }  // namespace runtime
42 
43 namespace distributed {
44 // The local host cache size defaults to 10 times the device cache size.
45 static constexpr size_t kHostCacheScaleFactor = 10;
46 // The maximum number of concurrent threads for data prefetching.
47 static constexpr size_t kMaxThreadNum = 16;
48 // Maximum number of feature ids processed per thread.
49 static constexpr size_t kMaxIdsPerThread = 10000;
50 
51 // Prefetch 16 batchs data once.
52 static constexpr size_t kMultiBatchThreshold = 16;
53 
54 using mindspore::device::DeviceAddress;
55 using mindspore::kernel::Address;
56 
57 // The type of embedding tables.
58 enum ParamType { kUnKnown = 0, kWeight = 1, kAccumulation = 2 };
59 
60 // The initialization information for embedding tables.
61 struct ParamInitInfo {
62   std::string param_name_;
63   ParamType param_type_{kUnKnown};
64   size_t global_seed_{0};
65   size_t op_seed_{0};
66   float init_val_{0};
67 };
68 
69 // The hash tables records information such as the dimension, memory address, and cache size of the embedding table
70 // with the embedding cache enabled.
71 struct HashTableInfo {
72   size_t cache_vocab_size{0};
73   size_t host_cache_vocab_size{0};
74   size_t embedding_size{0};
75   size_t vocab_size{0};
76   // For performance, set address the snapshot of device_address.
77   Address address{nullptr, 0};
78   DeviceAddress *device_address{nullptr};
79   float *host_address{nullptr};
80   ParamInitInfo param_init_info_;
81   int32_t param_key_{-1};
82 };
83 
84 // Record the hash mapping relationship of all embedding tables with cache enabled on the device side, and the
85 // ids information that needs to be exchanged with the local host cache. Note that the following information of
86 // all embedding cache tables on the device side is same: hash mapping, and feature ids of feature vectors that need
87 // to be swapped with the local host cache.
88 struct EmbeddingDeviceCache {
89   explicit EmbeddingDeviceCache(size_t batch_ids_num);
90 
91   std::unique_ptr<int[]> device_to_host_index;
92   std::unique_ptr<int[]> device_to_host_ids;
93   std::unique_ptr<int[]> host_to_device_index;
94   std::unique_ptr<int[]> host_to_device_ids;
95 };
96 
97 // Record the hash mapping relationship of all embedding tables with cache enabled on the local host side, and the
98 // information that needs to be exchanged with the remote cache and device cache. Note that the following information of
99 // all embedding cache tables on the local host side is same: hash mapping, and feature ids of feature vectors that need
100 // to be swapped with the remote cache and device cache.
101 struct EmbeddingHostCache {
102   explicit EmbeddingHostCache(size_t batch_ids_num);
103 
104   std::unique_ptr<int[]> host_to_server_index;
105   std::unique_ptr<int[]> host_to_server_ids;
106   std::unique_ptr<int[]> server_to_host_index;
107   std::unique_ptr<int[]> server_to_host_ids;
108   std::unique_ptr<int[]> new_id_index;
109   std::unique_ptr<int[]> host_to_device_index;
110   std::unique_ptr<int[]> device_to_host_index;
111 };
112 
113 struct EmbeddingCacheStatisticsInfo {
114   size_t batch_id_count_{0};
115   size_t batch_id_unique_count_{0};
116   size_t device_to_host_size_{0};
117   size_t host_to_device_size_{0};
118   size_t host_to_server_size_{0};
119   size_t server_to_host_size_{0};
120   size_t new_id_size_{0};
121   size_t hash_hit_count_{0};
122   size_t mem_cache_swap_out_size_{0};
123   size_t mem_cache_swap_in_size_{0};
124   size_t mem_cache_hit_count_{0};
125 };
126 
127 // Origin id data item recorder.
128 struct IdDataInfo {
129   IdDataInfo() = default;
IdDataInfoIdDataInfo130   IdDataInfo(void *data, size_t size, std::vector<device::DataQueueItem> *items, bool end_of_epoch, bool end_of_file)
131       : data_(data), size_(size), items_(items), end_of_epoch_(end_of_epoch), end_of_file_(end_of_file) {}
132 
133   void *data_{nullptr};
134   size_t size_{0};
135   std::vector<device::DataQueueItem> *items_{nullptr};
136   bool end_of_epoch_{false};
137   bool end_of_file_{false};
138 };
139 
140 // The indexes data after cache prefetch.
141 struct IndexDataInfo {
142   IndexDataInfo() = default;
IndexDataInfoIndexDataInfo143   IndexDataInfo(void *data, std::vector<device::DataQueueItem> *items, bool end_of_epoch, bool end_of_file)
144       : data_(data), items_(items), end_of_epoch_(end_of_epoch), end_of_file_(end_of_file) {}
145 
146   void *data_{nullptr};
147   std::vector<device::DataQueueItem> *items_{nullptr};
148   bool end_of_epoch_{false};
149   bool end_of_file_{false};
150 };
151 
152 // The origin unique data recorder.
153 struct UniqueIds {
154   UniqueIds() = default;
155 
156   size_t data_step_{0};
157   std::vector<void *> multi_batch_data_;
158   std::vector<size_t> multi_batch_size_;
159   std::vector<std::vector<device::DataQueueItem> *> multi_batch_items_;
160   int *ids_{nullptr};
161   size_t ids_num_{0};
162 
163   bool end_of_epoch_{false};
164   bool end_of_file_{false};
165 };
166 
167 // Record all information used to analyse cache.
168 struct CacheAnalysis {
169   CacheAnalysis() = default;
CacheAnalysisCacheAnalysis170   CacheAnalysis(EmbeddingDeviceCache *embedding_device_cache, EmbeddingHostCache *embedding_host_cache,
171                 EmbeddingCacheStatisticsInfo *statistics_info, UniqueIds *unique_ids, int *indices, bool end_of_epoch,
172                 bool end_of_file)
173       : embedding_device_cache_(embedding_device_cache),
174         embedding_host_cache_(embedding_host_cache),
175         statistics_info_(statistics_info),
176         unique_ids_(unique_ids),
177         indices_(indices),
178         end_of_epoch_(end_of_epoch),
179         end_of_file_(end_of_file) {}
180 
181   // Record the ids information that needs to be exchanged with the local host cache.
182   EmbeddingDeviceCache *embedding_device_cache_{nullptr};
183   // Record the information that needs to be exchanged with the remote cache and device cache.
184   EmbeddingHostCache *embedding_host_cache_{nullptr};
185   EmbeddingCacheStatisticsInfo *statistics_info_{nullptr};
186   UniqueIds *unique_ids_{nullptr};
187   int *indices_{nullptr};
188   bool end_of_epoch_{false};
189   bool end_of_file_{false};
190 };
191 
192 // Record all ids(after unique) and indices(after cache analysis)
193 struct IdsAndIndices {
194   IdsAndIndices() = default;
IdsAndIndicesIdsAndIndices195   IdsAndIndices(UniqueIds *unique_ids, int *indices, bool end_of_epoch, bool end_of_file)
196       : unique_ids_(unique_ids), indices_(indices), end_of_epoch_(end_of_epoch), end_of_file_(end_of_file) {}
197 
198   UniqueIds *unique_ids_{nullptr};
199   int *indices_{nullptr};
200   bool end_of_epoch_{false};
201   bool end_of_file_{false};
202 };
203 
204 // The EmbeddingCacheTableManager class is used to save all Parameter information for enabling cache, such as device
205 // cache size, host cache size, etc., and can allocate memory for the embedding cache table.
206 class BACKEND_EXPORT EmbeddingCacheTableManager {
207  public:
208   using WarmUpCacheMapValue = std::tuple<tensor::TensorPtr, tensor::TensorPtr, tensor::TensorPtr>;
209   using WarmUpCacheMapEntry = std::pair<int32_t, WarmUpCacheMapValue>;
210   using WarmUpCacheMap = std::map<int32_t, WarmUpCacheMapValue>;
211   static EmbeddingCacheTableManager &GetInstance();
212 
213   // Initialize the EmbeddingCacheTableManager.
214   void Initialize();
215   // Finalize the EmbeddingCacheTableManager and release all resource.
216   void Finalize(const device::DeviceContext *device_context);
217 
218   // Insert and save dimension information of the embedding cache table.
219   void InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
220                            size_t vocab_size, int32_t param_key);
221 
222   // Parameter will modify the name. After modification, you need to re-insert all the dimension information that saves
223   // the parameter.
224   void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name);
225 
226   // Insert the initial value for the accumulation value of embedding's optimizer.
227   void InsertAccumuInitInfo(const std::string &param_name, float init_val);
228 
229   // Clone a hash table, such as the optimizer's state parameters are generally cloned from weight.
230   void CloneHashTable(const std::string &dest_param_name, int32_t dest_param_key, const std::string &src_param_name,
231                       int32_t src_param_key);
232 
233   // Set the device address for embedding cache table, using the same device address with parameter node.
234   void SetEmbeddingDeviceAddress(const std::string &param_name, DeviceAddress *device_address);
235 
236   // Alloc device memory for all embedding cache table.
237   void AllocMemForEmbedding(const device::DeviceContext *device_context);
238 
239   // Qeury device address of a embedding cache table.
240   const DeviceAddress *QueryEmbeddingDeviceAddress(const std::string &param_name) const;
241 
242   // Qeury device cache size of a embedding cache table.
243   size_t QueryHashTableSize(const std::string &param_name) const;
244 
245   // Check whether a parameter is cache enabled embedding table.
IsEmbeddingCacheTable(const std::string & param_name)246   bool IsEmbeddingCacheTable(const std::string &param_name) const { return hash_tables_.count(param_name) != 0; }
247 
248   // Set ids number of a batchsize.
set_batch_ids_num(size_t batch_ids_num)249   void set_batch_ids_num(size_t batch_ids_num) { batch_ids_num_ = batch_ids_num; }
250 
251   //  Get the offset of the id range corresponding to the embedding cache table slice on each worker in a multi-worker
252   //  automatic parallel scenario.
253   int cache_indices_lower_bound() const;
254 
255   // Set embedding vocab cache size on device.
set_cache_size(size_t cache_size)256   void set_cache_size(size_t cache_size) { device_cache_size_ = cache_size; }
257 
258   // Get embedding vocab cache size on device.
cache_size()259   size_t cache_size() const { return device_cache_size_; }
260 
261   // Set the storage format (`dense` or `sparse`) of embedding tables.
set_sparse_format(bool is_sparse)262   void set_sparse_format(bool is_sparse) { sparse_format_ = is_sparse; }
263 
is_sparse_format()264   bool is_sparse_format() { return sparse_format_; }
265 
266   // Get whether multi-stage pipeline cache prefetch is enabled.
267   bool enable_pipeline() const;
268 
269   void DumpHashTables() const;
270 
checkpoint_load_status()271   bool checkpoint_load_status() const { return checkpoint_load_status_; }
272 
set_checkpoint_load_status(bool checkpoint_load_status)273   void set_checkpoint_load_status(bool checkpoint_load_status) { checkpoint_load_status_ = checkpoint_load_status; }
274 
275   int32_t StoreWarmUpPtr(const int32_t param_key, const tensor::TensorPtr &tensor_ptr);
276 
277   int32_t StoreWarmUpPtr(const int32_t param_key, const tensor::TensorPtr &key_ptr, const tensor::TensorPtr &value_ptr,
278                          const tensor::TensorPtr &status_ptr);
279 
280   void WarmUpHostCacheAsync(const int32_t batch_count);
281 
282   std::pair<std::shared_ptr<std::future<bool>>, bool> GetWarmUpHostCacheAsyncStatus();
283 
284   bool WaitForWarmUpHostCacheComplete();
285 
286   const HashTableInfo *FindHashTablesByParamKey(const int param_key);
287 
host_cache_ptrs()288   const WarmUpCacheMap &host_cache_ptrs() { return host_cache_ptrs_; }
289 
hash_tables()290   std::map<std::string, HashTableInfo> &hash_tables() { return hash_tables_; }
291 
set_host_hash_map(const std::shared_ptr<EmbeddingHashMap> & host_hash_map)292   void set_host_hash_map(const std::shared_ptr<EmbeddingHashMap> &host_hash_map) { host_hash_map_ = host_hash_map; }
293 
294  private:
295   EmbeddingCacheTableManager() = default;
296   ~EmbeddingCacheTableManager() = default;
297   DISABLE_COPY_AND_ASSIGN(EmbeddingCacheTableManager);
298 
299   // Get embedding table slice bound info on each worker in a multi-worker automatic parallel scenario.
300   void GetEmbeddingTableSliceBound();
301 
302   void WarmUpHostCacheItemBatch(const int32_t thread_count, const WarmUpCacheMapEntry &entry);
303 
304   void WarmUpHostCacheItem(const std::shared_ptr<EmbeddingHashMap> &embedding_hash_map,
305                            const HashTableInfo *hash_table_info_ptr, const WarmUpCacheMapEntry &entry, const int start,
306                            const int end, const size_t value_len);
307 
308   void WarmUpHostCacheSync(const int32_t batch_count);
309 
310   std::atomic<bool> checkpoint_load_status_{false};
311 
312   WarmUpCacheMap host_cache_ptrs_;
313 
314   std::mutex host_cache_mutex_;
315 
316   std::shared_ptr<std::promise<bool>> host_cache_promise_{nullptr};
317 
318   // The hash tables records information such as the dimension, memory address, and cache size of the embedding table
319   // with the embedding cache enabled.
320   std::map<std::string, HashTableInfo> hash_tables_;
321 
322   std::shared_ptr<EmbeddingHashMap> device_hash_map_;
323 
324   std::shared_ptr<EmbeddingHashMap> host_hash_map_;
325 
326   int *hash_swap_index_addr_;
327   float *hash_swap_value_addr_;
328 
329   // Model parallelism is used between multiple workers, and local_embedding_slice_bounds_ records the feature range
330   // corresponding to the embedding table slice of the process.
331   std::pair<int, int> local_embedding_slice_bounds_;
332 
333   // Model parallelism is used between multiple workers, and local_device_cache_bounds_ records the local device cache
334   // range corresponding to the embedding table slice of the process.
335   std::pair<int, int> local_device_cache_bounds_;
336 
337   // Full Embedding table row num, not less than the total number of feature ids.
338   size_t vocab_size_{0};
339   // Embedding cache size(row number of embedding cache) of device cache.
340   size_t device_cache_size_{0};
341   // Embedding cache size(row number of embedding cache) of local host cache.
342   size_t host_cache_size_{0};
343   // Total ids number of a batchsize.
344   size_t batch_ids_num_{0};
345 
346   // If the storage format is sparse or dense, the default format is dense.
347   bool sparse_format_{false};
348 
349   // The batch number once cache prefetch.
350   size_t multi_batch_threshold_;
351 
352   // Record whether multi-stage pipeline cache prefetch is enabled.
353   bool enable_pipeline_{false};
354 
355   device::DeviceContext *cpu_device_context_{nullptr};
356 
357   friend class mindspore::runtime::EmbeddingCachePrefetchActor;
358   friend class mindspore::runtime::DeviceEmbeddingOperation;
359   friend class mindspore::runtime::DeviceDenseEmbeddingOperation;
360   friend class mindspore::runtime::DeviceSparseEmbeddingOperation;
361 };
362 
363 /**
364  * @brief A single instance class used to manager all EmbeddingStorage instances, EmbeddingStorage is encapsulated
365  * within the Huge Embedding Table's lookup and update. EmbeddingStorageManager provides Add and Get API to add, replace
366  * and acquire EmbeddingStorage instances.
367  */
368 class BACKEND_EXPORT EmbeddingStorageManager {
369  public:
370   static EmbeddingStorageManager &GetInstance();
371 
372   /**
373    * @brief Add the embedding storage instance corresponding to the parameter key, if embedding storage instance already
374    * exists, replace it by input parameter `embed_storage'.
375    * @param[in] `param_key`: The parameter key for embedding table which need to add.
376    * @param[in] `embed_storage`: The embedding storage instance pointer which can not be nullptr.
377    */
378   void Add(int32_t param_key, const std::shared_ptr<storage::AbstractEmbeddingStorage> &embed_storage);
379 
380   /**
381    * @brief Try get the embedding storage instance corresponding to the parameter key.
382    * @param[in] `param_key`: The parameter key for embedding table which need to acquire.
383    * @return The embedding storage instance pointer if the embedding storage already exists, else throw exception.
384    */
385   std::shared_ptr<storage::AbstractEmbeddingStorage> Get(int32_t param_key);
386 
387   /**
388    * @brief Check if the embedding storage instance corresponding to the parameter key already exists.
389    * @param[in] `param_key`: The parameter key for embedding table which need to check if the embedding storage already
390    * exists.
391    * @return true if the embedding storage already exists, else false.
392    */
Exists(int32_t param_key)393   bool Exists(int32_t param_key) const { return embedding_storages_.find(param_key) != embedding_storages_.end(); }
394 
395   /**
396    * @brief Clear all embedding storage instances and release related resources.
397    */
398   void Clear();
399 
400  private:
401   EmbeddingStorageManager() = default;
402   ~EmbeddingStorageManager() = default;
403   DISABLE_COPY_AND_ASSIGN(EmbeddingStorageManager);
404 
405   // Record all {parameter key -> embedding storage instance} pairs.
406   HashMap<int32_t, std::shared_ptr<storage::AbstractEmbeddingStorage>> embedding_storages_;
407 };
408 
409 /**
410  * @brief Create a new embedding storage instance for specific key and value type, and add the instance to
411  * EmbeddingStorageManager.
412  * @param[in] `key_value_types`: The specific key and value data type to determine the type of embedding storage
413  * instance to create.
414  * @param[in] `embedding_key`: The unique parameter key for embedding table.
415  * @param[in] `embedding_dim`: The size of each embedding vector.
416  * @param[in] `capacity`: The capacity for new embedding storage.
417  */
418 BACKEND_EXPORT void CreateEmbeddingStorage(std::pair<TypeId, TypeId> key_value_types, int32_t embedding_key,
419                                            size_t embedding_dim, size_t capacity);
420 }  // namespace distributed
421 
422 static distributed::EmbeddingCacheTableManager &embedding_cache_table_manager =
423   distributed::EmbeddingCacheTableManager::GetInstance();
424 
425 static distributed::EmbeddingStorageManager &embedding_storage_manager =
426   distributed::EmbeddingStorageManager::GetInstance();
427 }  // namespace mindspore
428 #endif  // MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_CHCHE_UTILS_H_
429