• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_HASH_MAP_H_
18 #define MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_HASH_MAP_H_
19 
20 #include <cmath>
21 #include <utility>
22 #include <memory>
23 #include <vector>
24 #include <list>
25 #include <mutex>
26 #include "utils/hash_map.h"
27 #include "utils/convert_utils_base.h"
28 #include "include/backend/visible.h"
29 #include "distributed/embedding_cache/cache_strategy/cache.h"
30 
31 namespace mindspore {
32 namespace distributed {
33 // Define the value of an invalid step.
34 static constexpr size_t kInvalidStepValue = 0;
35 // Define the value of an invalid index.
36 static constexpr int kInvalidIndexValue = -1;
37 
38 // The minimum valid capacity.
39 static constexpr size_t kMinimumCapacity = 2;
40 
41 struct HashMapElement {
42   // The current global step of cache prefetching operation.
43   size_t step_{kInvalidStepValue};
44 
IsEmptyHashMapElement45   bool IsEmpty() const { return step_ == kInvalidStepValue; }
IsExpiredHashMapElement46   bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; }
StepEqualHashMapElement47   bool StepEqual(size_t step) const { return step_ == step; }
set_stepHashMapElement48   void set_step(size_t step) { step_ = step; }
49 };
50 
51 // EmbeddingHashMap is used to manage the id -> index mapping of the embedding cache table on the host
52 // side. The cache content can be stored on the device or host side.
53 class BACKEND_EXPORT EmbeddingHashMap {
54  public:
55   using Element = typename Cache<int, int>::Element;
56 
57   explicit EmbeddingHashMap(size_t hash_capacity);
58 
59   ~EmbeddingHashMap() = default;
60 
61   // Find the insertion position (index) in the hash map for an id.
62   // If the hash map capacity is insufficient, return the information of ids and indices that need to be swapped.
63   int ParseData(const int id, int *const swap_out_index, int *const swap_out_ids, const size_t data_step,
64                 const size_t graph_running_step, size_t *const swap_out_size, bool *const need_wait_graph);
65 
66   int GetOrInsertDataUnsafe(const int key);
67 
68   // Get the global step of a element in hash map.
69   size_t hash_step(const int hash_index) const;
70   // Set the global step of a element in hash map.
71   void set_hash_step(const int hash_index, const size_t step);
72 
73   // Get capacity of hash map.
74   size_t hash_capacity() const;
75 
76   // Get index by id.
77   bool GetIndex(const int id, int *index) const;
78 
79   const std::list<Element> &Export() const;
80 
81   // Reset the hash map.
Reset()82   void Reset() {}
83 
84   void DumpHashMap();
85 
86  private:
87   // Find the insertion position (index) in the hash map for an id.
88   int FindInsertionPos(const size_t data_step, const size_t graph_running_step, bool *const need_swap,
89                        bool *const need_wait_graph, int *swap_out_id);
90 
91   int InsertDataUnsafe(const int key);
92 
93   int FindPosUnsafe();
94 
95   // The hash map capacity.
96   size_t hash_capacity_;
97 
98   // The hash map valid capacity(less than hash_capacity_).
99   size_t valid_capacity_;
100 
101   // Record all elements in this hash map.
102   std::vector<HashMapElement> hash_map_elements_;
103 
104   // The id -> index mapping.
105   std::unique_ptr<Cache<int, int>> ids_to_indices_;
106 
107   // The cursor that records the current used index.
108   size_t current_pos_;
109 };
110 }  // namespace distributed
111 }  // namespace mindspore
112 #endif  // MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_HASH_MAP_H_
113