• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
17 
18 #include <stdlib.h>
19 
20 #include <string>
21 
22 #include "absl/synchronization/mutex.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/random/random.h"
26 
27 namespace tensorflow {
28 
29 namespace {
30 
get_uid()31 int64 get_uid() {
32   uint64 unsigned_rand = random::New64() & INT64_MAX;
33   return static_cast<int64>(unsigned_rand);
34 }
35 
GetCompilationCacheSizeFromEnv()36 int64 GetCompilationCacheSizeFromEnv() {
37   const char* env = getenv("TF_XRT_COMPILATION_CACHE_SIZE");
38   return env == nullptr ? 1024 : std::stol(env);
39 }
40 
41 }  // namespace
42 
43 const char* kXRTCompilationCacheResourceName = "xrt_compilation_cache";
44 
EntryRefImpl(XRTCompilationCache * parent,CompiledSubgraph * entry)45 XRTCompilationCache::EntryRefImpl::EntryRefImpl(XRTCompilationCache* parent,
46                                                 CompiledSubgraph* entry)
47     : parent_(parent), entry_(entry) {
48   entry_->Ref();
49 }
50 
~EntryRefImpl()51 XRTCompilationCache::EntryRefImpl::~EntryRefImpl() {
52   parent_->DiscardEntryRef(entry_);
53 }
54 
get()55 XRTCompilationCacheEntry XRTCompilationCache::EntryRefImpl::get() {
56   return XRTCompilationCacheEntry(entry_->program.get());
57 }
58 
XRTCompilationCache(int max_number_of_entries)59 XRTCompilationCache::XRTCompilationCache(int max_number_of_entries)
60     : max_cache_entries_(max_number_of_entries) {
61   CHECK_GE(max_cache_entries_, 0);
62   VLOG(1) << "Created compilation cache max " << max_cache_entries_
63           << " entries.";
64 }
65 
~XRTCompilationCache()66 XRTCompilationCache::~XRTCompilationCache() {
67   VLOG(1) << "XRTCompilationCache::~XRTCompilationCache()";
68   // A buggy client may be holding onto a reference, or a client might have
69   // crashed while holding onto a reference. In either case, discard all
70   // outstanding client references to avoid leaking storage.
71   for (const auto& entry : entries_by_uid_) {
72     while (!entry.second->RefCountIsOne()) {
73       entry.second->Unref();
74     }
75   }
76   while (!entries_by_last_use_.empty()) {
77     MarkOldestEntryForEviction();
78   }
79   CHECK_EQ(cache_.size(), 0);
80   CHECK_EQ(entries_by_uid_.size(), 0);
81   CHECK_EQ(cache_entries_, 0);
82   CHECK_EQ(marked_for_eviction_entries_, 0);
83 }
84 
Release(int64_t uid)85 Status XRTCompilationCache::Release(int64_t uid) {
86   absl::MutexLock lock(&mu_);
87   auto iter = entries_by_uid_.find(uid);
88 
89   if (iter == entries_by_uid_.end()) {
90     return errors::NotFound("No cache entry found for uid ", uid);
91   }
92 
93   DiscardEntryRefLocked(iter->second);
94 
95   VLOG(1) << "After releasing entry " << uid << " refs cache is "
96           << cache_.size() << " entries ("
97           << cache_entries_ + marked_for_eviction_entries_
98           << "), marked for eviction "
99           << (cache_.size() - entries_by_last_use_.size()) << " entries ("
100           << marked_for_eviction_entries_ << ").";
101 
102   return Status::OK();
103 }
104 
DiscardEntryRef(CompiledSubgraph * entry)105 void XRTCompilationCache::DiscardEntryRef(CompiledSubgraph* entry) {
106   absl::MutexLock lock(&mu_);
107   DiscardEntryRefLocked(entry);
108 }
109 
DiscardEntryRefLocked(CompiledSubgraph * entry)110 void XRTCompilationCache::DiscardEntryRefLocked(CompiledSubgraph* entry) {
111   if (entry->RefCountIsOne()) {
112     // The last reference to this entry is going away, so really delete it from
113     // the cache in such a way that it can't be restored by being looked up
114     // again.
115 
116     // Sanity-check that it has been marked for eviction.
117     CHECK(entries_by_last_use_.find(entry->last_use) ==
118           entries_by_last_use_.end());
119     // Update the counter tracking how much space is taken up by entries that
120     // are marked for eviction.
121     --marked_for_eviction_entries_;
122 
123     // Remove the entry from the cache.
124     auto erased = cache_.erase(entry->key);
125     if (erased == 0) {
126       LOG(FATAL) << "Tried to discard nonexistent cache entry";
127     }
128     erased = entries_by_uid_.erase(entry->uid);
129     CHECK_EQ(erased, 1);
130   }
131   entry->Unref();
132 }
133 
MarkOldestEntryForEviction()134 void XRTCompilationCache::MarkOldestEntryForEviction() {
135   CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second;
136   VLOG(1) << "Marking " << entry_to_mark->key << " for eviction";
137   entries_by_last_use_.erase(entry_to_mark->last_use);
138   --cache_entries_;
139   ++marked_for_eviction_entries_;
140   // Discard the cache's reference to entry. If steps are holding onto
141   // references to entry it won't be deleted until the last step holding it
142   // completes. It stays in the cache in the meantime and can be resurrected
143   // by a call to CompileIfKeyAbsent if that occurs before the last reference
144   // expires.
145   DiscardEntryRefLocked(entry_to_mark);
146 }
147 
LookupEntryMarkedForEviction(CompiledSubgraph * entry)148 void XRTCompilationCache::LookupEntryMarkedForEviction(
149     CompiledSubgraph* entry) {
150   // The entry was previously marked for eviction (or is newly created) so
151   // unmark it. Add a reference (owned by the cache), update the cache size, and
152   // mark something old for eviction if necessary.
153   entry->Ref();
154   --marked_for_eviction_entries_;
155   ++cache_entries_;
156 
157   // Mark the least-recently-used non-marked entry for eviction. Never mark the
158   // most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
159   // which means there's only one entry not already marked for eviction), so
160   // that an entry persists in the cache even if it is larger than the allocated
161   // cache size.
162   while (entries_by_last_use_.size() > 1 &&
163          cache_entries_ > max_cache_entries_) {
164     MarkOldestEntryForEviction();
165   }
166 }
167 
InitializeEntry(const string & key,const std::function<Status (std::unique_ptr<xla::LocalExecutable> *)> & initialize_program)168 XRTCompilationCache::CompiledSubgraph* XRTCompilationCache::InitializeEntry(
169     const string& key,
170     const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
171         initialize_program) {
172   CompiledSubgraph* entry = new CompiledSubgraph();
173   entry->parent = this;
174   entry->key = key;
175   entry->uid = get_uid();
176   // Add the entry to the cache. Once the computation has been compiled,
177   // UpdateEntryAfterCompilation will be called to potentially mark old entries
178   // that don't fit any more for eviction.
179   //
180   // At this point there is one reference to entry, which is owned by the caller
181   // who created the entry. A second reference, owned by the cache, will be
182   // added below since we leave the entry in the 'marked for eviction' state
183   // here.
184   auto cache_inserted =
185       cache_.insert(std::pair<string, CompiledSubgraph*>(key, entry));
186   CHECK(cache_inserted.second);
187 
188   // Initialize the program outside the lock so that other cache operations
189   // can proceed during the (potentially lengthy) initialization.
190   Status s;
191   std::unique_ptr<xla::LocalExecutable> program;
192   {
193     mu_.Unlock();
194     { s = initialize_program(&program); }
195     mu_.Lock();
196   }
197 
198   // Add the entry to the uid index.
199   auto uid_inserted = entries_by_uid_.insert(
200       std::pair<int64, CompiledSubgraph*>(entry->uid, entry));
201   CHECK(uid_inserted.second);
202 
203   entry->initialized = true;
204   entry->initialization_status = s;
205   if (s.ok()) {
206     entry->program = std::move(program);
207   }
208   // Add the entry to marked_for_eviction_entries_ since it will be adjusted
209   // down again when the newly-created entry gets unmarked.
210   ++marked_for_eviction_entries_;
211   return entry;
212 }
213 
CompileIfKeyAbsent(const string & key,int64 * uid,const std::function<Status (std::unique_ptr<xla::LocalExecutable> *)> & compile_function)214 Status XRTCompilationCache::CompileIfKeyAbsent(
215     const string& key, int64* uid,
216     const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
217         compile_function) {
218   CompiledSubgraph* entry = nullptr;
219 
220   absl::MutexLock lock(&mu_);
221   auto iter = cache_.find(key);
222 
223   if (iter == cache_.end()) {
224     // The single ref on the newly-created entry is owned by the caller.
225     VLOG(1) << "Before adding new entry for key " << key << " cache is "
226             << cache_.size() << " entries ("
227             << cache_entries_ + marked_for_eviction_entries_ << "), "
228             << " marked for eviction "
229             << (cache_.size() - entries_by_last_use_.size()) << " entries ("
230             << marked_for_eviction_entries_ << ").";
231     entry = InitializeEntry(key, compile_function);
232   } else {
233     VLOG(1) << "Before refreshing entry for key " << key << " cache is "
234             << cache_.size() << " entries ("
235             << cache_entries_ + marked_for_eviction_entries_ << "), "
236             << " marked for eviction "
237             << (cache_.size() - entries_by_last_use_.size()) << " entries ("
238             << marked_for_eviction_entries_ << ").";
239     entry = iter->second;
240     // Make a new reference that is owned by the caller.
241     entry->Ref();
242     // Block if necessary until the subgraph has been initialized.
243     mu_.Await(absl::Condition(
244         +[](CompiledSubgraph* e) { return e->initialized; }, entry));
245   }
246 
247   // Let the caller know the uid of the entry.
248   *uid = entry->uid;
249 
250   // Remove the old LRU-table entry if it wasn't already marked for eviction.
251   auto erased = entries_by_last_use_.erase(entry->last_use);
252   // Update the LRU table indicating this entry is the most recently used.
253   entry->last_use = use_counter_++;
254   entries_by_last_use_[entry->last_use] = entry;
255   if (erased == 0) {
256     // The entry had been marked for eviction, or is newly created.
257     LookupEntryMarkedForEviction(entry);
258   }
259 
260   VLOG(1) << "After refreshing entry for key " << key << " cache is "
261           << cache_.size() << " entries ("
262           << cache_entries_ + marked_for_eviction_entries_ << "), "
263           << " marked for eviction "
264           << (cache_.size() - entries_by_last_use_.size()) << " entries ("
265           << marked_for_eviction_entries_ << ").";
266 
267   return entry->initialization_status;
268 }
269 
Lookup(int64_t uid,std::unique_ptr<XRTCompilationCacheEntryRef> * entry)270 Status XRTCompilationCache::Lookup(
271     int64_t uid, std::unique_ptr<XRTCompilationCacheEntryRef>* entry) {
272   entry->reset();
273 
274   absl::MutexLock lock(&mu_);
275   const auto iter = entries_by_uid_.find(uid);
276   if (iter == entries_by_uid_.end()) {
277     return errors::NotFound("No executable found for uid ", uid);
278   }
279   CompiledSubgraph* cache_entry = iter->second;
280   *entry = std::unique_ptr<XRTCompilationCacheEntryRef>(
281       new EntryRefImpl(this, cache_entry));
282   return Status::OK();
283 }
284 
DebugString() const285 string XRTCompilationCache::DebugString() const {
286   return "XRTCompilationCache";
287 }
288 
GetOrCreateCompilationCache(ResourceMgr * rm,int64_t max_number_of_entries)289 xla::StatusOr<RefPtr<XRTCompilationCache>> GetOrCreateCompilationCache(
290     ResourceMgr* rm, int64_t max_number_of_entries) {
291   if (max_number_of_entries == 0) {
292     max_number_of_entries = GetCompilationCacheSizeFromEnv();
293   }
294   XRTCompilationCache* cache;
295   TF_RETURN_IF_ERROR(rm->LookupOrCreate<XRTCompilationCache>(
296       rm->default_container(), kXRTCompilationCacheResourceName, &cache,
297       [&](XRTCompilationCache** new_cache) {
298         *new_cache = new XRTCompilationCache(max_number_of_entries);
299         return Status::OK();
300       }));
301   return RefPtr<XRTCompilationCache>(cache);
302 }
303 
304 }  // namespace tensorflow
305