• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
17 
18 #include "absl/synchronization/mutex.h"
19 #include "tensorflow/compiler/xla/client/local_client.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/random/random.h"
22 
23 namespace tensorflow {
24 
25 namespace {
26 
get_uid()27 int64 get_uid() {
28   uint64 unsigned_rand = random::New64() & INT64_MAX;
29   return static_cast<int64>(unsigned_rand);
30 }
31 
32 }  // namespace
33 
34 const char* kXRTCompilationCacheResourceName = "xrt_compilation_cache";
35 
EntryRefImpl(XRTCompilationCache * parent,CompiledSubgraph * entry)36 XRTCompilationCache::EntryRefImpl::EntryRefImpl(XRTCompilationCache* parent,
37                                                 CompiledSubgraph* entry)
38     : parent_(parent), entry_(entry) {
39   entry_->Ref();
40 }
41 
~EntryRefImpl()42 XRTCompilationCache::EntryRefImpl::~EntryRefImpl() {
43   parent_->DiscardEntryRef(entry_);
44 }
45 
get()46 XRTCompilationCacheEntry XRTCompilationCache::EntryRefImpl::get() {
47   return XRTCompilationCacheEntry(entry_->program.get());
48 }
49 
XRTCompilationCache(int max_number_of_entries)50 XRTCompilationCache::XRTCompilationCache(int max_number_of_entries)
51     : max_cache_entries_(max_number_of_entries) {
52   CHECK_GE(max_cache_entries_, 0);
53   VLOG(1) << "Created compilation cache max " << max_cache_entries_
54           << " entries.";
55 }
56 
~XRTCompilationCache()57 XRTCompilationCache::~XRTCompilationCache() {
58   VLOG(1) << "XRTCompilationCache::~XRTCompilationCache()";
59   // A buggy client may be holding onto a reference, or a client might have
60   // crashed while holding onto a reference. In either case, discard all
61   // outstanding client references to avoid leaking storage.
62   for (const auto& entry : entries_by_uid_) {
63     while (!entry.second->RefCountIsOne()) {
64       entry.second->Unref();
65     }
66   }
67   while (!entries_by_last_use_.empty()) {
68     MarkOldestEntryForEviction();
69   }
70   CHECK_EQ(cache_.size(), 0);
71   CHECK_EQ(entries_by_uid_.size(), 0);
72   CHECK_EQ(cache_entries_, 0);
73   CHECK_EQ(marked_for_eviction_entries_, 0);
74 }
75 
Release(int64 uid)76 Status XRTCompilationCache::Release(int64 uid) {
77   absl::MutexLock lock(&mu_);
78   auto iter = entries_by_uid_.find(uid);
79 
80   if (iter == entries_by_uid_.end()) {
81     return errors::NotFound("No cache entry found for uid ", uid);
82   }
83 
84   DiscardEntryRefLocked(iter->second);
85 
86   VLOG(1) << "After releasing entry " << uid << " refs cache is "
87           << cache_.size() << " entries ("
88           << cache_entries_ + marked_for_eviction_entries_
89           << "), marked for eviction "
90           << (cache_.size() - entries_by_last_use_.size()) << " entries ("
91           << marked_for_eviction_entries_ << ").";
92 
93   return Status::OK();
94 }
95 
DiscardEntryRef(CompiledSubgraph * entry)96 void XRTCompilationCache::DiscardEntryRef(CompiledSubgraph* entry) {
97   absl::MutexLock lock(&mu_);
98   DiscardEntryRefLocked(entry);
99 }
100 
DiscardEntryRefLocked(CompiledSubgraph * entry)101 void XRTCompilationCache::DiscardEntryRefLocked(CompiledSubgraph* entry) {
102   if (entry->RefCountIsOne()) {
103     // The last reference to this entry is going away, so really delete it from
104     // the cache in such a way that it can't be restored by being looked up
105     // again.
106 
107     // Sanity-check that it has been marked for eviction.
108     CHECK(entries_by_last_use_.find(entry->last_use) ==
109           entries_by_last_use_.end());
110     // Update the counter tracking how much space is taken up by entries that
111     // are marked for eviction.
112     --marked_for_eviction_entries_;
113 
114     // Remove the entry from the cache.
115     auto erased = cache_.erase(entry->key);
116     if (erased == 0) {
117       LOG(FATAL) << "Tried to discard nonexistent cache entry";
118     }
119     erased = entries_by_uid_.erase(entry->uid);
120     CHECK_EQ(erased, 1);
121   }
122   entry->Unref();
123 }
124 
MarkOldestEntryForEviction()125 void XRTCompilationCache::MarkOldestEntryForEviction() {
126   CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second;
127   VLOG(1) << "Marking " << entry_to_mark->key << " for eviction";
128   entries_by_last_use_.erase(entry_to_mark->last_use);
129   --cache_entries_;
130   ++marked_for_eviction_entries_;
131   // Discard the cache's reference to entry. If steps are holding onto
132   // references to entry it won't be deleted until the last step holding it
133   // completes. It stays in the cache in the meantime and can be resurrected
134   // by a call to CompileIfKeyAbsent if that occurs before the last reference
135   // expires.
136   DiscardEntryRefLocked(entry_to_mark);
137 }
138 
LookupEntryMarkedForEviction(CompiledSubgraph * entry)139 void XRTCompilationCache::LookupEntryMarkedForEviction(
140     CompiledSubgraph* entry) {
141   // The entry was previously marked for eviction (or is newly created) so
142   // unmark it. Add a reference (owned by the cache), update the cache size, and
143   // mark something old for eviction if necessary.
144   entry->Ref();
145   --marked_for_eviction_entries_;
146   ++cache_entries_;
147 
148   // Mark the least-recently-used non-marked entry for eviction. Never mark the
149   // most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
150   // which means there's only one entry not already marked for eviction), so
151   // that an entry persists in the cache even if it is larger than the allocated
152   // cache size.
153   while (entries_by_last_use_.size() > 1 &&
154          cache_entries_ > max_cache_entries_) {
155     MarkOldestEntryForEviction();
156   }
157 }
158 
InitializeEntry(const string & key,const std::function<Status (std::unique_ptr<xla::LocalExecutable> *)> & initialize_program)159 XRTCompilationCache::CompiledSubgraph* XRTCompilationCache::InitializeEntry(
160     const string& key,
161     const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
162         initialize_program) {
163   CompiledSubgraph* entry = new CompiledSubgraph();
164   entry->parent = this;
165   entry->key = key;
166   entry->uid = get_uid();
167   // Add the entry to the cache. Once the computation has been compiled,
168   // UpdateEntryAfterCompilation will be called to potentially mark old entries
169   // that don't fit any more for eviction.
170   //
171   // At this point there is one reference to entry, which is owned by the caller
172   // who created the entry. A second reference, owned by the cache, will be
173   // added below since we leave the entry in the 'marked for eviction' state
174   // here.
175   auto cache_inserted =
176       cache_.insert(std::pair<string, CompiledSubgraph*>(key, entry));
177   CHECK(cache_inserted.second);
178 
179   // Initialize the program outside the lock so that other cache operations
180   // can proceed during the (potentially lengthy) initialization.
181   Status s;
182   std::unique_ptr<xla::LocalExecutable> program;
183   {
184     mu_.Unlock();
185     { s = initialize_program(&program); }
186     mu_.Lock();
187   }
188 
189   // Add the entry to the uid index.
190   auto uid_inserted = entries_by_uid_.insert(
191       std::pair<int64, CompiledSubgraph*>(entry->uid, entry));
192   CHECK(uid_inserted.second);
193 
194   entry->initialized = true;
195   entry->initialization_status = s;
196   if (s.ok()) {
197     entry->program = std::move(program);
198   }
199   // Add the entry to marked_for_eviction_entries_ since it will be adjusted
200   // down again when the newly-created entry gets unmarked.
201   ++marked_for_eviction_entries_;
202   return entry;
203 }
204 
CompileIfKeyAbsent(const string & key,int64 * uid,const std::function<Status (std::unique_ptr<xla::LocalExecutable> *)> & compile_function)205 Status XRTCompilationCache::CompileIfKeyAbsent(
206     const string& key, int64* uid,
207     const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
208         compile_function) {
209   CompiledSubgraph* entry = nullptr;
210 
211   absl::MutexLock lock(&mu_);
212   auto iter = cache_.find(key);
213 
214   if (iter == cache_.end()) {
215     // The single ref on the newly-created entry is owned by the caller.
216     VLOG(1) << "Before adding new entry for key " << key << " cache is "
217             << cache_.size() << " entries ("
218             << cache_entries_ + marked_for_eviction_entries_ << "), "
219             << " marked for eviction "
220             << (cache_.size() - entries_by_last_use_.size()) << " entries ("
221             << marked_for_eviction_entries_ << ").";
222     entry = InitializeEntry(key, compile_function);
223   } else {
224     VLOG(1) << "Before refreshing entry for key " << key << " cache is "
225             << cache_.size() << " entries ("
226             << cache_entries_ + marked_for_eviction_entries_ << "), "
227             << " marked for eviction "
228             << (cache_.size() - entries_by_last_use_.size()) << " entries ("
229             << marked_for_eviction_entries_ << ").";
230     entry = iter->second;
231     // Make a new reference that is owned by the caller.
232     entry->Ref();
233     // Block if necessary until the subgraph has been initialized.
234     mu_.Await(absl::Condition(
235         +[](CompiledSubgraph* e) { return e->initialized; }, entry));
236   }
237 
238   // Let the caller know the uid of the entry.
239   *uid = entry->uid;
240 
241   // Remove the old LRU-table entry if it wasn't already marked for eviction.
242   auto erased = entries_by_last_use_.erase(entry->last_use);
243   // Update the LRU table indicating this entry is the most recently used.
244   entry->last_use = use_counter_++;
245   entries_by_last_use_[entry->last_use] = entry;
246   if (erased == 0) {
247     // The entry had been marked for eviction, or is newly created.
248     LookupEntryMarkedForEviction(entry);
249   }
250 
251   VLOG(1) << "After refreshing entry for key " << key << " cache is "
252           << cache_.size() << " entries ("
253           << cache_entries_ + marked_for_eviction_entries_ << "), "
254           << " marked for eviction "
255           << (cache_.size() - entries_by_last_use_.size()) << " entries ("
256           << marked_for_eviction_entries_ << ").";
257 
258   return entry->initialization_status;
259 }
260 
Lookup(int64 uid,std::unique_ptr<XRTCompilationCacheEntryRef> * entry)261 Status XRTCompilationCache::Lookup(
262     int64 uid, std::unique_ptr<XRTCompilationCacheEntryRef>* entry) {
263   entry->reset();
264 
265   absl::MutexLock lock(&mu_);
266   const auto iter = entries_by_uid_.find(uid);
267   if (iter == entries_by_uid_.end()) {
268     return errors::NotFound("No executable found for uid ", uid);
269   }
270   CompiledSubgraph* cache_entry = iter->second;
271   *entry = std::unique_ptr<XRTCompilationCacheEntryRef>(
272       new EntryRefImpl(this, cache_entry));
273   return Status::OK();
274 }
275 
DebugString() const276 string XRTCompilationCache::DebugString() const {
277   return "XRTCompilationCache";
278 }
279 
280 }  // namespace tensorflow
281