1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_ 18 #define MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_ 19 20 #include <math.h> 21 #include <utility> 22 #include <memory> 23 #include <vector> 24 #include <unordered_map> 25 #include "utils/convert_utils_base.h" 26 27 namespace mindspore { 28 namespace ps { 29 static const size_t INVALID_STEP_VALUE = 0; 30 static const int INVALID_INDEX_VALUE = -1; 31 32 struct HashMapElement { 33 int id_{INVALID_INDEX_VALUE}; 34 size_t step_{INVALID_STEP_VALUE}; IsEmptyHashMapElement35 bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; } IsExpiredHashMapElement36 bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; } IsStepHashMapElement37 bool IsStep(size_t step) const { return step_ == step; } set_idHashMapElement38 void set_id(int id) { id_ = id; } set_stepHashMapElement39 void set_step(size_t step) { step_ = step; } 40 }; 41 42 // Hash table is held in device, HashMap is used to manage hash table in host. 43 class EmbeddingHashMap { 44 public: EmbeddingHashMap(size_t hash_count,size_t hash_capacity)45 EmbeddingHashMap(size_t hash_count, size_t hash_capacity) 46 : hash_count_(hash_count), 47 hash_capacity_(hash_capacity), 48 current_pos_(0), 49 current_batch_start_pos_(0), 50 graph_running_index_num_(0), 51 graph_running_index_pos_(0), 52 expired_element_full_(false) { 53 hash_map_elements_.resize(hash_capacity); 54 // In multi-device mode, embedding table are distributed on different devices by ID interval, 55 // and IDs outside the range of local device will use the front and back positions of the table, 56 // the positions are reserved for this. 57 hash_map_elements_.front().set_step(SIZE_MAX); 58 hash_map_elements_.back().set_step(SIZE_MAX); 59 graph_running_index_ = std::make_unique<int[]>(hash_capacity); 60 } 61 virtual ~EmbeddingHashMap() = default; 62 int ParseData(const int id, int *const swap_out_index, int *const swap_out_ids, const size_t data_step, 63 const size_t graph_running_step, size_t *const swap_out_size, bool *const need_wait_graph); hash_step(const int hash_index)64 size_t hash_step(const int hash_index) const { return hash_map_elements_[hash_index].step_; } set_hash_step(const int hash_index,const size_t step)65 void set_hash_step(const int hash_index, const size_t step) { hash_map_elements_[hash_index].set_step(step); } hash_id_to_index()66 const std::unordered_map<int, int> &hash_id_to_index() const { return hash_id_to_index_; } hash_capacity()67 size_t hash_capacity() const { return hash_capacity_; } 68 void DumpHashMap(); 69 void Reset(); 70 71 private: 72 int FindInsertionPos(const size_t data_step, const size_t graph_running_step, bool *const need_swap, 73 bool *const need_wait_graph); 74 size_t hash_count_; 75 size_t hash_capacity_; 76 std::vector<HashMapElement> hash_map_elements_; 77 std::unordered_map<int, int> hash_id_to_index_; 78 size_t current_pos_; 79 size_t current_batch_start_pos_; 80 size_t graph_running_index_num_; 81 size_t graph_running_index_pos_; 82 std::unique_ptr<int[]> graph_running_index_; 83 bool expired_element_full_; 84 }; 85 } // namespace ps 86 } // namespace mindspore 87 #endif // MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_ 88