1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
17
18 #include "absl/synchronization/mutex.h"
19 #include "tensorflow/compiler/xla/client/local_client.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/random/random.h"
22
23 namespace tensorflow {
24
25 namespace {
26
get_uid()27 int64 get_uid() {
28 uint64 unsigned_rand = random::New64() & INT64_MAX;
29 return static_cast<int64>(unsigned_rand);
30 }
31
32 } // namespace
33
34 const char* kXRTCompilationCacheResourceName = "xrt_compilation_cache";
35
EntryRefImpl(XRTCompilationCache * parent,CompiledSubgraph * entry)36 XRTCompilationCache::EntryRefImpl::EntryRefImpl(XRTCompilationCache* parent,
37 CompiledSubgraph* entry)
38 : parent_(parent), entry_(entry) {
39 entry_->Ref();
40 }
41
~EntryRefImpl()42 XRTCompilationCache::EntryRefImpl::~EntryRefImpl() {
43 parent_->DiscardEntryRef(entry_);
44 }
45
get()46 XRTCompilationCacheEntry XRTCompilationCache::EntryRefImpl::get() {
47 return XRTCompilationCacheEntry(entry_->program.get());
48 }
49
XRTCompilationCache(int max_number_of_entries)50 XRTCompilationCache::XRTCompilationCache(int max_number_of_entries)
51 : max_cache_entries_(max_number_of_entries) {
52 CHECK_GE(max_cache_entries_, 0);
53 VLOG(1) << "Created compilation cache max " << max_cache_entries_
54 << " entries.";
55 }
56
~XRTCompilationCache()57 XRTCompilationCache::~XRTCompilationCache() {
58 VLOG(1) << "XRTCompilationCache::~XRTCompilationCache()";
59 // A buggy client may be holding onto a reference, or a client might have
60 // crashed while holding onto a reference. In either case, discard all
61 // outstanding client references to avoid leaking storage.
62 for (const auto& entry : entries_by_uid_) {
63 while (!entry.second->RefCountIsOne()) {
64 entry.second->Unref();
65 }
66 }
67 while (!entries_by_last_use_.empty()) {
68 MarkOldestEntryForEviction();
69 }
70 CHECK_EQ(cache_.size(), 0);
71 CHECK_EQ(entries_by_uid_.size(), 0);
72 CHECK_EQ(cache_entries_, 0);
73 CHECK_EQ(marked_for_eviction_entries_, 0);
74 }
75
Release(int64 uid)76 Status XRTCompilationCache::Release(int64 uid) {
77 absl::MutexLock lock(&mu_);
78 auto iter = entries_by_uid_.find(uid);
79
80 if (iter == entries_by_uid_.end()) {
81 return errors::NotFound("No cache entry found for uid ", uid);
82 }
83
84 DiscardEntryRefLocked(iter->second);
85
86 VLOG(1) << "After releasing entry " << uid << " refs cache is "
87 << cache_.size() << " entries ("
88 << cache_entries_ + marked_for_eviction_entries_
89 << "), marked for eviction "
90 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
91 << marked_for_eviction_entries_ << ").";
92
93 return Status::OK();
94 }
95
DiscardEntryRef(CompiledSubgraph * entry)96 void XRTCompilationCache::DiscardEntryRef(CompiledSubgraph* entry) {
97 absl::MutexLock lock(&mu_);
98 DiscardEntryRefLocked(entry);
99 }
100
DiscardEntryRefLocked(CompiledSubgraph * entry)101 void XRTCompilationCache::DiscardEntryRefLocked(CompiledSubgraph* entry) {
102 if (entry->RefCountIsOne()) {
103 // The last reference to this entry is going away, so really delete it from
104 // the cache in such a way that it can't be restored by being looked up
105 // again.
106
107 // Sanity-check that it has been marked for eviction.
108 CHECK(entries_by_last_use_.find(entry->last_use) ==
109 entries_by_last_use_.end());
110 // Update the counter tracking how much space is taken up by entries that
111 // are marked for eviction.
112 --marked_for_eviction_entries_;
113
114 // Remove the entry from the cache.
115 auto erased = cache_.erase(entry->key);
116 if (erased == 0) {
117 LOG(FATAL) << "Tried to discard nonexistent cache entry";
118 }
119 erased = entries_by_uid_.erase(entry->uid);
120 CHECK_EQ(erased, 1);
121 }
122 entry->Unref();
123 }
124
MarkOldestEntryForEviction()125 void XRTCompilationCache::MarkOldestEntryForEviction() {
126 CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second;
127 VLOG(1) << "Marking " << entry_to_mark->key << " for eviction";
128 entries_by_last_use_.erase(entry_to_mark->last_use);
129 --cache_entries_;
130 ++marked_for_eviction_entries_;
131 // Discard the cache's reference to entry. If steps are holding onto
132 // references to entry it won't be deleted until the last step holding it
133 // completes. It stays in the cache in the meantime and can be resurrected
134 // by a call to CompileIfKeyAbsent if that occurs before the last reference
135 // expires.
136 DiscardEntryRefLocked(entry_to_mark);
137 }
138
LookupEntryMarkedForEviction(CompiledSubgraph * entry)139 void XRTCompilationCache::LookupEntryMarkedForEviction(
140 CompiledSubgraph* entry) {
141 // The entry was previously marked for eviction (or is newly created) so
142 // unmark it. Add a reference (owned by the cache), update the cache size, and
143 // mark something old for eviction if necessary.
144 entry->Ref();
145 --marked_for_eviction_entries_;
146 ++cache_entries_;
147
148 // Mark the least-recently-used non-marked entry for eviction. Never mark the
149 // most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
150 // which means there's only one entry not already marked for eviction), so
151 // that an entry persists in the cache even if it is larger than the allocated
152 // cache size.
153 while (entries_by_last_use_.size() > 1 &&
154 cache_entries_ > max_cache_entries_) {
155 MarkOldestEntryForEviction();
156 }
157 }
158
InitializeEntry(const string & key,const std::function<Status (std::unique_ptr<xla::LocalExecutable> *)> & initialize_program)159 XRTCompilationCache::CompiledSubgraph* XRTCompilationCache::InitializeEntry(
160 const string& key,
161 const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
162 initialize_program) {
163 CompiledSubgraph* entry = new CompiledSubgraph();
164 entry->parent = this;
165 entry->key = key;
166 entry->uid = get_uid();
167 // Add the entry to the cache. Once the computation has been compiled,
168 // UpdateEntryAfterCompilation will be called to potentially mark old entries
169 // that don't fit any more for eviction.
170 //
171 // At this point there is one reference to entry, which is owned by the caller
172 // who created the entry. A second reference, owned by the cache, will be
173 // added below since we leave the entry in the 'marked for eviction' state
174 // here.
175 auto cache_inserted =
176 cache_.insert(std::pair<string, CompiledSubgraph*>(key, entry));
177 CHECK(cache_inserted.second);
178
179 // Initialize the program outside the lock so that other cache operations
180 // can proceed during the (potentially lengthy) initialization.
181 Status s;
182 std::unique_ptr<xla::LocalExecutable> program;
183 {
184 mu_.Unlock();
185 { s = initialize_program(&program); }
186 mu_.Lock();
187 }
188
189 // Add the entry to the uid index.
190 auto uid_inserted = entries_by_uid_.insert(
191 std::pair<int64, CompiledSubgraph*>(entry->uid, entry));
192 CHECK(uid_inserted.second);
193
194 entry->initialized = true;
195 entry->initialization_status = s;
196 if (s.ok()) {
197 entry->program = std::move(program);
198 }
199 // Add the entry to marked_for_eviction_entries_ since it will be adjusted
200 // down again when the newly-created entry gets unmarked.
201 ++marked_for_eviction_entries_;
202 return entry;
203 }
204
CompileIfKeyAbsent(const string & key,int64 * uid,const std::function<Status (std::unique_ptr<xla::LocalExecutable> *)> & compile_function)205 Status XRTCompilationCache::CompileIfKeyAbsent(
206 const string& key, int64* uid,
207 const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
208 compile_function) {
209 CompiledSubgraph* entry = nullptr;
210
211 absl::MutexLock lock(&mu_);
212 auto iter = cache_.find(key);
213
214 if (iter == cache_.end()) {
215 // The single ref on the newly-created entry is owned by the caller.
216 VLOG(1) << "Before adding new entry for key " << key << " cache is "
217 << cache_.size() << " entries ("
218 << cache_entries_ + marked_for_eviction_entries_ << "), "
219 << " marked for eviction "
220 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
221 << marked_for_eviction_entries_ << ").";
222 entry = InitializeEntry(key, compile_function);
223 } else {
224 VLOG(1) << "Before refreshing entry for key " << key << " cache is "
225 << cache_.size() << " entries ("
226 << cache_entries_ + marked_for_eviction_entries_ << "), "
227 << " marked for eviction "
228 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
229 << marked_for_eviction_entries_ << ").";
230 entry = iter->second;
231 // Make a new reference that is owned by the caller.
232 entry->Ref();
233 // Block if necessary until the subgraph has been initialized.
234 mu_.Await(absl::Condition(
235 +[](CompiledSubgraph* e) { return e->initialized; }, entry));
236 }
237
238 // Let the caller know the uid of the entry.
239 *uid = entry->uid;
240
241 // Remove the old LRU-table entry if it wasn't already marked for eviction.
242 auto erased = entries_by_last_use_.erase(entry->last_use);
243 // Update the LRU table indicating this entry is the most recently used.
244 entry->last_use = use_counter_++;
245 entries_by_last_use_[entry->last_use] = entry;
246 if (erased == 0) {
247 // The entry had been marked for eviction, or is newly created.
248 LookupEntryMarkedForEviction(entry);
249 }
250
251 VLOG(1) << "After refreshing entry for key " << key << " cache is "
252 << cache_.size() << " entries ("
253 << cache_entries_ + marked_for_eviction_entries_ << "), "
254 << " marked for eviction "
255 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
256 << marked_for_eviction_entries_ << ").";
257
258 return entry->initialization_status;
259 }
260
Lookup(int64 uid,std::unique_ptr<XRTCompilationCacheEntryRef> * entry)261 Status XRTCompilationCache::Lookup(
262 int64 uid, std::unique_ptr<XRTCompilationCacheEntryRef>* entry) {
263 entry->reset();
264
265 absl::MutexLock lock(&mu_);
266 const auto iter = entries_by_uid_.find(uid);
267 if (iter == entries_by_uid_.end()) {
268 return errors::NotFound("No executable found for uid ", uid);
269 }
270 CompiledSubgraph* cache_entry = iter->second;
271 *entry = std::unique_ptr<XRTCompilationCacheEntryRef>(
272 new EntryRefImpl(this, cache_entry));
273 return Status::OK();
274 }
275
DebugString() const276 string XRTCompilationCache::DebugString() const {
277 return "XRTCompilationCache";
278 }
279
280 } // namespace tensorflow
281