1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_MANAGER_H_ 18 #define MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_MANAGER_H_ 19 20 #include <map> 21 #include <string> 22 #include <vector> 23 #include <thread> 24 #include <atomic> 25 #include <utility> 26 #include <memory> 27 #include <condition_variable> 28 #include "utils/ms_context.h" 29 #include "backend/kernel_compiler/kernel.h" 30 #include "utils/shape_utils.h" 31 #include "ir/tensor.h" 32 #include "ps/constants.h" 33 #include "ps/worker.h" 34 #include "ps/ps_context.h" 35 #include "ps/ps_cache/ps_data/ps_data_prefetch.h" 36 #include "ps/ps_cache/embedding_hash_map.h" 37 #include "ps/ps_cache/ps_cache_factory.h" 38 39 namespace mindspore { 40 namespace ps { 41 constexpr size_t kHostCacheScaleFactor = 10; 42 constexpr size_t kMaxThreadNum = 16; 43 constexpr size_t kMaxIdsPerThread = 10000; 44 using mindspore::kernel::Address; 45 46 struct HashTableInfo { 47 size_t cache_vocab_size{0}; 48 size_t host_cache_vocab_size{0}; 49 size_t embedding_size{0}; 50 size_t vocab_size{0}; 51 Address device_address{nullptr, 0}; 52 std::shared_ptr<float[]> host_address{nullptr}; 53 ParamInitInfo param_init_info_; 54 }; 55 56 struct EmbeddingDeviceCache { EmbeddingDeviceCacheEmbeddingDeviceCache57 EmbeddingDeviceCache(size_t batch_elements, size_t cache_vocab_size) 58 : hash_swap_index_addr_(nullptr), hash_swap_value_addr_(nullptr) { 59 device_to_host_index = std::make_unique<int[]>(batch_elements); 60 device_to_host_ids = std::make_unique<int[]>(batch_elements); 61 host_to_device_index = std::make_unique<int[]>(batch_elements); 62 host_to_device_ids = std::make_unique<int[]>(batch_elements); 63 device_hash_map_ = std::make_shared<EmbeddingHashMap>(0, cache_vocab_size); 64 auto context_ptr = MsContext::GetInstance(); 65 MS_EXCEPTION_IF_NULL(context_ptr); 66 auto devcie_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); 67 cache_ = PsCacheFactory::Get().ps_cache(devcie_target); 68 } 69 std::unique_ptr<int[]> device_to_host_index; 70 std::unique_ptr<int[]> device_to_host_ids; 71 std::unique_ptr<int[]> host_to_device_index; 72 std::unique_ptr<int[]> host_to_device_ids; 73 int *hash_swap_index_addr_; 74 float *hash_swap_value_addr_; 75 std::shared_ptr<EmbeddingHashMap> device_hash_map_; 76 std::shared_ptr<PsCacheBasic> cache_; 77 }; 78 79 struct EmbeddingHostCache { EmbeddingHostCacheEmbeddingHostCache80 EmbeddingHostCache(size_t batch_elements, size_t host_cache_vocab_size) { 81 host_to_server_index = std::make_unique<int[]>(batch_elements); 82 host_to_server_ids = std::make_unique<int[]>(batch_elements); 83 server_to_host_index = std::make_unique<int[]>(batch_elements); 84 server_to_host_ids = std::make_unique<int[]>(batch_elements); 85 host_to_device_index = std::make_unique<int[]>(batch_elements); 86 device_to_host_index = std::make_unique<int[]>(batch_elements); 87 host_hash_map_ = std::make_shared<EmbeddingHashMap>(0, host_cache_vocab_size); 88 } 89 std::unique_ptr<int[]> host_to_server_index; 90 std::unique_ptr<int[]> host_to_server_ids; 91 std::unique_ptr<int[]> server_to_host_index; 92 std::unique_ptr<int[]> server_to_host_ids; 93 std::unique_ptr<int[]> host_to_device_index; 94 std::unique_ptr<int[]> device_to_host_index; 95 std::shared_ptr<EmbeddingHashMap> host_hash_map_; 96 }; 97 98 struct PsCacheStatisticsInfo { 99 size_t batch_id_count_{0}; 100 size_t batch_id_unique_count_{0}; 101 size_t device_to_host_size_{0}; 102 size_t host_to_device_size_{0}; 103 size_t host_to_server_size_{0}; 104 size_t server_to_host_size_{0}; 105 size_t hash_hit_count_{0}; 106 size_t mem_cache_swap_out_size_{0}; 107 size_t mem_cache_swap_in_size_{0}; 108 size_t mem_cache_hit_count_{0}; 109 }; 110 111 class PsCacheManager { 112 public: GetInstance()113 static PsCacheManager &GetInstance() { 114 static PsCacheManager instance; 115 return instance; 116 } 117 void Initialize(); 118 void InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, 119 size_t vocab_size); 120 void InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed); 121 void InsertAccumuInitInfo(const std::string ¶m_name, float init_val); 122 void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, 123 size_t cache_vocab_size, size_t embedding_size); 124 void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name); 125 const Address &QueryHashTableAddr(const std::string ¶m_name) const; 126 const size_t &QueryHashTableSize(const std::string ¶m_name) const; IsHashTable(const std::string & param_name)127 bool IsHashTable(const std::string ¶m_name) { return hash_tables_.count(param_name) != 0; } set_batch_elements(size_t batch_elements)128 void set_batch_elements(size_t batch_elements) { batch_elements_ = batch_elements; } set_rank_id(uint32_t rank_id)129 void set_rank_id(uint32_t rank_id) { rank_id_ = rank_id; } initialized_ps_cache()130 bool initialized_ps_cache() const { return initialized_ps_cache_; } vocab_cache_size()131 size_t vocab_cache_size() const { return vocab_cache_size_; } 132 int cache_indices_lower_bound() const; 133 void DoProcessData(uint32_t device_id, const void *context); 134 void IncreaseGraphStep(const std::string &channel_name); 135 void SyncEmbeddingTable(); 136 void Finalize(); 137 void DumpHashTables(bool dump_device_tables = false) const; 138 139 private: 140 PsCacheManager() = default; 141 ~PsCacheManager() = default; 142 PsCacheManager(const PsCacheManager &) = delete; 143 PsCacheManager &operator=(const PsCacheManager &) = delete; 144 bool IncreaseStep(); set_current_graph_step()145 void set_current_graph_step() { graph_running_step_ = graph_step_; } 146 std::string channel_name(); 147 void set_channel_name(const std::string channel_name); 148 void InitParameterServer(); 149 void InitDataChannel(); 150 void AllocMemForHashTable(); 151 void SetLocalIdRank(); 152 void ProcessDataTask(uint32_t device_id, const void *context); 153 bool ProcessData(); 154 bool ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index); 155 bool WaitGraphRun(); 156 bool ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device, int *hash_index); 157 bool ParseHostDataHostToDevice(size_t id); 158 bool ParseHostDataDeviceToHost(); 159 bool HashSwapDeviceOut(int *swap_out_index, std::vector<float> *swap_out_data, const HashTableInfo &hash_info); 160 bool HashSwapDeviceIn(const int *swap_in_ids, const int *swap_in_index, const HashTableInfo &hash_info, size_t key); 161 bool HashSwapHostToDevice(const HashTableInfo &hash_info); 162 bool HashSwapDeviceToHost(const HashTableInfo &hash_info); 163 bool HashSwapHostToServer(size_t key, const HashTableInfo &hash_info); 164 bool HashSwapServerToHost(size_t key, const HashTableInfo &hash_info); 165 bool InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, const int *insert_indices, 166 const float *insert_data, float *hash_table_addr); 167 bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, 168 const int *indices_addr, float *output_addr); 169 bool UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *const swap_out_ids, size_t key); 170 void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr, 171 const int *indices_addr, float *output_addr); 172 bool CheckFinishInsertInitInfo() const; 173 void AddEmbeddingTable() const; 174 void DumpStatisticsInfo(size_t each_print_step = 1000); 175 bool SyncHostEmbeddingTable(); 176 bool SyncDeviceEmbeddingTable(); 177 bool CheckCacheHitOrOutRangeTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device, 178 bool *out_range, size_t *hash_hit_count); 179 bool CheckCacheHitOrOutRange(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device, 180 bool *out_range); 181 bool ResetEmbeddingHashMap(); 182 183 bool initialized_ps_cache_{false}; 184 std::string channel_name_; 185 std::mutex channel_mutex_; 186 std::atomic_ulong graph_step_{0}; 187 size_t graph_running_step_{0}; 188 size_t data_step_{0}; 189 std::mutex data_mutex_; 190 std::condition_variable data_prase_; 191 std::condition_variable insert_init_info_; 192 std::thread process_data_thread_; 193 194 std::map<std::string, HashTableInfo> hash_tables_; 195 std::shared_ptr<EmbeddingDeviceCache> embedding_device_cache_; 196 std::shared_ptr<EmbeddingHostCache> embedding_host_cache_; 197 198 size_t vocab_size_{0}; 199 size_t vocab_cache_size_{0}; 200 size_t host_vocab_cache_size_{0}; 201 size_t batch_elements_{0}; 202 PsCacheStatisticsInfo statistics_info_; 203 std::pair<int, int> emb_table_slice_bounds_; 204 std::pair<int, int> cache_indices_bounds_; 205 int vocab_cache_size_diff_{0}; 206 uint32_t rank_id_{0}; 207 std::atomic_bool finish_insert_init_info_{false}; 208 std::atomic_bool finish_init_parameter_server_{false}; 209 std::atomic_bool running_{false}; 210 bool finish_embedding_table_sync_{false}; 211 bool device_need_wait_graph_{false}; 212 bool host_need_wait_graph_{false}; 213 }; 214 215 static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance(); 216 } // namespace ps 217 } // namespace mindspore 218 #endif // MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_MANAGER_H_ 219