• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_LRU_CACHE_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_LRU_CACHE_H_
18 
19 #include "absl/container/node_hash_map.h"
20 #include "absl/types/optional.h"
21 #include "tensorflow/core/platform/logging.h"
22 
23 namespace xla {
24 
25 // A simple LRU cache. Not thread-safe.
26 // Value must be copyable and moveable. The intent is that Value is typically
27 // a smart-pointer type.
28 template <typename Key, typename Value,
29           typename Hash = typename absl::node_hash_map<Key, Value>::hasher,
30           typename Eq = typename absl::node_hash_map<Key, Value>::key_equal>
31 class LRUCache {
32  private:
33   struct LRUListEntry {
34     LRUListEntry* next;
35     LRUListEntry* prev;
36   };
37 
38  public:
39   // Multiple LRUCaches can share a LRU list, meaning that the capacity and
40   // eviction policy is shared. The user provides an LRU list
41   // to the cache constructor, and must ensure that it remains alive as long
42   // as the cache does.
43   class LRUList {
44    public:
LRUList(int capacity)45     explicit LRUList(int capacity) : capacity_(capacity) {
46       head_.next = &head_;
47       head_.prev = &head_;
48     }
~LRUList()49     ~LRUList() {
50       CHECK(head_.next == &head_);
51       CHECK(head_.prev == &head_);
52     }
53 
54     LRUList(const LRUList&) = delete;
55     LRUList(LRUList&&) = delete;
56     LRUList& operator=(const LRUList&) = delete;
57     LRUList& operator=(LRUList&&) = delete;
58 
Capacity()59     int Capacity() const { return capacity_; }
Size()60     int Size() const { return size_; }
61 
62     void Clear();
63 
64    private:
65     friend class LRUCache;
66     int capacity_;
67     int size_ = 0;
68 
69     // Root of a circular doubly-linked list of entries, in order from least
70     // recently used to most recently used. An "empty" cache always contains
71     // this element in the LRU list.
72     LRUListEntry head_;
73   };
74 
LRUCache(LRUList * lru_list)75   explicit LRUCache(LRUList* lru_list) : lru_list_(lru_list) {}
76   ~LRUCache();
77 
78   LRUCache(const LRUCache&) = delete;
79   LRUCache(LRUCache&&) = delete;
80   LRUCache& operator=(const LRUCache&) = delete;
81   LRUCache& operator=(LRUCache&&) = delete;
82 
83   // Returns the `value` associated with `key`. Creates a value with `factory`
84   // and inserts it if absent.
85   Value GetOrCreateIfAbsent(const Key& key,
86                             const std::function<Value(const Key&)>& factory);
87 
88   // Removes all entries from the cache.
89   void Clear();
90 
Size()91   int Size() const { return entries_.size(); }
Capacity()92   int Capacity() const { return lru_list_->Capacity(); }
93 
94  private:
95   LRUList* lru_list_;
96 
97   struct Entry : public LRUListEntry {
98     Entry() = default;
99 
100     // Pointer to the key in `entries_`. absl::node_hash_map<> promises
101     // pointer stability for keys.
102     const Key* key;
103     LRUCache* container;
104     absl::optional<Value> value;
105   };
106 
107   // We use `node_hash_map` because we want to guarantee pointer stability for
108   // keys and values.
109   absl::node_hash_map<Key, Entry, Hash, Eq> entries_;
110 };
111 
112 template <typename Key, typename Value, typename Hash, typename Eq>
Clear()113 void LRUCache<Key, Value, Hash, Eq>::LRUList::Clear() {
114   while (head_.next != &head_) {
115     static_cast<Entry*>(head_.next)->container->Clear();
116   }
117   size_ = 0;
118 }
119 
120 template <typename Key, typename Value, typename Hash, typename Eq>
Clear()121 void LRUCache<Key, Value, Hash, Eq>::Clear() {
122   for (auto& e : entries_) {
123     LRUListEntry* l = &e.second;
124     l->next->prev = l->prev;
125     l->prev->next = l->next;
126     --lru_list_->size_;
127   }
128   entries_.clear();
129 }
130 
131 template <typename Key, typename Value, typename Hash, typename Eq>
~LRUCache()132 LRUCache<Key, Value, Hash, Eq>::~LRUCache() {
133   Clear();
134 }
135 
136 template <typename Key, typename Value, typename Hash, typename Eq>
GetOrCreateIfAbsent(const Key & key,const std::function<Value (const Key &)> & factory)137 Value LRUCache<Key, Value, Hash, Eq>::GetOrCreateIfAbsent(
138     const Key& key, const std::function<Value(const Key&)>& factory) {
139   typename absl::node_hash_map<Key, Entry, Hash, Eq>::iterator it;
140   bool inserted;
141   std::tie(it, inserted) = entries_.try_emplace(key);
142   Entry& entry = it->second;
143   if (inserted) {
144     entry.key = &it->first;
145     entry.value = factory(*entry.key);
146     ++lru_list_->size_;
147   } else {
148     // Removes the entry from the LRU list, in preparation for adding it
149     // to the back of the list.
150     entry.prev->next = entry.next;
151     entry.next->prev = entry.prev;
152   }
153   // (Re-)adds entry to the back of the LRU list. Since it is now the
154   // most recently used element, it goes at the back.
155   LRUListEntry& lru_head = lru_list_->head_;
156   entry.container = this;
157   entry.prev = lru_head.prev;
158   entry.next = &lru_head;
159   lru_head.prev->next = &entry;
160   lru_head.prev = &entry;
161 
162   Value v = *entry.value;
163 
164   // Evict an LRU entry if we are over capacity.
165   if (lru_list_->size_ > lru_list_->capacity_) {
166     Entry* to_remove = static_cast<Entry*>(lru_head.next);
167     to_remove->next->prev = &lru_head;
168     lru_head.next = to_remove->next;
169     to_remove->container->entries_.erase(*to_remove->key);
170     --lru_list_->size_;
171   }
172   return v;
173 }
174 
175 }  // namespace xla
176 
177 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_LRU_CACHE_H_
178