1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/backend/distributed/embedding_cache/embedding_hash_map.h"
18 #include "distributed/embedding_cache/cache_strategy/lru_cache.h"
19
20 namespace mindspore {
21 namespace distributed {
EmbeddingHashMap(size_t hash_capacity)22 EmbeddingHashMap::EmbeddingHashMap(size_t hash_capacity) : hash_capacity_(hash_capacity), current_pos_(0) {
23 hash_map_elements_.resize(hash_capacity);
24 // In multi-device mode, embedding table are distributed on different devices by id interval,
25 // and ids outside the range of local device will use the front and back positions of the table(for Ascend platform,
26 // out-of-range ids will be Rectified by ReLU and Minimal operaotrs), so these two positions should be reserved for
27 // out-of-range ids, otherwise, the in range ids' embedding will be dirtied when optimizer updates the embeddings of
28 // out-of-range ids.
29 hash_map_elements_.front().set_step(SIZE_MAX);
30 hash_map_elements_.back().set_step(SIZE_MAX);
31 valid_capacity_ = hash_capacity > kMinimumCapacity ? (hash_capacity - kMinimumCapacity) : 0;
32 if (valid_capacity_ == 0) {
33 MS_LOG(ERROR) << "The invalid capacity is zero, please enlarge the capacity.";
34 }
35 ids_to_indices_ = std::make_unique<LRUCache<int, int>>(valid_capacity_);
36 }
37
hash_step(const int hash_index) const38 size_t EmbeddingHashMap::hash_step(const int hash_index) const { return hash_map_elements_[hash_index].step_; }
39
set_hash_step(const int hash_index,const size_t step)40 void EmbeddingHashMap::set_hash_step(const int hash_index, const size_t step) {
41 hash_map_elements_[hash_index].set_step(step);
42 }
43
44 // Get capacity of hash map.
hash_capacity() const45 size_t EmbeddingHashMap::hash_capacity() const { return hash_capacity_; }
46
GetIndex(const int id,int * index) const47 bool EmbeddingHashMap::GetIndex(const int id, int *index) const { return ids_to_indices_->Get(id, index); }
48
Export() const49 const std::list<EmbeddingHashMap::Element> &EmbeddingHashMap::Export() const { return ids_to_indices_->Export(); }
50
ParseData(const int id,int * const swap_out_index,int * const swap_out_ids,const size_t data_step,const size_t graph_running_step,size_t * const swap_out_size,bool * const need_wait_graph)51 int EmbeddingHashMap::ParseData(const int id, int *const swap_out_index, int *const swap_out_ids,
52 const size_t data_step, const size_t graph_running_step, size_t *const swap_out_size,
53 bool *const need_wait_graph) {
54 MS_EXCEPTION_IF_NULL(swap_out_index);
55 MS_EXCEPTION_IF_NULL(swap_out_ids);
56 MS_EXCEPTION_IF_NULL(swap_out_size);
57 bool need_swap = false;
58 int swap_out_id;
59 auto hash_index = FindInsertionPos(data_step, graph_running_step, &need_swap, need_wait_graph, &swap_out_id);
60 if (hash_index == kInvalidIndexValue) {
61 return hash_index;
62 }
63
64 if (!need_swap) {
65 ids_to_indices_->Put(id, hash_index);
66 hash_map_elements_[hash_index].set_step(data_step);
67 return hash_index;
68 }
69
70 swap_out_index[*swap_out_size] = hash_index;
71 swap_out_ids[*swap_out_size] = swap_out_id;
72 ++(*swap_out_size);
73 ids_to_indices_->Put(id, hash_index);
74 hash_map_elements_[hash_index].set_step(data_step);
75 return hash_index;
76 }
77
GetOrInsertDataUnsafe(const int key)78 int EmbeddingHashMap::GetOrInsertDataUnsafe(const int key) {
79 int index = kInvalidIndexValue;
80 if (GetIndex(key, &index)) {
81 return index;
82 }
83
84 return InsertDataUnsafe(key);
85 }
86
InsertDataUnsafe(const int key)87 int EmbeddingHashMap::InsertDataUnsafe(const int key) {
88 auto hash_index = FindPosUnsafe();
89 if (hash_index == kInvalidIndexValue) {
90 MS_LOG(WARNING) << "Insert data unsafe failed as map is full.";
91 return hash_index;
92 }
93
94 ids_to_indices_->Put(key, hash_index);
95 hash_map_elements_[hash_index].set_step(1UL);
96 return hash_index;
97 }
98
FindPosUnsafe()99 int EmbeddingHashMap::FindPosUnsafe() {
100 if (current_pos_ >= valid_capacity_) {
101 return kInvalidIndexValue;
102 }
103 return static_cast<int>(++current_pos_);
104 }
105
FindInsertionPos(const size_t,const size_t graph_running_step,bool * const need_swap,bool * const need_wait_graph,int * swap_out_id)106 int EmbeddingHashMap::FindInsertionPos(const size_t, const size_t graph_running_step, bool *const need_swap,
107 bool *const need_wait_graph, int *swap_out_id) {
108 if (current_pos_ < valid_capacity_) {
109 // Start from index 1.
110 return ++current_pos_;
111 }
112 if (valid_capacity_ == 0) {
113 return kInvalidIndexValue;
114 }
115
116 *need_swap = true;
117 int id = ids_to_indices_->Back().first;
118 int index = ids_to_indices_->Back().second;
119 if (hash_map_elements_[index].IsExpired(graph_running_step)) {
120 std::vector<Element> evicted_elements;
121 ids_to_indices_->TryEvict(1, &evicted_elements);
122 if (evicted_elements.size() != 1) {
123 MS_LOG(EXCEPTION) << "Failed to evict tail element in cache, evict element number: " << evicted_elements.size()
124 << ", cache size: " << ids_to_indices_->size()
125 << ", cache capacity: " << ids_to_indices_->capacity();
126 }
127
128 *swap_out_id = evicted_elements.front().first;
129 if (*swap_out_id != id) {
130 MS_LOG(EXCEPTION) << "The evicted id should be: " << id << ", but got: " << *swap_out_id;
131 }
132 return index;
133 }
134 return kInvalidIndexValue;
135 }
136
DumpHashMap()137 void EmbeddingHashMap::DumpHashMap() {
138 MS_LOG(INFO) << "Dump hash map info begin, hash_capacity: " << hash_capacity_;
139 MS_LOG(INFO) << "Dump hash_id_to_index: ";
140 MS_LOG(INFO) << "Dump hash_map_unit: ";
141 for (size_t i = 0; i < hash_map_elements_.size(); i++) {
142 if (!hash_map_elements_[i].IsEmpty()) {
143 MS_LOG(INFO) << " index: " << i << " step: " << hash_map_elements_[i].step_;
144 }
145 }
146 MS_LOG(INFO) << "Dump hash map info end.";
147 }
148 } // namespace distributed
149 } // namespace mindspore
150