1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/ps_cache/embedding_hash_map.h"
18
19 namespace mindspore {
20 namespace ps {
ParseData(const int id,int * const swap_out_index,int * const swap_out_ids,const size_t data_step,const size_t graph_running_step,size_t * const swap_out_size,bool * const need_wait_graph)21 int EmbeddingHashMap::ParseData(const int id, int *const swap_out_index, int *const swap_out_ids,
22 const size_t data_step, const size_t graph_running_step, size_t *const swap_out_size,
23 bool *const need_wait_graph) {
24 MS_EXCEPTION_IF_NULL(swap_out_index);
25 MS_EXCEPTION_IF_NULL(swap_out_ids);
26 MS_EXCEPTION_IF_NULL(swap_out_size);
27 bool need_swap = false;
28 auto hash_index = FindInsertionPos(data_step, graph_running_step, &need_swap, need_wait_graph);
29 if (hash_index == INVALID_INDEX_VALUE) {
30 return hash_index;
31 }
32
33 if (!need_swap) {
34 hash_count_++;
35 (void)hash_id_to_index_.emplace(id, hash_index);
36 hash_map_elements_[hash_index].set_id(id);
37 hash_map_elements_[hash_index].set_step(data_step);
38 return hash_index;
39 }
40
41 swap_out_index[*swap_out_size] = hash_index;
42 swap_out_ids[*swap_out_size] = hash_map_elements_[hash_index].id_;
43 (*swap_out_size)++;
44 (void)hash_id_to_index_.erase(hash_map_elements_[hash_index].id_);
45 (void)hash_id_to_index_.emplace(id, hash_index);
46 hash_map_elements_[hash_index].set_id(id);
47 hash_map_elements_[hash_index].set_step(data_step);
48 return hash_index;
49 }
50
FindInsertionPos(const size_t,const size_t graph_running_step,bool * const need_swap,bool * const need_wait_graph)51 int EmbeddingHashMap::FindInsertionPos(const size_t, const size_t graph_running_step, bool *const need_swap,
52 bool *const need_wait_graph) {
53 MS_EXCEPTION_IF_NULL(need_swap);
54 MS_EXCEPTION_IF_NULL(need_wait_graph);
55 int hash_index = INVALID_INDEX_VALUE;
56 while (!expired_element_full_) {
57 if (hash_map_elements_[current_pos_].IsEmpty()) {
58 hash_index = current_pos_;
59 hash_count_++;
60 } else if (hash_map_elements_[current_pos_].IsExpired(graph_running_step)) {
61 hash_index = current_pos_;
62 *need_swap = true;
63 } else if (hash_map_elements_[current_pos_].IsStep(graph_running_step)) {
64 graph_running_index_[graph_running_index_num_++] = current_pos_;
65 }
66 current_pos_ = (current_pos_ + 1) % hash_capacity_;
67 if (hash_index != INVALID_INDEX_VALUE) {
68 return hash_index;
69 }
70 if (current_pos_ == current_batch_start_pos_) {
71 expired_element_full_ = true;
72 MS_LOG(INFO) << "Running step:" << graph_running_step << "(num:" << graph_running_index_num_
73 << ") will be used, index swap will wait until the graph completed.";
74 }
75 }
76
77 if (graph_running_index_pos_ != graph_running_index_num_) {
78 *need_swap = true;
79 *need_wait_graph = true;
80 return graph_running_index_[graph_running_index_pos_++];
81 }
82 return INVALID_INDEX_VALUE;
83 }
84
DumpHashMap()85 void EmbeddingHashMap::DumpHashMap() {
86 MS_LOG(INFO) << "Dump hash map info begin, hash_capacity: " << hash_capacity_ << " hash_count: " << hash_count_;
87 MS_LOG(INFO) << "Dump hash_id_to_index: ";
88 for (auto iter = hash_id_to_index_.begin(); iter != hash_id_to_index_.end(); ++iter) {
89 MS_LOG(INFO) << " id: " << iter->first << " index: " << iter->second;
90 }
91 MS_LOG(INFO) << "Dump hash_map_unit: ";
92 for (size_t i = 0; i < hash_map_elements_.size(); i++) {
93 if (!hash_map_elements_[i].IsEmpty()) {
94 MS_LOG(INFO) << " index: " << i << " id: " << hash_map_elements_[i].id_
95 << " step: " << hash_map_elements_[i].step_;
96 }
97 }
98 MS_LOG(INFO) << "Dump hash map info end.";
99 }
100
Reset()101 void EmbeddingHashMap::Reset() {
102 current_batch_start_pos_ = current_pos_;
103 graph_running_index_num_ = 0;
104 graph_running_index_pos_ = 0;
105 expired_element_full_ = false;
106 }
107 } // namespace ps
108 } // namespace mindspore
109