• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_MANAGER_H_
18 #define MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_MANAGER_H_
19 
20 #include <map>
21 #include <string>
22 #include <vector>
23 #include <thread>
24 #include <atomic>
25 #include <utility>
26 #include <memory>
27 #include <condition_variable>
28 #include "utils/ms_context.h"
29 #include "backend/kernel_compiler/kernel.h"
30 #include "utils/shape_utils.h"
31 #include "ir/tensor.h"
32 #include "ps/constants.h"
33 #include "ps/worker.h"
34 #include "ps/ps_context.h"
35 #include "ps/ps_cache/ps_data/ps_data_prefetch.h"
36 #include "ps/ps_cache/embedding_hash_map.h"
37 #include "ps/ps_cache/ps_cache_factory.h"
38 
39 namespace mindspore {
40 namespace ps {
41 constexpr size_t kHostCacheScaleFactor = 10;
42 constexpr size_t kMaxThreadNum = 16;
43 constexpr size_t kMaxIdsPerThread = 10000;
44 using mindspore::kernel::Address;
45 
46 struct HashTableInfo {
47   size_t cache_vocab_size{0};
48   size_t host_cache_vocab_size{0};
49   size_t embedding_size{0};
50   size_t vocab_size{0};
51   Address device_address{nullptr, 0};
52   std::shared_ptr<float[]> host_address{nullptr};
53   ParamInitInfo param_init_info_;
54 };
55 
56 struct EmbeddingDeviceCache {
EmbeddingDeviceCacheEmbeddingDeviceCache57   EmbeddingDeviceCache(size_t batch_elements, size_t cache_vocab_size)
58       : hash_swap_index_addr_(nullptr), hash_swap_value_addr_(nullptr) {
59     device_to_host_index = std::make_unique<int[]>(batch_elements);
60     device_to_host_ids = std::make_unique<int[]>(batch_elements);
61     host_to_device_index = std::make_unique<int[]>(batch_elements);
62     host_to_device_ids = std::make_unique<int[]>(batch_elements);
63     device_hash_map_ = std::make_shared<EmbeddingHashMap>(0, cache_vocab_size);
64     auto context_ptr = MsContext::GetInstance();
65     MS_EXCEPTION_IF_NULL(context_ptr);
66     auto devcie_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
67     cache_ = PsCacheFactory::Get().ps_cache(devcie_target);
68   }
69   std::unique_ptr<int[]> device_to_host_index;
70   std::unique_ptr<int[]> device_to_host_ids;
71   std::unique_ptr<int[]> host_to_device_index;
72   std::unique_ptr<int[]> host_to_device_ids;
73   int *hash_swap_index_addr_;
74   float *hash_swap_value_addr_;
75   std::shared_ptr<EmbeddingHashMap> device_hash_map_;
76   std::shared_ptr<PsCacheBasic> cache_;
77 };
78 
79 struct EmbeddingHostCache {
EmbeddingHostCacheEmbeddingHostCache80   EmbeddingHostCache(size_t batch_elements, size_t host_cache_vocab_size) {
81     host_to_server_index = std::make_unique<int[]>(batch_elements);
82     host_to_server_ids = std::make_unique<int[]>(batch_elements);
83     server_to_host_index = std::make_unique<int[]>(batch_elements);
84     server_to_host_ids = std::make_unique<int[]>(batch_elements);
85     host_to_device_index = std::make_unique<int[]>(batch_elements);
86     device_to_host_index = std::make_unique<int[]>(batch_elements);
87     host_hash_map_ = std::make_shared<EmbeddingHashMap>(0, host_cache_vocab_size);
88   }
89   std::unique_ptr<int[]> host_to_server_index;
90   std::unique_ptr<int[]> host_to_server_ids;
91   std::unique_ptr<int[]> server_to_host_index;
92   std::unique_ptr<int[]> server_to_host_ids;
93   std::unique_ptr<int[]> host_to_device_index;
94   std::unique_ptr<int[]> device_to_host_index;
95   std::shared_ptr<EmbeddingHashMap> host_hash_map_;
96 };
97 
98 struct PsCacheStatisticsInfo {
99   size_t batch_id_count_{0};
100   size_t batch_id_unique_count_{0};
101   size_t device_to_host_size_{0};
102   size_t host_to_device_size_{0};
103   size_t host_to_server_size_{0};
104   size_t server_to_host_size_{0};
105   size_t hash_hit_count_{0};
106   size_t mem_cache_swap_out_size_{0};
107   size_t mem_cache_swap_in_size_{0};
108   size_t mem_cache_hit_count_{0};
109 };
110 
111 class PsCacheManager {
112  public:
GetInstance()113   static PsCacheManager &GetInstance() {
114     static PsCacheManager instance;
115     return instance;
116   }
117   void Initialize();
118   void InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
119                            size_t vocab_size);
120   void InsertWeightInitInfo(const std::string &param_name, size_t global_seed, size_t op_seed);
121   void InsertAccumuInitInfo(const std::string &param_name, float init_val);
122   void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
123                              size_t cache_vocab_size, size_t embedding_size);
124   void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name);
125   const Address &QueryHashTableAddr(const std::string &param_name) const;
126   const size_t &QueryHashTableSize(const std::string &param_name) const;
IsHashTable(const std::string & param_name)127   bool IsHashTable(const std::string &param_name) { return hash_tables_.count(param_name) != 0; }
set_batch_elements(size_t batch_elements)128   void set_batch_elements(size_t batch_elements) { batch_elements_ = batch_elements; }
set_rank_id(uint32_t rank_id)129   void set_rank_id(uint32_t rank_id) { rank_id_ = rank_id; }
initialized_ps_cache()130   bool initialized_ps_cache() const { return initialized_ps_cache_; }
vocab_cache_size()131   size_t vocab_cache_size() const { return vocab_cache_size_; }
132   int cache_indices_lower_bound() const;
133   void DoProcessData(uint32_t device_id, const void *context);
134   void IncreaseGraphStep(const std::string &channel_name);
135   void SyncEmbeddingTable();
136   void Finalize();
137   void DumpHashTables(bool dump_device_tables = false) const;
138 
139  private:
140   PsCacheManager() = default;
141   ~PsCacheManager() = default;
142   PsCacheManager(const PsCacheManager &) = delete;
143   PsCacheManager &operator=(const PsCacheManager &) = delete;
144   bool IncreaseStep();
set_current_graph_step()145   void set_current_graph_step() { graph_running_step_ = graph_step_; }
146   std::string channel_name();
147   void set_channel_name(const std::string channel_name);
148   void InitParameterServer();
149   void InitDataChannel();
150   void AllocMemForHashTable();
151   void SetLocalIdRank();
152   void ProcessDataTask(uint32_t device_id, const void *context);
153   bool ProcessData();
154   bool ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index);
155   bool WaitGraphRun();
156   bool ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device, int *hash_index);
157   bool ParseHostDataHostToDevice(size_t id);
158   bool ParseHostDataDeviceToHost();
159   bool HashSwapDeviceOut(int *swap_out_index, std::vector<float> *swap_out_data, const HashTableInfo &hash_info);
160   bool HashSwapDeviceIn(const int *swap_in_ids, const int *swap_in_index, const HashTableInfo &hash_info, size_t key);
161   bool HashSwapHostToDevice(const HashTableInfo &hash_info);
162   bool HashSwapDeviceToHost(const HashTableInfo &hash_info);
163   bool HashSwapHostToServer(size_t key, const HashTableInfo &hash_info);
164   bool HashSwapServerToHost(size_t key, const HashTableInfo &hash_info);
165   bool InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, const int *insert_indices,
166                            const float *insert_data, float *hash_table_addr);
167   bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
168                            const int *indices_addr, float *output_addr);
169   bool UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *const swap_out_ids, size_t key);
170   void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr,
171                        const int *indices_addr, float *output_addr);
172   bool CheckFinishInsertInitInfo() const;
173   void AddEmbeddingTable() const;
174   void DumpStatisticsInfo(size_t each_print_step = 1000);
175   bool SyncHostEmbeddingTable();
176   bool SyncDeviceEmbeddingTable();
177   bool CheckCacheHitOrOutRangeTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device,
178                                    bool *out_range, size_t *hash_hit_count);
179   bool CheckCacheHitOrOutRange(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device,
180                                bool *out_range);
181   bool ResetEmbeddingHashMap();
182 
183   bool initialized_ps_cache_{false};
184   std::string channel_name_;
185   std::mutex channel_mutex_;
186   std::atomic_ulong graph_step_{0};
187   size_t graph_running_step_{0};
188   size_t data_step_{0};
189   std::mutex data_mutex_;
190   std::condition_variable data_prase_;
191   std::condition_variable insert_init_info_;
192   std::thread process_data_thread_;
193 
194   std::map<std::string, HashTableInfo> hash_tables_;
195   std::shared_ptr<EmbeddingDeviceCache> embedding_device_cache_;
196   std::shared_ptr<EmbeddingHostCache> embedding_host_cache_;
197 
198   size_t vocab_size_{0};
199   size_t vocab_cache_size_{0};
200   size_t host_vocab_cache_size_{0};
201   size_t batch_elements_{0};
202   PsCacheStatisticsInfo statistics_info_;
203   std::pair<int, int> emb_table_slice_bounds_;
204   std::pair<int, int> cache_indices_bounds_;
205   int vocab_cache_size_diff_{0};
206   uint32_t rank_id_{0};
207   std::atomic_bool finish_insert_init_info_{false};
208   std::atomic_bool finish_init_parameter_server_{false};
209   std::atomic_bool running_{false};
210   bool finish_embedding_table_sync_{false};
211   bool device_need_wait_graph_{false};
212   bool host_need_wait_graph_{false};
213 };
214 
215 static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance();
216 }  // namespace ps
217 }  // namespace mindspore
218 #endif  // MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_MANAGER_H_
219