• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/ps_cache/embedding_hash_map.h"
18 
19 namespace mindspore {
20 namespace ps {
ParseData(const int id,int * const swap_out_index,int * const swap_out_ids,const size_t data_step,const size_t graph_running_step,size_t * const swap_out_size,bool * const need_wait_graph)21 int EmbeddingHashMap::ParseData(const int id, int *const swap_out_index, int *const swap_out_ids,
22                                 const size_t data_step, const size_t graph_running_step, size_t *const swap_out_size,
23                                 bool *const need_wait_graph) {
24   MS_EXCEPTION_IF_NULL(swap_out_index);
25   MS_EXCEPTION_IF_NULL(swap_out_ids);
26   MS_EXCEPTION_IF_NULL(swap_out_size);
27   bool need_swap = false;
28   auto hash_index = FindInsertionPos(data_step, graph_running_step, &need_swap, need_wait_graph);
29   if (hash_index == INVALID_INDEX_VALUE) {
30     return hash_index;
31   }
32 
33   if (!need_swap) {
34     hash_count_++;
35     (void)hash_id_to_index_.emplace(id, hash_index);
36     hash_map_elements_[hash_index].set_id(id);
37     hash_map_elements_[hash_index].set_step(data_step);
38     return hash_index;
39   }
40 
41   swap_out_index[*swap_out_size] = hash_index;
42   swap_out_ids[*swap_out_size] = hash_map_elements_[hash_index].id_;
43   (*swap_out_size)++;
44   (void)hash_id_to_index_.erase(hash_map_elements_[hash_index].id_);
45   (void)hash_id_to_index_.emplace(id, hash_index);
46   hash_map_elements_[hash_index].set_id(id);
47   hash_map_elements_[hash_index].set_step(data_step);
48   return hash_index;
49 }
50 
FindInsertionPos(const size_t,const size_t graph_running_step,bool * const need_swap,bool * const need_wait_graph)51 int EmbeddingHashMap::FindInsertionPos(const size_t, const size_t graph_running_step, bool *const need_swap,
52                                        bool *const need_wait_graph) {
53   MS_EXCEPTION_IF_NULL(need_swap);
54   MS_EXCEPTION_IF_NULL(need_wait_graph);
55   int hash_index = INVALID_INDEX_VALUE;
56   while (!expired_element_full_) {
57     if (hash_map_elements_[current_pos_].IsEmpty()) {
58       hash_index = current_pos_;
59       hash_count_++;
60     } else if (hash_map_elements_[current_pos_].IsExpired(graph_running_step)) {
61       hash_index = current_pos_;
62       *need_swap = true;
63     } else if (hash_map_elements_[current_pos_].IsStep(graph_running_step)) {
64       graph_running_index_[graph_running_index_num_++] = current_pos_;
65     }
66     current_pos_ = (current_pos_ + 1) % hash_capacity_;
67     if (hash_index != INVALID_INDEX_VALUE) {
68       return hash_index;
69     }
70     if (current_pos_ == current_batch_start_pos_) {
71       expired_element_full_ = true;
72       MS_LOG(INFO) << "Running step:" << graph_running_step << "(num:" << graph_running_index_num_
73                    << ") will be used, index swap will wait until the graph completed.";
74     }
75   }
76 
77   if (graph_running_index_pos_ != graph_running_index_num_) {
78     *need_swap = true;
79     *need_wait_graph = true;
80     return graph_running_index_[graph_running_index_pos_++];
81   }
82   return INVALID_INDEX_VALUE;
83 }
84 
DumpHashMap()85 void EmbeddingHashMap::DumpHashMap() {
86   MS_LOG(INFO) << "Dump hash map info begin, hash_capacity: " << hash_capacity_ << " hash_count: " << hash_count_;
87   MS_LOG(INFO) << "Dump hash_id_to_index: ";
88   for (auto iter = hash_id_to_index_.begin(); iter != hash_id_to_index_.end(); ++iter) {
89     MS_LOG(INFO) << "  id: " << iter->first << " index: " << iter->second;
90   }
91   MS_LOG(INFO) << "Dump hash_map_unit: ";
92   for (size_t i = 0; i < hash_map_elements_.size(); i++) {
93     if (!hash_map_elements_[i].IsEmpty()) {
94       MS_LOG(INFO) << "  index: " << i << " id: " << hash_map_elements_[i].id_
95                    << " step: " << hash_map_elements_[i].step_;
96     }
97   }
98   MS_LOG(INFO) << "Dump hash map info end.";
99 }
100 
Reset()101 void EmbeddingHashMap::Reset() {
102   current_batch_start_pos_ = current_pos_;
103   graph_running_index_num_ = 0;
104   graph_running_index_pos_ = 0;
105   expired_element_full_ = false;
106 }
107 }  // namespace ps
108 }  // namespace mindspore
109