1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
17
18 #include <stdlib.h>
19
20 #include <string>
21
22 #include "absl/synchronization/mutex.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/random/random.h"
26
27 namespace tensorflow {
28
29 namespace {
30
get_uid()31 int64 get_uid() {
32 uint64 unsigned_rand = random::New64() & INT64_MAX;
33 return static_cast<int64>(unsigned_rand);
34 }
35
GetCompilationCacheSizeFromEnv()36 int64 GetCompilationCacheSizeFromEnv() {
37 const char* env = getenv("TF_XRT_COMPILATION_CACHE_SIZE");
38 return env == nullptr ? 1024 : std::stol(env);
39 }
40
41 } // namespace
42
43 const char* kXRTCompilationCacheResourceName = "xrt_compilation_cache";
44
EntryRefImpl(XRTCompilationCache * parent,CompiledSubgraph * entry)45 XRTCompilationCache::EntryRefImpl::EntryRefImpl(XRTCompilationCache* parent,
46 CompiledSubgraph* entry)
47 : parent_(parent), entry_(entry) {
48 entry_->Ref();
49 }
50
~EntryRefImpl()51 XRTCompilationCache::EntryRefImpl::~EntryRefImpl() {
52 parent_->DiscardEntryRef(entry_);
53 }
54
get()55 XRTCompilationCacheEntry XRTCompilationCache::EntryRefImpl::get() {
56 return XRTCompilationCacheEntry(entry_->program.get());
57 }
58
XRTCompilationCache(int max_number_of_entries)59 XRTCompilationCache::XRTCompilationCache(int max_number_of_entries)
60 : max_cache_entries_(max_number_of_entries) {
61 CHECK_GE(max_cache_entries_, 0);
62 VLOG(1) << "Created compilation cache max " << max_cache_entries_
63 << " entries.";
64 }
65
~XRTCompilationCache()66 XRTCompilationCache::~XRTCompilationCache() {
67 VLOG(1) << "XRTCompilationCache::~XRTCompilationCache()";
68 // A buggy client may be holding onto a reference, or a client might have
69 // crashed while holding onto a reference. In either case, discard all
70 // outstanding client references to avoid leaking storage.
71 for (const auto& entry : entries_by_uid_) {
72 while (!entry.second->RefCountIsOne()) {
73 entry.second->Unref();
74 }
75 }
76 while (!entries_by_last_use_.empty()) {
77 MarkOldestEntryForEviction();
78 }
79 CHECK_EQ(cache_.size(), 0);
80 CHECK_EQ(entries_by_uid_.size(), 0);
81 CHECK_EQ(cache_entries_, 0);
82 CHECK_EQ(marked_for_eviction_entries_, 0);
83 }
84
Release(int64_t uid)85 Status XRTCompilationCache::Release(int64_t uid) {
86 absl::MutexLock lock(&mu_);
87 auto iter = entries_by_uid_.find(uid);
88
89 if (iter == entries_by_uid_.end()) {
90 return errors::NotFound("No cache entry found for uid ", uid);
91 }
92
93 DiscardEntryRefLocked(iter->second);
94
95 VLOG(1) << "After releasing entry " << uid << " refs cache is "
96 << cache_.size() << " entries ("
97 << cache_entries_ + marked_for_eviction_entries_
98 << "), marked for eviction "
99 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
100 << marked_for_eviction_entries_ << ").";
101
102 return Status::OK();
103 }
104
DiscardEntryRef(CompiledSubgraph * entry)105 void XRTCompilationCache::DiscardEntryRef(CompiledSubgraph* entry) {
106 absl::MutexLock lock(&mu_);
107 DiscardEntryRefLocked(entry);
108 }
109
DiscardEntryRefLocked(CompiledSubgraph * entry)110 void XRTCompilationCache::DiscardEntryRefLocked(CompiledSubgraph* entry) {
111 if (entry->RefCountIsOne()) {
112 // The last reference to this entry is going away, so really delete it from
113 // the cache in such a way that it can't be restored by being looked up
114 // again.
115
116 // Sanity-check that it has been marked for eviction.
117 CHECK(entries_by_last_use_.find(entry->last_use) ==
118 entries_by_last_use_.end());
119 // Update the counter tracking how much space is taken up by entries that
120 // are marked for eviction.
121 --marked_for_eviction_entries_;
122
123 // Remove the entry from the cache.
124 auto erased = cache_.erase(entry->key);
125 if (erased == 0) {
126 LOG(FATAL) << "Tried to discard nonexistent cache entry";
127 }
128 erased = entries_by_uid_.erase(entry->uid);
129 CHECK_EQ(erased, 1);
130 }
131 entry->Unref();
132 }
133
MarkOldestEntryForEviction()134 void XRTCompilationCache::MarkOldestEntryForEviction() {
135 CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second;
136 VLOG(1) << "Marking " << entry_to_mark->key << " for eviction";
137 entries_by_last_use_.erase(entry_to_mark->last_use);
138 --cache_entries_;
139 ++marked_for_eviction_entries_;
140 // Discard the cache's reference to entry. If steps are holding onto
141 // references to entry it won't be deleted until the last step holding it
142 // completes. It stays in the cache in the meantime and can be resurrected
143 // by a call to CompileIfKeyAbsent if that occurs before the last reference
144 // expires.
145 DiscardEntryRefLocked(entry_to_mark);
146 }
147
LookupEntryMarkedForEviction(CompiledSubgraph * entry)148 void XRTCompilationCache::LookupEntryMarkedForEviction(
149 CompiledSubgraph* entry) {
150 // The entry was previously marked for eviction (or is newly created) so
151 // unmark it. Add a reference (owned by the cache), update the cache size, and
152 // mark something old for eviction if necessary.
153 entry->Ref();
154 --marked_for_eviction_entries_;
155 ++cache_entries_;
156
157 // Mark the least-recently-used non-marked entry for eviction. Never mark the
158 // most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
159 // which means there's only one entry not already marked for eviction), so
160 // that an entry persists in the cache even if it is larger than the allocated
161 // cache size.
162 while (entries_by_last_use_.size() > 1 &&
163 cache_entries_ > max_cache_entries_) {
164 MarkOldestEntryForEviction();
165 }
166 }
167
InitializeEntry(const string & key,const std::function<Status (std::unique_ptr<xla::LocalExecutable> *)> & initialize_program)168 XRTCompilationCache::CompiledSubgraph* XRTCompilationCache::InitializeEntry(
169 const string& key,
170 const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
171 initialize_program) {
172 CompiledSubgraph* entry = new CompiledSubgraph();
173 entry->parent = this;
174 entry->key = key;
175 entry->uid = get_uid();
176 // Add the entry to the cache. Once the computation has been compiled,
177 // UpdateEntryAfterCompilation will be called to potentially mark old entries
178 // that don't fit any more for eviction.
179 //
180 // At this point there is one reference to entry, which is owned by the caller
181 // who created the entry. A second reference, owned by the cache, will be
182 // added below since we leave the entry in the 'marked for eviction' state
183 // here.
184 auto cache_inserted =
185 cache_.insert(std::pair<string, CompiledSubgraph*>(key, entry));
186 CHECK(cache_inserted.second);
187
188 // Initialize the program outside the lock so that other cache operations
189 // can proceed during the (potentially lengthy) initialization.
190 Status s;
191 std::unique_ptr<xla::LocalExecutable> program;
192 {
193 mu_.Unlock();
194 { s = initialize_program(&program); }
195 mu_.Lock();
196 }
197
198 // Add the entry to the uid index.
199 auto uid_inserted = entries_by_uid_.insert(
200 std::pair<int64, CompiledSubgraph*>(entry->uid, entry));
201 CHECK(uid_inserted.second);
202
203 entry->initialized = true;
204 entry->initialization_status = s;
205 if (s.ok()) {
206 entry->program = std::move(program);
207 }
208 // Add the entry to marked_for_eviction_entries_ since it will be adjusted
209 // down again when the newly-created entry gets unmarked.
210 ++marked_for_eviction_entries_;
211 return entry;
212 }
213
CompileIfKeyAbsent(const string & key,int64 * uid,const std::function<Status (std::unique_ptr<xla::LocalExecutable> *)> & compile_function)214 Status XRTCompilationCache::CompileIfKeyAbsent(
215 const string& key, int64* uid,
216 const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
217 compile_function) {
218 CompiledSubgraph* entry = nullptr;
219
220 absl::MutexLock lock(&mu_);
221 auto iter = cache_.find(key);
222
223 if (iter == cache_.end()) {
224 // The single ref on the newly-created entry is owned by the caller.
225 VLOG(1) << "Before adding new entry for key " << key << " cache is "
226 << cache_.size() << " entries ("
227 << cache_entries_ + marked_for_eviction_entries_ << "), "
228 << " marked for eviction "
229 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
230 << marked_for_eviction_entries_ << ").";
231 entry = InitializeEntry(key, compile_function);
232 } else {
233 VLOG(1) << "Before refreshing entry for key " << key << " cache is "
234 << cache_.size() << " entries ("
235 << cache_entries_ + marked_for_eviction_entries_ << "), "
236 << " marked for eviction "
237 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
238 << marked_for_eviction_entries_ << ").";
239 entry = iter->second;
240 // Make a new reference that is owned by the caller.
241 entry->Ref();
242 // Block if necessary until the subgraph has been initialized.
243 mu_.Await(absl::Condition(
244 +[](CompiledSubgraph* e) { return e->initialized; }, entry));
245 }
246
247 // Let the caller know the uid of the entry.
248 *uid = entry->uid;
249
250 // Remove the old LRU-table entry if it wasn't already marked for eviction.
251 auto erased = entries_by_last_use_.erase(entry->last_use);
252 // Update the LRU table indicating this entry is the most recently used.
253 entry->last_use = use_counter_++;
254 entries_by_last_use_[entry->last_use] = entry;
255 if (erased == 0) {
256 // The entry had been marked for eviction, or is newly created.
257 LookupEntryMarkedForEviction(entry);
258 }
259
260 VLOG(1) << "After refreshing entry for key " << key << " cache is "
261 << cache_.size() << " entries ("
262 << cache_entries_ + marked_for_eviction_entries_ << "), "
263 << " marked for eviction "
264 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
265 << marked_for_eviction_entries_ << ").";
266
267 return entry->initialization_status;
268 }
269
Lookup(int64_t uid,std::unique_ptr<XRTCompilationCacheEntryRef> * entry)270 Status XRTCompilationCache::Lookup(
271 int64_t uid, std::unique_ptr<XRTCompilationCacheEntryRef>* entry) {
272 entry->reset();
273
274 absl::MutexLock lock(&mu_);
275 const auto iter = entries_by_uid_.find(uid);
276 if (iter == entries_by_uid_.end()) {
277 return errors::NotFound("No executable found for uid ", uid);
278 }
279 CompiledSubgraph* cache_entry = iter->second;
280 *entry = std::unique_ptr<XRTCompilationCacheEntryRef>(
281 new EntryRefImpl(this, cache_entry));
282 return Status::OK();
283 }
284
DebugString() const285 string XRTCompilationCache::DebugString() const {
286 return "XRTCompilationCache";
287 }
288
GetOrCreateCompilationCache(ResourceMgr * rm,int64_t max_number_of_entries)289 xla::StatusOr<RefPtr<XRTCompilationCache>> GetOrCreateCompilationCache(
290 ResourceMgr* rm, int64_t max_number_of_entries) {
291 if (max_number_of_entries == 0) {
292 max_number_of_entries = GetCompilationCacheSizeFromEnv();
293 }
294 XRTCompilationCache* cache;
295 TF_RETURN_IF_ERROR(rm->LookupOrCreate<XRTCompilationCache>(
296 rm->default_container(), kXRTCompilationCacheResourceName, &cache,
297 [&](XRTCompilationCache** new_cache) {
298 *new_cache = new XRTCompilationCache(max_number_of_entries);
299 return Status::OK();
300 }));
301 return RefPtr<XRTCompilationCache>(cache);
302 }
303
304 } // namespace tensorflow
305