• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/backend/distributed/embedding_cache/embedding_hash_map.h"
18 #include "distributed/embedding_cache/cache_strategy/lru_cache.h"
19 
20 namespace mindspore {
21 namespace distributed {
EmbeddingHashMap(size_t hash_capacity)22 EmbeddingHashMap::EmbeddingHashMap(size_t hash_capacity) : hash_capacity_(hash_capacity), current_pos_(0) {
23   hash_map_elements_.resize(hash_capacity);
24   // In multi-device mode, embedding table are distributed on different devices by id interval,
25   // and ids outside the range of local device will use the front and back positions of the table(for Ascend platform,
26   // out-of-range ids will be Rectified by ReLU and Minimal operaotrs), so these two positions should be reserved for
27   // out-of-range ids, otherwise, the in range ids' embedding will be dirtied when optimizer updates the embeddings of
28   // out-of-range ids.
29   hash_map_elements_.front().set_step(SIZE_MAX);
30   hash_map_elements_.back().set_step(SIZE_MAX);
31   valid_capacity_ = hash_capacity > kMinimumCapacity ? (hash_capacity - kMinimumCapacity) : 0;
32   if (valid_capacity_ == 0) {
33     MS_LOG(ERROR) << "The invalid capacity is zero, please enlarge the capacity.";
34   }
35   ids_to_indices_ = std::make_unique<LRUCache<int, int>>(valid_capacity_);
36 }
37 
hash_step(const int hash_index) const38 size_t EmbeddingHashMap::hash_step(const int hash_index) const { return hash_map_elements_[hash_index].step_; }
39 
set_hash_step(const int hash_index,const size_t step)40 void EmbeddingHashMap::set_hash_step(const int hash_index, const size_t step) {
41   hash_map_elements_[hash_index].set_step(step);
42 }
43 
44 // Get capacity of hash map.
hash_capacity() const45 size_t EmbeddingHashMap::hash_capacity() const { return hash_capacity_; }
46 
GetIndex(const int id,int * index) const47 bool EmbeddingHashMap::GetIndex(const int id, int *index) const { return ids_to_indices_->Get(id, index); }
48 
Export() const49 const std::list<EmbeddingHashMap::Element> &EmbeddingHashMap::Export() const { return ids_to_indices_->Export(); }
50 
ParseData(const int id,int * const swap_out_index,int * const swap_out_ids,const size_t data_step,const size_t graph_running_step,size_t * const swap_out_size,bool * const need_wait_graph)51 int EmbeddingHashMap::ParseData(const int id, int *const swap_out_index, int *const swap_out_ids,
52                                 const size_t data_step, const size_t graph_running_step, size_t *const swap_out_size,
53                                 bool *const need_wait_graph) {
54   MS_EXCEPTION_IF_NULL(swap_out_index);
55   MS_EXCEPTION_IF_NULL(swap_out_ids);
56   MS_EXCEPTION_IF_NULL(swap_out_size);
57   bool need_swap = false;
58   int swap_out_id;
59   auto hash_index = FindInsertionPos(data_step, graph_running_step, &need_swap, need_wait_graph, &swap_out_id);
60   if (hash_index == kInvalidIndexValue) {
61     return hash_index;
62   }
63 
64   if (!need_swap) {
65     ids_to_indices_->Put(id, hash_index);
66     hash_map_elements_[hash_index].set_step(data_step);
67     return hash_index;
68   }
69 
70   swap_out_index[*swap_out_size] = hash_index;
71   swap_out_ids[*swap_out_size] = swap_out_id;
72   ++(*swap_out_size);
73   ids_to_indices_->Put(id, hash_index);
74   hash_map_elements_[hash_index].set_step(data_step);
75   return hash_index;
76 }
77 
GetOrInsertDataUnsafe(const int key)78 int EmbeddingHashMap::GetOrInsertDataUnsafe(const int key) {
79   int index = kInvalidIndexValue;
80   if (GetIndex(key, &index)) {
81     return index;
82   }
83 
84   return InsertDataUnsafe(key);
85 }
86 
InsertDataUnsafe(const int key)87 int EmbeddingHashMap::InsertDataUnsafe(const int key) {
88   auto hash_index = FindPosUnsafe();
89   if (hash_index == kInvalidIndexValue) {
90     MS_LOG(WARNING) << "Insert data unsafe failed as map is full.";
91     return hash_index;
92   }
93 
94   ids_to_indices_->Put(key, hash_index);
95   hash_map_elements_[hash_index].set_step(1UL);
96   return hash_index;
97 }
98 
FindPosUnsafe()99 int EmbeddingHashMap::FindPosUnsafe() {
100   if (current_pos_ >= valid_capacity_) {
101     return kInvalidIndexValue;
102   }
103   return static_cast<int>(++current_pos_);
104 }
105 
FindInsertionPos(const size_t,const size_t graph_running_step,bool * const need_swap,bool * const need_wait_graph,int * swap_out_id)106 int EmbeddingHashMap::FindInsertionPos(const size_t, const size_t graph_running_step, bool *const need_swap,
107                                        bool *const need_wait_graph, int *swap_out_id) {
108   if (current_pos_ < valid_capacity_) {
109     // Start from index 1.
110     return ++current_pos_;
111   }
112   if (valid_capacity_ == 0) {
113     return kInvalidIndexValue;
114   }
115 
116   *need_swap = true;
117   int id = ids_to_indices_->Back().first;
118   int index = ids_to_indices_->Back().second;
119   if (hash_map_elements_[index].IsExpired(graph_running_step)) {
120     std::vector<Element> evicted_elements;
121     ids_to_indices_->TryEvict(1, &evicted_elements);
122     if (evicted_elements.size() != 1) {
123       MS_LOG(EXCEPTION) << "Failed to evict tail element in cache, evict element number: " << evicted_elements.size()
124                         << ", cache size: " << ids_to_indices_->size()
125                         << ", cache capacity: " << ids_to_indices_->capacity();
126     }
127 
128     *swap_out_id = evicted_elements.front().first;
129     if (*swap_out_id != id) {
130       MS_LOG(EXCEPTION) << "The evicted id should be: " << id << ", but got: " << *swap_out_id;
131     }
132     return index;
133   }
134   return kInvalidIndexValue;
135 }
136 
DumpHashMap()137 void EmbeddingHashMap::DumpHashMap() {
138   MS_LOG(INFO) << "Dump hash map info begin, hash_capacity: " << hash_capacity_;
139   MS_LOG(INFO) << "Dump hash_id_to_index: ";
140   MS_LOG(INFO) << "Dump hash_map_unit: ";
141   for (size_t i = 0; i < hash_map_elements_.size(); i++) {
142     if (!hash_map_elements_[i].IsEmpty()) {
143       MS_LOG(INFO) << "  index: " << i << " step: " << hash_map_elements_[i].step_;
144     }
145   }
146   MS_LOG(INFO) << "Dump hash map info end.";
147 }
148 }  // namespace distributed
149 }  // namespace mindspore
150