• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
16 
17 #include "tensorflow/core/platform/casts.h"
18 #include "tensorflow/core/tpu/kernels/tpu_util.h"
19 #include "tensorflow/core/tpu/tpu_api.h"
20 
21 namespace tensorflow {
22 namespace tpu {
23 
RefHolder(TpuCompilationCacheInterface * parent)24 TpuCompilationCacheInterface::RefHolder::RefHolder(
25     TpuCompilationCacheInterface* parent)
26     : parent_(parent) {
27   // Hold a reference to the parent until the holder is discarded.
28   parent_->Ref();
29 }
30 
~RefHolder()31 TpuCompilationCacheInterface::RefHolder::~RefHolder() {
32   parent_->DiscardEntryRefs(entries_);
33   // Release our reference to the parent.
34   parent_->Unref();
35 }
36 
AddRef(CompiledSubgraph * entry)37 void TpuCompilationCacheInterface::RefHolder::AddRef(CompiledSubgraph* entry) {
38   entries_.push_back(entry);
39 }
40 
DebugString() const41 std::string TpuCompilationCacheInterface::RefHolder::DebugString() const {
42   return "TpuCompilationCacheRefHolder";
43 }
44 
CompilationCacheEntryRef()45 CompilationCacheEntryRef::CompilationCacheEntryRef()
46     : parent_(nullptr), entry_(nullptr), index_(0) {}
47 
CompilationCacheEntryRef(TpuCompilationCacheInterface * parent,CompiledSubgraph * entry,int index)48 CompilationCacheEntryRef::CompilationCacheEntryRef(
49     TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
50     : parent_(parent), entry_(entry), index_(index) {
51   if (entry_ == nullptr) {
52     return;
53   }
54   if (entry_->main_entry == nullptr) {
55     entry_->Ref();
56   } else {
57     // This is a sharding/unsharding entry nested in a main entry. Only
58     // refcount the main entry.
59     entry_->main_entry->Ref();
60   }
61 }
62 
~CompilationCacheEntryRef()63 CompilationCacheEntryRef::~CompilationCacheEntryRef() {
64   if (entry_ == nullptr) {
65     return;
66   }
67   if (entry_->main_entry == nullptr) {
68     parent_->DiscardEntryRefs({entry_});
69   } else {
70     parent_->DiscardEntryRefs({entry_->main_entry});
71   }
72 }
73 
get()74 TpuCompilationCacheEntry CompilationCacheEntryRef::get() {
75   if (entry_ == nullptr) {
76     // Create an empty entry if the entry is nullptr. This corresponds to
77     // non-existing sharding/unsharding entries.
78     return TpuCompilationCacheEntry();
79   }
80 
81   return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_);
82 }
83 
ToSubEntryRef(CompilationCacheFetchTarget fetch_target)84 Status CompilationCacheEntryRef::ToSubEntryRef(
85     CompilationCacheFetchTarget fetch_target) {
86   CompiledSubgraph* target = nullptr;
87   switch (fetch_target) {
88     case CompilationCacheFetchTarget::MAIN:
89       target = entry_;
90       break;
91     case CompilationCacheFetchTarget::SHARDING:
92       target = entry_->sharding_entry.get();
93       break;
94     case CompilationCacheFetchTarget::UNSHARDING:
95       target = entry_->unsharding_entry.get();
96       break;
97     default:
98       return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
99   }
100 
101   if (target == nullptr) {
102     // Cache entry does not have an unsharding subentry. Unref and replace
103     // with nullptr.
104     parent_->DiscardEntryRefs({entry_});
105   }
106   // Otherwise, since the refcount is always on the main entry, we don't
107   // need ref/unref.
108   entry_ = target;
109   return Status::OK();
110 }
111 
TpuCompilationCacheInterface(int64_t max_cache_size)112 TpuCompilationCacheInterface::TpuCompilationCacheInterface(
113     int64_t max_cache_size)
114     : max_cache_size_(max_cache_size) {
115   CHECK_GE(max_cache_size_, 0);
116   VLOG(1) << "Created compilation cache size " << max_cache_size_ << " bytes.";
117 }
118 
~TpuCompilationCacheInterface()119 TpuCompilationCacheInterface::~TpuCompilationCacheInterface() {
120   VLOG(1) << "TpuCompilationCacheInterface::~TpuCompilationCacheInterface()";
121   // A buggy client may be holding onto a reference, or a client might have
122   // crashed while holding onto a reference. In either case, discard all
123   // outstanding client references to avoid leaking storage.
124   for (const auto& entry : entries_by_uid_) {
125     while (entry.second->external_references > 0) {
126       Status s = Release(entry.first);
127       CHECK(s.ok());
128     }
129   }
130   while (!entries_by_last_use_.empty()) {
131     UnloadAndDestroy(MarkOldestEntryForEviction());
132   }
133   // By the time the cache is deleted all reference holders should have already
134   // been deleted, since they were holding references to the cache. So all
135   // entries should be gone at this point.
136   CHECK_EQ(cache_.size(), 0);
137   CHECK_EQ(entries_by_uid_.size(), 0);
138   CHECK_EQ(entries_by_proto_key_.size(), 0);
139   CHECK_EQ(cache_size_, 0);
140   CHECK_EQ(marked_for_eviction_size_, 0);
141 }
142 
MarkEntryForEviction(int64_t subgraph_uid)143 Status TpuCompilationCacheInterface::MarkEntryForEviction(
144     int64_t subgraph_uid) {
145   profiler::TraceMe key_release_traceme(
146       "TPU compilation cache possibly evict uid",
147       /*level=*/2);
148   CompiledSubgraph* deleted_entry = nullptr;
149   {
150     absl::MutexLock lock(&mu_);
151     auto iter = entries_by_uid_.find(subgraph_uid);
152     if (iter == entries_by_uid_.end()) {
153       // If already evicted, return ok.
154       return Status::OK();
155     }
156 
157     // Mark entry for eviction.
158     CompiledSubgraph* subgraph_to_evict = iter->second;
159     // If there are external references, should not use this API.
160     if (subgraph_to_evict->external_references != 0) {
161       return errors::Internal("Subgraph ", subgraph_to_evict->subgraph_key,
162                               " external_references greater than zero. Should "
163                               "use TpuCompilationCacheInterface::Release.");
164     }
165 
166     VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key
167             << " for eviction. Debug string: "
168             << subgraph_to_evict->cache_entry_debug_string;
169     entries_by_last_use_.erase(subgraph_to_evict->last_use);
170     cache_size_ -= subgraph_to_evict->total_size;
171     marked_for_eviction_size_ += subgraph_to_evict->total_size;
172 
173     // Evict if refcount exactly one, otherwise only discard cache's reference
174     // to the entry while the actual eviction will happen when refholder's
175     // references go away.
176     deleted_entry = DiscardEntryRef(subgraph_to_evict);
177 
178     VLOG(1) << "After possibly evicting entry " << subgraph_uid
179             << " refs cache is " << cache_.size() << " entries ("
180             << cache_size_ + marked_for_eviction_size_
181             << " bytes), marked for eviction "
182             << (cache_.size() - entries_by_last_use_.size()) << " entries ("
183             << marked_for_eviction_size_ << " bytes).";
184   }
185 
186   // Unload from device cache if entry is evicted from host cache.
187   UnloadAndDestroy(deleted_entry);
188   return Status::OK();
189 }
190 
Release(int64_t subgraph_uid)191 Status TpuCompilationCacheInterface::Release(int64_t subgraph_uid) {
192   profiler::TraceMe key_release_traceme("TPU compilation cache release uid",
193                                         /*level=*/2);
194 
195   CompiledSubgraph* deleted_entry = nullptr;
196   {
197     absl::MutexLock lock(&mu_);
198     auto iter = entries_by_uid_.find(subgraph_uid);
199 
200     if (iter == entries_by_uid_.end()) {
201       return errors::NotFound("No cache entry found for uid ", subgraph_uid);
202     }
203 
204     CHECK_GT(iter->second->external_references, 0);
205     --iter->second->external_references;
206 
207     deleted_entry = DiscardEntryRef(iter->second);
208 
209     VLOG(1) << "After releasing entry " << subgraph_uid << " refs cache is "
210             << cache_.size() << " entries ("
211             << cache_size_ + marked_for_eviction_size_
212             << " bytes), marked for eviction "
213             << (cache_.size() - entries_by_last_use_.size()) << " entries ("
214             << marked_for_eviction_size_ << " bytes).";
215   }
216   UnloadAndDestroy(deleted_entry);
217   return Status::OK();
218 }
219 
UnloadAndDestroy(CompiledSubgraph * entry)220 void TpuCompilationCacheInterface::UnloadAndDestroy(CompiledSubgraph* entry) {
221   if (!entry) return;
222 
223   CHECK(entry->RefCountIsOne());
224   entry->tpu_program_group->UnloadAndDestroyPrograms();
225   entry->Unref();
226 }
227 
RemoveEntry(const std::string & key)228 size_t TpuCompilationCacheInterface::RemoveEntry(const std::string& key) {
229   auto erased = cache_.erase(key);
230   TpuCompilationMetrics::SetCacheEntryCount(cache_.size());
231 
232   auto parsed_key_or_status = ParseCompilationCacheKey(key);
233   CHECK(parsed_key_or_status.status().ok());
234   const TpuCompilationCacheKey parsed_key =
235       parsed_key_or_status.ConsumeValueOrDie();
236   if (!parsed_key.has_guaranteed_const) {
237     return erased;
238   }
239   session_key_map_.erase(
240       strings::StrCat(parsed_key.prefix, parsed_key.session_handle));
241   fingerprint_key_map_.erase(strings::StrCat(
242       parsed_key.prefix, parsed_key.guaranteed_const_fingerprint()));
243   return erased;
244 }
245 
DiscardEntryRef(CompiledSubgraph * entry)246 CompiledSubgraph* TpuCompilationCacheInterface::DiscardEntryRef(
247     CompiledSubgraph* entry) {
248   if (entry->RefCountIsOne()) {
249     // The last reference to this entry is going away, so really delete it from
250     // the cache in such a way that it can't be restored by being looked up
251     // again.
252 
253     // Sanity-check that it has been marked for eviction.
254     CHECK(entries_by_last_use_.find(entry->last_use) ==
255           entries_by_last_use_.end());
256     // Update the counter tracking how much space is taken up by entries that
257     // are marked for eviction.
258     marked_for_eviction_size_ -= entry->total_size;
259 
260     // Remove the entry from the cache.
261     auto erased = RemoveEntry(entry->subgraph_key);
262 
263     if (erased == 0) {
264       LOG(FATAL) << "Tried to discard nonexistent cache entry";
265     }
266     erased = entries_by_uid_.erase(entry->uid);
267     CHECK_EQ(erased, 1);
268     for (const std::string& key : entry->proto_key) {
269       erased = entries_by_proto_key_.erase(key);
270       CHECK_EQ(erased, 1);
271     }
272     // The actual deletion will happen outside the lock in UnloadAndDestroy().
273     return entry;
274   }
275   entry->Unref();
276   return nullptr;
277 }
278 
MakePerStepRefHolder()279 CompilationRefHolder* TpuCompilationCacheInterface::MakePerStepRefHolder() {
280   return new RefHolder(this);
281 }
282 
DiscardEntryRefs(gtl::ArraySlice<CompiledSubgraph * > entries)283 void TpuCompilationCacheInterface::DiscardEntryRefs(
284     gtl::ArraySlice<CompiledSubgraph*> entries) {
285   std::vector<CompiledSubgraph*> removed_entries;
286   {
287     absl::MutexLock lock(&mu_);
288 
289     for (auto entry : entries) {
290       removed_entries.push_back(DiscardEntryRef(entry));
291     }
292 
293     VLOG(1) << "After discarding entry refs cache is " << cache_.size()
294             << " entries (" << cache_size_ + marked_for_eviction_size_
295             << " bytes), marked for eviction "
296             << (cache_.size() - entries_by_last_use_.size()) << " entries ("
297             << marked_for_eviction_size_ << " bytes).";
298   }
299   for (auto removed_entry : removed_entries) {
300     UnloadAndDestroy(removed_entry);
301   }
302 }
303 
MarkOldestEntryForEviction()304 CompiledSubgraph* TpuCompilationCacheInterface::MarkOldestEntryForEviction() {
305   CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second;
306   VLOG(1) << "Marking " << entry_to_mark->subgraph_key
307           << " for eviction. Debug string: "
308           << entry_to_mark->cache_entry_debug_string;
309   entries_by_last_use_.erase(entry_to_mark->last_use);
310   cache_size_ -= entry_to_mark->total_size;
311   marked_for_eviction_size_ += entry_to_mark->total_size;
312   // Discard the cache's reference to entry. If steps are holding onto
313   // references to entry it won't be deleted until the last step holding it
314   // completes. It stays in the cache in the meantime and can be resurrected
315   // by a call to CompileIfKeyAbsent if that occurs before the last reference
316   // expires.
317   return DiscardEntryRef(entry_to_mark);
318 }
319 
LookupEntryMarkedForEviction(CompiledSubgraph * entry,std::vector<CompiledSubgraph * > * removed_entries)320 void TpuCompilationCacheInterface::LookupEntryMarkedForEviction(
321     CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries) {
322   // The entry was previously marked for eviction (or is newly created) so
323   // unmark it. Add a reference (owned by the cache), update the cache size, and
324   // mark something old for eviction if necessary.
325   entry->Ref();
326   marked_for_eviction_size_ -= entry->total_size;
327   cache_size_ += entry->total_size;
328 
329   // Mark the least-recently-used non-marked entry for eviction. Never mark the
330   // most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
331   // which means there's only one entry not already marked for eviction), so
332   // that an entry persists in the cache even if it is larger than the allocated
333   // cache size.
334   while (entries_by_last_use_.size() > 1 && cache_size_ > max_cache_size_) {
335     if (auto entry_to_evict = MarkOldestEntryForEviction()) {
336       removed_entries->push_back(entry_to_evict);
337     }
338   }
339 }
340 
InsertEntry(const std::string & key,CompiledSubgraph * entry)341 void TpuCompilationCacheInterface::InsertEntry(const std::string& key,
342                                                CompiledSubgraph* entry) {
343   auto cache_inserted =
344       cache_.insert(std::pair<std::string, CompiledSubgraph*>(key, entry));
345   CHECK(cache_inserted.second);
346   TpuCompilationMetrics::SetCacheEntryCount(cache_.size());
347 
348   auto parsed_key_or_status = ParseCompilationCacheKey(key);
349   CHECK(parsed_key_or_status.status().ok());
350   const TpuCompilationCacheKey parsed_key =
351       parsed_key_or_status.ConsumeValueOrDie();
352   if (!parsed_key.has_guaranteed_const) {
353     return;
354   }
355   session_key_map_.insert(std::make_pair(
356       strings::StrCat(parsed_key.prefix, parsed_key.session_handle), key));
357   fingerprint_key_map_.insert(
358       std::make_pair(strings::StrCat(parsed_key.prefix,
359                                      parsed_key.guaranteed_const_fingerprint()),
360                      key));
361 }
362 
CompileIfKeyAbsent(const TpuCompilationCacheKey & subgraph_key,const SessionMetadata * session_metadata,CompilationRefHolder * per_step_ref_holder,int64 * uid,std::vector<std::string> * proto_key,std::vector<std::string> * sharding_key,std::vector<bool> * may_modify_variables,absl::Span<const xla::HloProto * const> * hlo_metadatas,const std::function<Status (TpuProgramGroupInterface *)> & compile_function)363 Status TpuCompilationCacheInterface::CompileIfKeyAbsent(
364     const TpuCompilationCacheKey& subgraph_key,
365     const SessionMetadata* session_metadata,
366     CompilationRefHolder* per_step_ref_holder, int64* uid,
367     std::vector<std::string>* proto_key, std::vector<std::string>* sharding_key,
368     std::vector<bool>* may_modify_variables,
369     absl::Span<const xla::HloProto* const>* hlo_metadatas,
370     const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
371   std::vector<CompiledSubgraph*> removed_entries;
372   auto status = CompileIfKeyAbsentHelper(
373       subgraph_key, session_metadata, per_step_ref_holder, uid, proto_key,
374       sharding_key, may_modify_variables, &removed_entries, hlo_metadatas,
375       compile_function);
376   for (auto entry : removed_entries) {
377     UnloadAndDestroy(entry);
378   }
379   return status;
380 }
381 
FindCacheKey(const TpuCompilationCacheKey & subgraph_key)382 std::string TpuCompilationCacheInterface::FindCacheKey(
383     const TpuCompilationCacheKey& subgraph_key) {
384   if (!subgraph_key.has_guaranteed_const) {
385     return subgraph_key.prefix;
386   }
387   auto iter = session_key_map_.find(
388       strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle));
389   if (iter != session_key_map_.end()) {
390     return iter->second;
391   }
392   iter = fingerprint_key_map_.find(strings::StrCat(
393       subgraph_key.prefix, subgraph_key.guaranteed_const_fingerprint()));
394   if (iter != session_key_map_.end()) {
395     return iter->second;
396   }
397   VLOG(1) << "No matching cache key found for key " << subgraph_key.ToString();
398   return "";
399 }
400 
CompileIfKeyAbsentHelper(const TpuCompilationCacheKey & subgraph_key,const SessionMetadata * session_metadata,CompilationRefHolder * per_step_ref_holder,int64 * uid,std::vector<std::string> * proto_key,std::vector<std::string> * sharding_key,std::vector<bool> * may_modify_variables,std::vector<CompiledSubgraph * > * removed_entries,absl::Span<const xla::HloProto * const> * hlo_metadatas,const std::function<Status (TpuProgramGroupInterface *)> & compile_function)401 Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
402     const TpuCompilationCacheKey& subgraph_key,
403     const SessionMetadata* session_metadata,
404     CompilationRefHolder* per_step_ref_holder, int64* uid,
405     std::vector<std::string>* proto_key, std::vector<std::string>* sharding_key,
406     std::vector<bool>* may_modify_variables,
407     std::vector<CompiledSubgraph*>* removed_entries,
408     absl::Span<const xla::HloProto* const>* hlo_metadatas,
409     const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
410   CompiledSubgraph* entry = nullptr;
411 
412   profiler::TraceMe subgraph_lookup_traceme(
413       "TPU compilation cache subgraph lookup",
414       /*level=*/2);
415 
416   // NOTE: In spite of the fact that we use MutexLock, we do not hold the lock
417   // for the lifetime of the object, see InitializeEntry() call below.
418   absl::MutexLock lock(&mu_);
419 
420   std::string cache_key = FindCacheKey(subgraph_key);
421   auto iter = cache_.find(cache_key);
422   bool is_new_key = iter == cache_.end();
423 
424   const std::string session_name =
425       tpu::SessionNameFromMetadata(session_metadata);
426 
427   if (is_new_key) {
428     cache_key = subgraph_key.ToString();
429     TpuCompilationMetrics::IncrementCacheLookupCount(
430         /*is_cache_hit=*/false, session_name);
431     const std::string msg =
432         strings::StrCat("TPU host compilation cache miss: cache_key(",
433                         cache_key, "), session_name(", session_name, ")");
434     TRACESTRING(msg);
435     LOG(INFO) << msg;
436 
437     // Check if caller has disabled compilation. Set using
438     // internal::ScopedTpuCompileDisabler.
439     if (!OpsApiFn()->TpuCompile_IsTpuCompilationEnabledFn()) {
440       const std::string error_msg = strings::StrCat(
441           "[TpuCompilationDisabled]: Compilation cache miss, but compilation "
442           "disabled, session_name(",
443           session_name, ") Debug String: ", subgraph_key.debug_string);
444       if (VLOG_IS_ON(2)) {
445         VLOG(2) << "Cache Missed. Current cache entries: ";
446         for (auto it = cache_.begin(); it != cache_.end(); ++it) {
447           VLOG(2) << "Cache Debug Info: ";
448           VLOG(2) << it->second->cache_entry_debug_string;
449         }
450       }
451 
452       LOG_EVERY_N_SEC(WARNING, 30) << error_msg;
453       return errors::NotFound(error_msg);
454     }
455 
456     // The single ref on the newly-created entry is owned by the caller.
457     VLOG(1) << "Before adding new entry for key " << cache_key
458             << " with session_name( " << session_name << ");"
459             << "; cache is " << cache_.size() << " entries ("
460             << cache_size_ + marked_for_eviction_size_ << " bytes), "
461             << " marked for eviction "
462             << (cache_.size() - entries_by_last_use_.size()) << " entries ("
463             << marked_for_eviction_size_ << " bytes).";
464     // Note that InitializeEntry() will Release/Reacquire mu_.
465     entry = InitializeEntry(cache_key, compile_function, subgraph_key);
466     bool compilation_success = entry->tpu_program_group->program_count() > 0;
467     TRACELITERAL("TPU host compilation cache: compilation done.");
468     LOG(INFO) << strings::StrCat(
469         "TPU host compilation cache: compilation ",
470         compilation_success ? "complete" : "failed", " for cache_key(",
471         cache_key, "), session_name(", session_name, "), subgraph_key(",
472         subgraph_key.debug_string, ")");
473     // If session_name is present, log some additional stats related to HBM
474     // here, so that they can be associated directly to the session.
475     if (!session_name.empty()) {
476       entry->tpu_program_group->LogProgramMemorySummary();
477     }
478   } else {
479     TpuCompilationMetrics::IncrementCacheLookupCount(
480         /*is_cache_hit=*/true, session_name);
481     const std::string msg =
482         strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key,
483                         "), session_name(", session_name, ")");
484     TRACESTRING(msg);
485     VLOG(1) << msg;
486     VLOG(1) << "Before refreshing entry for key " << cache_key
487             << " with session_name( " << session_name << "); cache is "
488             << cache_.size() << " entries ("
489             << cache_size_ + marked_for_eviction_size_ << " bytes), "
490             << " marked for eviction "
491             << (cache_.size() - entries_by_last_use_.size()) << " entries ("
492             << marked_for_eviction_size_ << " bytes).";
493     entry = iter->second;
494     // Make a new reference that is owned by the caller.
495     entry->Ref();
496     // Block if necessary until the subgraph has been initialized.
497     mu_.Await(absl::Condition(
498         +[](CompiledSubgraph* e) { return e->initialized; }, entry));
499   }
500 
501   // Let the caller know the uid of the entry.
502   *uid = entry->uid;
503   // Let the caller know the keys for each of the cached protos.
504   *proto_key = entry->proto_key;
505   *sharding_key = entry->sharding_key;
506   *may_modify_variables = entry->tpu_program_group->may_modify_variables_list();
507   *hlo_metadatas = entry->tpu_program_group->hlo_metadatas();
508 
509   // If the caller didn't supply a per_step_ref_holder then the caller is going
510   // to manually release the reference later via a call to Release().
511   if (per_step_ref_holder == nullptr) {
512     ++entry->external_references;
513   } else {
514     // The caller wants its reference to be handed off to a per-step holder that
515     // will discard the reference when the step completes.
516     RefHolder* cast_ref_holder =
517         tensorflow::down_cast<RefHolder*>(per_step_ref_holder);
518     CHECK_NE(cast_ref_holder, nullptr);
519     cast_ref_holder->AddRef(entry);
520   }
521 
522   // Remove the old LRU-table entry if it wasn't already marked for eviction.
523   auto erased = entries_by_last_use_.erase(entry->last_use);
524   // Update the LRU table indicating this entry is the most recently used.
525   entry->last_use = use_counter_++;
526   entries_by_last_use_[entry->last_use] = entry;
527   if (erased == 0) {
528     // The entry had been marked for eviction, or is newly created.
529     LookupEntryMarkedForEviction(entry, removed_entries);
530   }
531 
532   // Log a little more verbosely when a key is added.
533   if (VLOG_IS_ON(1) || is_new_key) {
534     LOG(INFO) << "After " << (is_new_key ? "adding" : "refreshing")
535               << " entry for key " << cache_key << " with session_name "
536               << session_name << " cache is " << cache_.size() << " entries ("
537               << cache_size_ + marked_for_eviction_size_ << " bytes), "
538               << " marked for eviction "
539               << (cache_.size() - entries_by_last_use_.size()) << " entries ("
540               << marked_for_eviction_size_ << " bytes).";
541   }
542   return entry->initialization_status;
543 }
544 
GetKeysFromUid(int64_t uid,std::vector<std::string> * keys)545 Status TpuCompilationCacheInterface::GetKeysFromUid(
546     int64_t uid, std::vector<std::string>* keys) {
547   keys->clear();
548 
549   absl::MutexLock lock(&mu_);
550   const auto iter = entries_by_uid_.find(uid);
551   if (iter == entries_by_uid_.end()) {
552     return errors::NotFound("No subgraph found for uid ", uid);
553   }
554   *keys = iter->second->proto_key;
555   return Status::OK();
556 }
557 
Lookup(int64_t uid,int proto_index,std::unique_ptr<CompilationCacheEntryRef> * entry)558 Status TpuCompilationCacheInterface::Lookup(
559     int64_t uid, int proto_index,
560     std::unique_ptr<CompilationCacheEntryRef>* entry) {
561   entry->reset();
562 
563   profiler::TraceMe proto_lookup_traceme(
564       "TPU compilation cache proto lookup by uid",
565       /*level=*/2);
566 
567   absl::MutexLock lock(&mu_);
568   const auto iter = entries_by_uid_.find(uid);
569   if (iter == entries_by_uid_.end()) {
570     return errors::NotFound("No subgraph found for uid ", uid);
571   }
572   CompiledSubgraph* cache_entry = iter->second;
573   if (proto_index < 0 ||
574       proto_index >= cache_entry->tpu_program_group->program_count()) {
575     return errors::NotFound("No proto found for core index ", proto_index,
576                             " in subgraph with uid ", uid);
577   }
578   *entry = absl::make_unique<CompilationCacheEntryRef>(this, cache_entry,
579                                                        proto_index);
580   return Status::OK();
581 }
582 
Lookup(const std::string & proto_key,std::unique_ptr<CompilationCacheEntryRef> * entry)583 Status TpuCompilationCacheInterface::Lookup(
584     const std::string& proto_key,
585     std::unique_ptr<CompilationCacheEntryRef>* entry) {
586   entry->reset();
587 
588   profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
589                                          /*level=*/2);
590 
591   absl::MutexLock lock(&mu_);
592   const auto iter = entries_by_proto_key_.find(proto_key);
593   if (iter == entries_by_proto_key_.end()) {
594     return errors::NotFound("No proto found for key ", proto_key);
595   }
596   CompiledSubgraph* cache_entry = iter->second.first;
597   int proto_index = iter->second.second;
598   *entry = absl::make_unique<CompilationCacheEntryRef>(this, cache_entry,
599                                                        proto_index);
600   return Status::OK();
601 }
602 }  // namespace tpu
603 }  // namespace tensorflow
604