• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_
18 #define MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_
19 
20 #include <math.h>
21 #include <utility>
22 #include <memory>
23 #include <vector>
24 #include <unordered_map>
25 #include "utils/convert_utils_base.h"
26 
27 namespace mindspore {
28 namespace ps {
29 static const size_t INVALID_STEP_VALUE = 0;
30 static const int INVALID_INDEX_VALUE = -1;
31 
32 struct HashMapElement {
33   int id_{INVALID_INDEX_VALUE};
34   size_t step_{INVALID_STEP_VALUE};
IsEmptyHashMapElement35   bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; }
IsExpiredHashMapElement36   bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; }
IsStepHashMapElement37   bool IsStep(size_t step) const { return step_ == step; }
set_idHashMapElement38   void set_id(int id) { id_ = id; }
set_stepHashMapElement39   void set_step(size_t step) { step_ = step; }
40 };
41 
42 // Hash table is held in device, HashMap is used to manage hash table in host.
43 class EmbeddingHashMap {
44  public:
EmbeddingHashMap(size_t hash_count,size_t hash_capacity)45   EmbeddingHashMap(size_t hash_count, size_t hash_capacity)
46       : hash_count_(hash_count),
47         hash_capacity_(hash_capacity),
48         current_pos_(0),
49         current_batch_start_pos_(0),
50         graph_running_index_num_(0),
51         graph_running_index_pos_(0),
52         expired_element_full_(false) {
53     hash_map_elements_.resize(hash_capacity);
54     // In multi-device mode, embedding table are distributed on different devices by ID interval,
55     // and IDs outside the range of local device will use the front and back positions of the table,
56     // the positions are reserved for this.
57     hash_map_elements_.front().set_step(SIZE_MAX);
58     hash_map_elements_.back().set_step(SIZE_MAX);
59     graph_running_index_ = std::make_unique<int[]>(hash_capacity);
60   }
61   virtual ~EmbeddingHashMap() = default;
62   int ParseData(const int id, int *const swap_out_index, int *const swap_out_ids, const size_t data_step,
63                 const size_t graph_running_step, size_t *const swap_out_size, bool *const need_wait_graph);
hash_step(const int hash_index)64   size_t hash_step(const int hash_index) const { return hash_map_elements_[hash_index].step_; }
set_hash_step(const int hash_index,const size_t step)65   void set_hash_step(const int hash_index, const size_t step) { hash_map_elements_[hash_index].set_step(step); }
hash_id_to_index()66   const std::unordered_map<int, int> &hash_id_to_index() const { return hash_id_to_index_; }
hash_capacity()67   size_t hash_capacity() const { return hash_capacity_; }
68   void DumpHashMap();
69   void Reset();
70 
71  private:
72   int FindInsertionPos(const size_t data_step, const size_t graph_running_step, bool *const need_swap,
73                        bool *const need_wait_graph);
74   size_t hash_count_;
75   size_t hash_capacity_;
76   std::vector<HashMapElement> hash_map_elements_;
77   std::unordered_map<int, int> hash_id_to_index_;
78   size_t current_pos_;
79   size_t current_batch_start_pos_;
80   size_t graph_running_index_num_;
81   size_t graph_running_index_pos_;
82   std::unique_ptr<int[]> graph_running_index_;
83   bool expired_element_full_;
84 };
85 }  // namespace ps
86 }  // namespace mindspore
87 #endif  // MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_
88