1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_CHCHE_UTILS_H_ 18 #define MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_CHCHE_UTILS_H_ 19 20 #include <future> 21 #include <map> 22 #include <string> 23 #include <memory> 24 #include <vector> 25 #include <tuple> 26 #include <utility> 27 #include "kernel/kernel.h" 28 #include "runtime/hardware/device_context.h" 29 #include "include/backend/visible.h" 30 #include "include/backend/distributed/embedding_cache/embedding_storage/abstract_embedding_storage.h" 31 #include "include/backend/distributed/embedding_cache/embedding_hash_map.h" 32 #include "include/backend/distributed/embedding_cache/blocking_queue.h" 33 #include "include/backend/data_queue/data_queue.h" 34 35 namespace mindspore { 36 namespace runtime { 37 class EmbeddingCachePrefetchActor; 38 class DeviceEmbeddingOperation; 39 class DeviceDenseEmbeddingOperation; 40 class DeviceSparseEmbeddingOperation; 41 } // namespace runtime 42 43 namespace distributed { 44 // The local host cache size defaults to 10 times the device cache size. 45 static constexpr size_t kHostCacheScaleFactor = 10; 46 // The maximum number of concurrent threads for data prefetching. 47 static constexpr size_t kMaxThreadNum = 16; 48 // Maximum number of feature ids processed per thread. 49 static constexpr size_t kMaxIdsPerThread = 10000; 50 51 // Prefetch 16 batchs data once. 52 static constexpr size_t kMultiBatchThreshold = 16; 53 54 using mindspore::device::DeviceAddress; 55 using mindspore::kernel::Address; 56 57 // The type of embedding tables. 58 enum ParamType { kUnKnown = 0, kWeight = 1, kAccumulation = 2 }; 59 60 // The initialization information for embedding tables. 61 struct ParamInitInfo { 62 std::string param_name_; 63 ParamType param_type_{kUnKnown}; 64 size_t global_seed_{0}; 65 size_t op_seed_{0}; 66 float init_val_{0}; 67 }; 68 69 // The hash tables records information such as the dimension, memory address, and cache size of the embedding table 70 // with the embedding cache enabled. 71 struct HashTableInfo { 72 size_t cache_vocab_size{0}; 73 size_t host_cache_vocab_size{0}; 74 size_t embedding_size{0}; 75 size_t vocab_size{0}; 76 // For performance, set address the snapshot of device_address. 77 Address address{nullptr, 0}; 78 DeviceAddress *device_address{nullptr}; 79 float *host_address{nullptr}; 80 ParamInitInfo param_init_info_; 81 int32_t param_key_{-1}; 82 }; 83 84 // Record the hash mapping relationship of all embedding tables with cache enabled on the device side, and the 85 // ids information that needs to be exchanged with the local host cache. Note that the following information of 86 // all embedding cache tables on the device side is same: hash mapping, and feature ids of feature vectors that need 87 // to be swapped with the local host cache. 88 struct EmbeddingDeviceCache { 89 explicit EmbeddingDeviceCache(size_t batch_ids_num); 90 91 std::unique_ptr<int[]> device_to_host_index; 92 std::unique_ptr<int[]> device_to_host_ids; 93 std::unique_ptr<int[]> host_to_device_index; 94 std::unique_ptr<int[]> host_to_device_ids; 95 }; 96 97 // Record the hash mapping relationship of all embedding tables with cache enabled on the local host side, and the 98 // information that needs to be exchanged with the remote cache and device cache. Note that the following information of 99 // all embedding cache tables on the local host side is same: hash mapping, and feature ids of feature vectors that need 100 // to be swapped with the remote cache and device cache. 101 struct EmbeddingHostCache { 102 explicit EmbeddingHostCache(size_t batch_ids_num); 103 104 std::unique_ptr<int[]> host_to_server_index; 105 std::unique_ptr<int[]> host_to_server_ids; 106 std::unique_ptr<int[]> server_to_host_index; 107 std::unique_ptr<int[]> server_to_host_ids; 108 std::unique_ptr<int[]> new_id_index; 109 std::unique_ptr<int[]> host_to_device_index; 110 std::unique_ptr<int[]> device_to_host_index; 111 }; 112 113 struct EmbeddingCacheStatisticsInfo { 114 size_t batch_id_count_{0}; 115 size_t batch_id_unique_count_{0}; 116 size_t device_to_host_size_{0}; 117 size_t host_to_device_size_{0}; 118 size_t host_to_server_size_{0}; 119 size_t server_to_host_size_{0}; 120 size_t new_id_size_{0}; 121 size_t hash_hit_count_{0}; 122 size_t mem_cache_swap_out_size_{0}; 123 size_t mem_cache_swap_in_size_{0}; 124 size_t mem_cache_hit_count_{0}; 125 }; 126 127 // Origin id data item recorder. 128 struct IdDataInfo { 129 IdDataInfo() = default; IdDataInfoIdDataInfo130 IdDataInfo(void *data, size_t size, std::vector<device::DataQueueItem> *items, bool end_of_epoch, bool end_of_file) 131 : data_(data), size_(size), items_(items), end_of_epoch_(end_of_epoch), end_of_file_(end_of_file) {} 132 133 void *data_{nullptr}; 134 size_t size_{0}; 135 std::vector<device::DataQueueItem> *items_{nullptr}; 136 bool end_of_epoch_{false}; 137 bool end_of_file_{false}; 138 }; 139 140 // The indexes data after cache prefetch. 141 struct IndexDataInfo { 142 IndexDataInfo() = default; IndexDataInfoIndexDataInfo143 IndexDataInfo(void *data, std::vector<device::DataQueueItem> *items, bool end_of_epoch, bool end_of_file) 144 : data_(data), items_(items), end_of_epoch_(end_of_epoch), end_of_file_(end_of_file) {} 145 146 void *data_{nullptr}; 147 std::vector<device::DataQueueItem> *items_{nullptr}; 148 bool end_of_epoch_{false}; 149 bool end_of_file_{false}; 150 }; 151 152 // The origin unique data recorder. 153 struct UniqueIds { 154 UniqueIds() = default; 155 156 size_t data_step_{0}; 157 std::vector<void *> multi_batch_data_; 158 std::vector<size_t> multi_batch_size_; 159 std::vector<std::vector<device::DataQueueItem> *> multi_batch_items_; 160 int *ids_{nullptr}; 161 size_t ids_num_{0}; 162 163 bool end_of_epoch_{false}; 164 bool end_of_file_{false}; 165 }; 166 167 // Record all information used to analyse cache. 168 struct CacheAnalysis { 169 CacheAnalysis() = default; CacheAnalysisCacheAnalysis170 CacheAnalysis(EmbeddingDeviceCache *embedding_device_cache, EmbeddingHostCache *embedding_host_cache, 171 EmbeddingCacheStatisticsInfo *statistics_info, UniqueIds *unique_ids, int *indices, bool end_of_epoch, 172 bool end_of_file) 173 : embedding_device_cache_(embedding_device_cache), 174 embedding_host_cache_(embedding_host_cache), 175 statistics_info_(statistics_info), 176 unique_ids_(unique_ids), 177 indices_(indices), 178 end_of_epoch_(end_of_epoch), 179 end_of_file_(end_of_file) {} 180 181 // Record the ids information that needs to be exchanged with the local host cache. 182 EmbeddingDeviceCache *embedding_device_cache_{nullptr}; 183 // Record the information that needs to be exchanged with the remote cache and device cache. 184 EmbeddingHostCache *embedding_host_cache_{nullptr}; 185 EmbeddingCacheStatisticsInfo *statistics_info_{nullptr}; 186 UniqueIds *unique_ids_{nullptr}; 187 int *indices_{nullptr}; 188 bool end_of_epoch_{false}; 189 bool end_of_file_{false}; 190 }; 191 192 // Record all ids(after unique) and indices(after cache analysis) 193 struct IdsAndIndices { 194 IdsAndIndices() = default; IdsAndIndicesIdsAndIndices195 IdsAndIndices(UniqueIds *unique_ids, int *indices, bool end_of_epoch, bool end_of_file) 196 : unique_ids_(unique_ids), indices_(indices), end_of_epoch_(end_of_epoch), end_of_file_(end_of_file) {} 197 198 UniqueIds *unique_ids_{nullptr}; 199 int *indices_{nullptr}; 200 bool end_of_epoch_{false}; 201 bool end_of_file_{false}; 202 }; 203 204 // The EmbeddingCacheTableManager class is used to save all Parameter information for enabling cache, such as device 205 // cache size, host cache size, etc., and can allocate memory for the embedding cache table. 206 class BACKEND_EXPORT EmbeddingCacheTableManager { 207 public: 208 using WarmUpCacheMapValue = std::tuple<tensor::TensorPtr, tensor::TensorPtr, tensor::TensorPtr>; 209 using WarmUpCacheMapEntry = std::pair<int32_t, WarmUpCacheMapValue>; 210 using WarmUpCacheMap = std::map<int32_t, WarmUpCacheMapValue>; 211 static EmbeddingCacheTableManager &GetInstance(); 212 213 // Initialize the EmbeddingCacheTableManager. 214 void Initialize(); 215 // Finalize the EmbeddingCacheTableManager and release all resource. 216 void Finalize(const device::DeviceContext *device_context); 217 218 // Insert and save dimension information of the embedding cache table. 219 void InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, 220 size_t vocab_size, int32_t param_key); 221 222 // Parameter will modify the name. After modification, you need to re-insert all the dimension information that saves 223 // the parameter. 224 void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name); 225 226 // Insert the initial value for the accumulation value of embedding's optimizer. 227 void InsertAccumuInitInfo(const std::string ¶m_name, float init_val); 228 229 // Clone a hash table, such as the optimizer's state parameters are generally cloned from weight. 230 void CloneHashTable(const std::string &dest_param_name, int32_t dest_param_key, const std::string &src_param_name, 231 int32_t src_param_key); 232 233 // Set the device address for embedding cache table, using the same device address with parameter node. 234 void SetEmbeddingDeviceAddress(const std::string ¶m_name, DeviceAddress *device_address); 235 236 // Alloc device memory for all embedding cache table. 237 void AllocMemForEmbedding(const device::DeviceContext *device_context); 238 239 // Qeury device address of a embedding cache table. 240 const DeviceAddress *QueryEmbeddingDeviceAddress(const std::string ¶m_name) const; 241 242 // Qeury device cache size of a embedding cache table. 243 size_t QueryHashTableSize(const std::string ¶m_name) const; 244 245 // Check whether a parameter is cache enabled embedding table. IsEmbeddingCacheTable(const std::string & param_name)246 bool IsEmbeddingCacheTable(const std::string ¶m_name) const { return hash_tables_.count(param_name) != 0; } 247 248 // Set ids number of a batchsize. set_batch_ids_num(size_t batch_ids_num)249 void set_batch_ids_num(size_t batch_ids_num) { batch_ids_num_ = batch_ids_num; } 250 251 // Get the offset of the id range corresponding to the embedding cache table slice on each worker in a multi-worker 252 // automatic parallel scenario. 253 int cache_indices_lower_bound() const; 254 255 // Set embedding vocab cache size on device. set_cache_size(size_t cache_size)256 void set_cache_size(size_t cache_size) { device_cache_size_ = cache_size; } 257 258 // Get embedding vocab cache size on device. cache_size()259 size_t cache_size() const { return device_cache_size_; } 260 261 // Set the storage format (`dense` or `sparse`) of embedding tables. set_sparse_format(bool is_sparse)262 void set_sparse_format(bool is_sparse) { sparse_format_ = is_sparse; } 263 is_sparse_format()264 bool is_sparse_format() { return sparse_format_; } 265 266 // Get whether multi-stage pipeline cache prefetch is enabled. 267 bool enable_pipeline() const; 268 269 void DumpHashTables() const; 270 checkpoint_load_status()271 bool checkpoint_load_status() const { return checkpoint_load_status_; } 272 set_checkpoint_load_status(bool checkpoint_load_status)273 void set_checkpoint_load_status(bool checkpoint_load_status) { checkpoint_load_status_ = checkpoint_load_status; } 274 275 int32_t StoreWarmUpPtr(const int32_t param_key, const tensor::TensorPtr &tensor_ptr); 276 277 int32_t StoreWarmUpPtr(const int32_t param_key, const tensor::TensorPtr &key_ptr, const tensor::TensorPtr &value_ptr, 278 const tensor::TensorPtr &status_ptr); 279 280 void WarmUpHostCacheAsync(const int32_t batch_count); 281 282 std::pair<std::shared_ptr<std::future<bool>>, bool> GetWarmUpHostCacheAsyncStatus(); 283 284 bool WaitForWarmUpHostCacheComplete(); 285 286 const HashTableInfo *FindHashTablesByParamKey(const int param_key); 287 host_cache_ptrs()288 const WarmUpCacheMap &host_cache_ptrs() { return host_cache_ptrs_; } 289 hash_tables()290 std::map<std::string, HashTableInfo> &hash_tables() { return hash_tables_; } 291 set_host_hash_map(const std::shared_ptr<EmbeddingHashMap> & host_hash_map)292 void set_host_hash_map(const std::shared_ptr<EmbeddingHashMap> &host_hash_map) { host_hash_map_ = host_hash_map; } 293 294 private: 295 EmbeddingCacheTableManager() = default; 296 ~EmbeddingCacheTableManager() = default; 297 DISABLE_COPY_AND_ASSIGN(EmbeddingCacheTableManager); 298 299 // Get embedding table slice bound info on each worker in a multi-worker automatic parallel scenario. 300 void GetEmbeddingTableSliceBound(); 301 302 void WarmUpHostCacheItemBatch(const int32_t thread_count, const WarmUpCacheMapEntry &entry); 303 304 void WarmUpHostCacheItem(const std::shared_ptr<EmbeddingHashMap> &embedding_hash_map, 305 const HashTableInfo *hash_table_info_ptr, const WarmUpCacheMapEntry &entry, const int start, 306 const int end, const size_t value_len); 307 308 void WarmUpHostCacheSync(const int32_t batch_count); 309 310 std::atomic<bool> checkpoint_load_status_{false}; 311 312 WarmUpCacheMap host_cache_ptrs_; 313 314 std::mutex host_cache_mutex_; 315 316 std::shared_ptr<std::promise<bool>> host_cache_promise_{nullptr}; 317 318 // The hash tables records information such as the dimension, memory address, and cache size of the embedding table 319 // with the embedding cache enabled. 320 std::map<std::string, HashTableInfo> hash_tables_; 321 322 std::shared_ptr<EmbeddingHashMap> device_hash_map_; 323 324 std::shared_ptr<EmbeddingHashMap> host_hash_map_; 325 326 int *hash_swap_index_addr_; 327 float *hash_swap_value_addr_; 328 329 // Model parallelism is used between multiple workers, and local_embedding_slice_bounds_ records the feature range 330 // corresponding to the embedding table slice of the process. 331 std::pair<int, int> local_embedding_slice_bounds_; 332 333 // Model parallelism is used between multiple workers, and local_device_cache_bounds_ records the local device cache 334 // range corresponding to the embedding table slice of the process. 335 std::pair<int, int> local_device_cache_bounds_; 336 337 // Full Embedding table row num, not less than the total number of feature ids. 338 size_t vocab_size_{0}; 339 // Embedding cache size(row number of embedding cache) of device cache. 340 size_t device_cache_size_{0}; 341 // Embedding cache size(row number of embedding cache) of local host cache. 342 size_t host_cache_size_{0}; 343 // Total ids number of a batchsize. 344 size_t batch_ids_num_{0}; 345 346 // If the storage format is sparse or dense, the default format is dense. 347 bool sparse_format_{false}; 348 349 // The batch number once cache prefetch. 350 size_t multi_batch_threshold_; 351 352 // Record whether multi-stage pipeline cache prefetch is enabled. 353 bool enable_pipeline_{false}; 354 355 device::DeviceContext *cpu_device_context_{nullptr}; 356 357 friend class mindspore::runtime::EmbeddingCachePrefetchActor; 358 friend class mindspore::runtime::DeviceEmbeddingOperation; 359 friend class mindspore::runtime::DeviceDenseEmbeddingOperation; 360 friend class mindspore::runtime::DeviceSparseEmbeddingOperation; 361 }; 362 363 /** 364 * @brief A single instance class used to manager all EmbeddingStorage instances, EmbeddingStorage is encapsulated 365 * within the Huge Embedding Table's lookup and update. EmbeddingStorageManager provides Add and Get API to add, replace 366 * and acquire EmbeddingStorage instances. 367 */ 368 class BACKEND_EXPORT EmbeddingStorageManager { 369 public: 370 static EmbeddingStorageManager &GetInstance(); 371 372 /** 373 * @brief Add the embedding storage instance corresponding to the parameter key, if embedding storage instance already 374 * exists, replace it by input parameter `embed_storage'. 375 * @param[in] `param_key`: The parameter key for embedding table which need to add. 376 * @param[in] `embed_storage`: The embedding storage instance pointer which can not be nullptr. 377 */ 378 void Add(int32_t param_key, const std::shared_ptr<storage::AbstractEmbeddingStorage> &embed_storage); 379 380 /** 381 * @brief Try get the embedding storage instance corresponding to the parameter key. 382 * @param[in] `param_key`: The parameter key for embedding table which need to acquire. 383 * @return The embedding storage instance pointer if the embedding storage already exists, else throw exception. 384 */ 385 std::shared_ptr<storage::AbstractEmbeddingStorage> Get(int32_t param_key); 386 387 /** 388 * @brief Check if the embedding storage instance corresponding to the parameter key already exists. 389 * @param[in] `param_key`: The parameter key for embedding table which need to check if the embedding storage already 390 * exists. 391 * @return true if the embedding storage already exists, else false. 392 */ Exists(int32_t param_key)393 bool Exists(int32_t param_key) const { return embedding_storages_.find(param_key) != embedding_storages_.end(); } 394 395 /** 396 * @brief Clear all embedding storage instances and release related resources. 397 */ 398 void Clear(); 399 400 private: 401 EmbeddingStorageManager() = default; 402 ~EmbeddingStorageManager() = default; 403 DISABLE_COPY_AND_ASSIGN(EmbeddingStorageManager); 404 405 // Record all {parameter key -> embedding storage instance} pairs. 406 HashMap<int32_t, std::shared_ptr<storage::AbstractEmbeddingStorage>> embedding_storages_; 407 }; 408 409 /** 410 * @brief Create a new embedding storage instance for specific key and value type, and add the instance to 411 * EmbeddingStorageManager. 412 * @param[in] `key_value_types`: The specific key and value data type to determine the type of embedding storage 413 * instance to create. 414 * @param[in] `embedding_key`: The unique parameter key for embedding table. 415 * @param[in] `embedding_dim`: The size of each embedding vector. 416 * @param[in] `capacity`: The capacity for new embedding storage. 417 */ 418 BACKEND_EXPORT void CreateEmbeddingStorage(std::pair<TypeId, TypeId> key_value_types, int32_t embedding_key, 419 size_t embedding_dim, size_t capacity); 420 } // namespace distributed 421 422 static distributed::EmbeddingCacheTableManager &embedding_cache_table_manager = 423 distributed::EmbeddingCacheTableManager::GetInstance(); 424 425 static distributed::EmbeddingStorageManager &embedding_storage_manager = 426 distributed::EmbeddingStorageManager::GetInstance(); 427 } // namespace mindspore 428 #endif // MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_CHCHE_UTILS_H_ 429