1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
16
17 #include "tensorflow/core/platform/casts.h"
18 #include "tensorflow/core/tpu/kernels/tpu_util.h"
19 #include "tensorflow/core/tpu/tpu_api.h"
20
21 namespace tensorflow {
22 namespace tpu {
23
RefHolder(TpuCompilationCacheInterface * parent)24 TpuCompilationCacheInterface::RefHolder::RefHolder(
25 TpuCompilationCacheInterface* parent)
26 : parent_(parent) {
27 // Hold a reference to the parent until the holder is discarded.
28 parent_->Ref();
29 }
30
~RefHolder()31 TpuCompilationCacheInterface::RefHolder::~RefHolder() {
32 parent_->DiscardEntryRefs(entries_);
33 // Release our reference to the parent.
34 parent_->Unref();
35 }
36
AddRef(CompiledSubgraph * entry)37 void TpuCompilationCacheInterface::RefHolder::AddRef(CompiledSubgraph* entry) {
38 entries_.push_back(entry);
39 }
40
DebugString() const41 std::string TpuCompilationCacheInterface::RefHolder::DebugString() const {
42 return "TpuCompilationCacheRefHolder";
43 }
44
CompilationCacheEntryRef()45 CompilationCacheEntryRef::CompilationCacheEntryRef()
46 : parent_(nullptr), entry_(nullptr), index_(0) {}
47
CompilationCacheEntryRef(TpuCompilationCacheInterface * parent,CompiledSubgraph * entry,int index)48 CompilationCacheEntryRef::CompilationCacheEntryRef(
49 TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
50 : parent_(parent), entry_(entry), index_(index) {
51 if (entry_ == nullptr) {
52 return;
53 }
54 if (entry_->main_entry == nullptr) {
55 entry_->Ref();
56 } else {
57 // This is a sharding/unsharding entry nested in a main entry. Only
58 // refcount the main entry.
59 entry_->main_entry->Ref();
60 }
61 }
62
~CompilationCacheEntryRef()63 CompilationCacheEntryRef::~CompilationCacheEntryRef() {
64 if (entry_ == nullptr) {
65 return;
66 }
67 if (entry_->main_entry == nullptr) {
68 parent_->DiscardEntryRefs({entry_});
69 } else {
70 parent_->DiscardEntryRefs({entry_->main_entry});
71 }
72 }
73
get()74 TpuCompilationCacheEntry CompilationCacheEntryRef::get() {
75 if (entry_ == nullptr) {
76 // Create an empty entry if the entry is nullptr. This corresponds to
77 // non-existing sharding/unsharding entries.
78 return TpuCompilationCacheEntry();
79 }
80
81 return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_);
82 }
83
ToSubEntryRef(CompilationCacheFetchTarget fetch_target)84 Status CompilationCacheEntryRef::ToSubEntryRef(
85 CompilationCacheFetchTarget fetch_target) {
86 CompiledSubgraph* target = nullptr;
87 switch (fetch_target) {
88 case CompilationCacheFetchTarget::MAIN:
89 target = entry_;
90 break;
91 case CompilationCacheFetchTarget::SHARDING:
92 target = entry_->sharding_entry.get();
93 break;
94 case CompilationCacheFetchTarget::UNSHARDING:
95 target = entry_->unsharding_entry.get();
96 break;
97 default:
98 return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
99 }
100
101 if (target == nullptr) {
102 // Cache entry does not have an unsharding subentry. Unref and replace
103 // with nullptr.
104 parent_->DiscardEntryRefs({entry_});
105 }
106 // Otherwise, since the refcount is always on the main entry, we don't
107 // need ref/unref.
108 entry_ = target;
109 return Status::OK();
110 }
111
TpuCompilationCacheInterface(int64_t max_cache_size)112 TpuCompilationCacheInterface::TpuCompilationCacheInterface(
113 int64_t max_cache_size)
114 : max_cache_size_(max_cache_size) {
115 CHECK_GE(max_cache_size_, 0);
116 VLOG(1) << "Created compilation cache size " << max_cache_size_ << " bytes.";
117 }
118
~TpuCompilationCacheInterface()119 TpuCompilationCacheInterface::~TpuCompilationCacheInterface() {
120 VLOG(1) << "TpuCompilationCacheInterface::~TpuCompilationCacheInterface()";
121 // A buggy client may be holding onto a reference, or a client might have
122 // crashed while holding onto a reference. In either case, discard all
123 // outstanding client references to avoid leaking storage.
124 for (const auto& entry : entries_by_uid_) {
125 while (entry.second->external_references > 0) {
126 Status s = Release(entry.first);
127 CHECK(s.ok());
128 }
129 }
130 while (!entries_by_last_use_.empty()) {
131 UnloadAndDestroy(MarkOldestEntryForEviction());
132 }
133 // By the time the cache is deleted all reference holders should have already
134 // been deleted, since they were holding references to the cache. So all
135 // entries should be gone at this point.
136 CHECK_EQ(cache_.size(), 0);
137 CHECK_EQ(entries_by_uid_.size(), 0);
138 CHECK_EQ(entries_by_proto_key_.size(), 0);
139 CHECK_EQ(cache_size_, 0);
140 CHECK_EQ(marked_for_eviction_size_, 0);
141 }
142
MarkEntryForEviction(int64_t subgraph_uid)143 Status TpuCompilationCacheInterface::MarkEntryForEviction(
144 int64_t subgraph_uid) {
145 profiler::TraceMe key_release_traceme(
146 "TPU compilation cache possibly evict uid",
147 /*level=*/2);
148 CompiledSubgraph* deleted_entry = nullptr;
149 {
150 absl::MutexLock lock(&mu_);
151 auto iter = entries_by_uid_.find(subgraph_uid);
152 if (iter == entries_by_uid_.end()) {
153 // If already evicted, return ok.
154 return Status::OK();
155 }
156
157 // Mark entry for eviction.
158 CompiledSubgraph* subgraph_to_evict = iter->second;
159 // If there are external references, should not use this API.
160 if (subgraph_to_evict->external_references != 0) {
161 return errors::Internal("Subgraph ", subgraph_to_evict->subgraph_key,
162 " external_references greater than zero. Should "
163 "use TpuCompilationCacheInterface::Release.");
164 }
165
166 VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key
167 << " for eviction. Debug string: "
168 << subgraph_to_evict->cache_entry_debug_string;
169 entries_by_last_use_.erase(subgraph_to_evict->last_use);
170 cache_size_ -= subgraph_to_evict->total_size;
171 marked_for_eviction_size_ += subgraph_to_evict->total_size;
172
173 // Evict if refcount exactly one, otherwise only discard cache's reference
174 // to the entry while the actual eviction will happen when refholder's
175 // references go away.
176 deleted_entry = DiscardEntryRef(subgraph_to_evict);
177
178 VLOG(1) << "After possibly evicting entry " << subgraph_uid
179 << " refs cache is " << cache_.size() << " entries ("
180 << cache_size_ + marked_for_eviction_size_
181 << " bytes), marked for eviction "
182 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
183 << marked_for_eviction_size_ << " bytes).";
184 }
185
186 // Unload from device cache if entry is evicted from host cache.
187 UnloadAndDestroy(deleted_entry);
188 return Status::OK();
189 }
190
Release(int64_t subgraph_uid)191 Status TpuCompilationCacheInterface::Release(int64_t subgraph_uid) {
192 profiler::TraceMe key_release_traceme("TPU compilation cache release uid",
193 /*level=*/2);
194
195 CompiledSubgraph* deleted_entry = nullptr;
196 {
197 absl::MutexLock lock(&mu_);
198 auto iter = entries_by_uid_.find(subgraph_uid);
199
200 if (iter == entries_by_uid_.end()) {
201 return errors::NotFound("No cache entry found for uid ", subgraph_uid);
202 }
203
204 CHECK_GT(iter->second->external_references, 0);
205 --iter->second->external_references;
206
207 deleted_entry = DiscardEntryRef(iter->second);
208
209 VLOG(1) << "After releasing entry " << subgraph_uid << " refs cache is "
210 << cache_.size() << " entries ("
211 << cache_size_ + marked_for_eviction_size_
212 << " bytes), marked for eviction "
213 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
214 << marked_for_eviction_size_ << " bytes).";
215 }
216 UnloadAndDestroy(deleted_entry);
217 return Status::OK();
218 }
219
UnloadAndDestroy(CompiledSubgraph * entry)220 void TpuCompilationCacheInterface::UnloadAndDestroy(CompiledSubgraph* entry) {
221 if (!entry) return;
222
223 CHECK(entry->RefCountIsOne());
224 entry->tpu_program_group->UnloadAndDestroyPrograms();
225 entry->Unref();
226 }
227
RemoveEntry(const std::string & key)228 size_t TpuCompilationCacheInterface::RemoveEntry(const std::string& key) {
229 auto erased = cache_.erase(key);
230 TpuCompilationMetrics::SetCacheEntryCount(cache_.size());
231
232 auto parsed_key_or_status = ParseCompilationCacheKey(key);
233 CHECK(parsed_key_or_status.status().ok());
234 const TpuCompilationCacheKey parsed_key =
235 parsed_key_or_status.ConsumeValueOrDie();
236 if (!parsed_key.has_guaranteed_const) {
237 return erased;
238 }
239 session_key_map_.erase(
240 strings::StrCat(parsed_key.prefix, parsed_key.session_handle));
241 fingerprint_key_map_.erase(strings::StrCat(
242 parsed_key.prefix, parsed_key.guaranteed_const_fingerprint()));
243 return erased;
244 }
245
DiscardEntryRef(CompiledSubgraph * entry)246 CompiledSubgraph* TpuCompilationCacheInterface::DiscardEntryRef(
247 CompiledSubgraph* entry) {
248 if (entry->RefCountIsOne()) {
249 // The last reference to this entry is going away, so really delete it from
250 // the cache in such a way that it can't be restored by being looked up
251 // again.
252
253 // Sanity-check that it has been marked for eviction.
254 CHECK(entries_by_last_use_.find(entry->last_use) ==
255 entries_by_last_use_.end());
256 // Update the counter tracking how much space is taken up by entries that
257 // are marked for eviction.
258 marked_for_eviction_size_ -= entry->total_size;
259
260 // Remove the entry from the cache.
261 auto erased = RemoveEntry(entry->subgraph_key);
262
263 if (erased == 0) {
264 LOG(FATAL) << "Tried to discard nonexistent cache entry";
265 }
266 erased = entries_by_uid_.erase(entry->uid);
267 CHECK_EQ(erased, 1);
268 for (const std::string& key : entry->proto_key) {
269 erased = entries_by_proto_key_.erase(key);
270 CHECK_EQ(erased, 1);
271 }
272 // The actual deletion will happen outside the lock in UnloadAndDestroy().
273 return entry;
274 }
275 entry->Unref();
276 return nullptr;
277 }
278
MakePerStepRefHolder()279 CompilationRefHolder* TpuCompilationCacheInterface::MakePerStepRefHolder() {
280 return new RefHolder(this);
281 }
282
DiscardEntryRefs(gtl::ArraySlice<CompiledSubgraph * > entries)283 void TpuCompilationCacheInterface::DiscardEntryRefs(
284 gtl::ArraySlice<CompiledSubgraph*> entries) {
285 std::vector<CompiledSubgraph*> removed_entries;
286 {
287 absl::MutexLock lock(&mu_);
288
289 for (auto entry : entries) {
290 removed_entries.push_back(DiscardEntryRef(entry));
291 }
292
293 VLOG(1) << "After discarding entry refs cache is " << cache_.size()
294 << " entries (" << cache_size_ + marked_for_eviction_size_
295 << " bytes), marked for eviction "
296 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
297 << marked_for_eviction_size_ << " bytes).";
298 }
299 for (auto removed_entry : removed_entries) {
300 UnloadAndDestroy(removed_entry);
301 }
302 }
303
MarkOldestEntryForEviction()304 CompiledSubgraph* TpuCompilationCacheInterface::MarkOldestEntryForEviction() {
305 CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second;
306 VLOG(1) << "Marking " << entry_to_mark->subgraph_key
307 << " for eviction. Debug string: "
308 << entry_to_mark->cache_entry_debug_string;
309 entries_by_last_use_.erase(entry_to_mark->last_use);
310 cache_size_ -= entry_to_mark->total_size;
311 marked_for_eviction_size_ += entry_to_mark->total_size;
312 // Discard the cache's reference to entry. If steps are holding onto
313 // references to entry it won't be deleted until the last step holding it
314 // completes. It stays in the cache in the meantime and can be resurrected
315 // by a call to CompileIfKeyAbsent if that occurs before the last reference
316 // expires.
317 return DiscardEntryRef(entry_to_mark);
318 }
319
LookupEntryMarkedForEviction(CompiledSubgraph * entry,std::vector<CompiledSubgraph * > * removed_entries)320 void TpuCompilationCacheInterface::LookupEntryMarkedForEviction(
321 CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries) {
322 // The entry was previously marked for eviction (or is newly created) so
323 // unmark it. Add a reference (owned by the cache), update the cache size, and
324 // mark something old for eviction if necessary.
325 entry->Ref();
326 marked_for_eviction_size_ -= entry->total_size;
327 cache_size_ += entry->total_size;
328
329 // Mark the least-recently-used non-marked entry for eviction. Never mark the
330 // most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
331 // which means there's only one entry not already marked for eviction), so
332 // that an entry persists in the cache even if it is larger than the allocated
333 // cache size.
334 while (entries_by_last_use_.size() > 1 && cache_size_ > max_cache_size_) {
335 if (auto entry_to_evict = MarkOldestEntryForEviction()) {
336 removed_entries->push_back(entry_to_evict);
337 }
338 }
339 }
340
InsertEntry(const std::string & key,CompiledSubgraph * entry)341 void TpuCompilationCacheInterface::InsertEntry(const std::string& key,
342 CompiledSubgraph* entry) {
343 auto cache_inserted =
344 cache_.insert(std::pair<std::string, CompiledSubgraph*>(key, entry));
345 CHECK(cache_inserted.second);
346 TpuCompilationMetrics::SetCacheEntryCount(cache_.size());
347
348 auto parsed_key_or_status = ParseCompilationCacheKey(key);
349 CHECK(parsed_key_or_status.status().ok());
350 const TpuCompilationCacheKey parsed_key =
351 parsed_key_or_status.ConsumeValueOrDie();
352 if (!parsed_key.has_guaranteed_const) {
353 return;
354 }
355 session_key_map_.insert(std::make_pair(
356 strings::StrCat(parsed_key.prefix, parsed_key.session_handle), key));
357 fingerprint_key_map_.insert(
358 std::make_pair(strings::StrCat(parsed_key.prefix,
359 parsed_key.guaranteed_const_fingerprint()),
360 key));
361 }
362
CompileIfKeyAbsent(const TpuCompilationCacheKey & subgraph_key,const SessionMetadata * session_metadata,CompilationRefHolder * per_step_ref_holder,int64 * uid,std::vector<std::string> * proto_key,std::vector<std::string> * sharding_key,std::vector<bool> * may_modify_variables,absl::Span<const xla::HloProto * const> * hlo_metadatas,const std::function<Status (TpuProgramGroupInterface *)> & compile_function)363 Status TpuCompilationCacheInterface::CompileIfKeyAbsent(
364 const TpuCompilationCacheKey& subgraph_key,
365 const SessionMetadata* session_metadata,
366 CompilationRefHolder* per_step_ref_holder, int64* uid,
367 std::vector<std::string>* proto_key, std::vector<std::string>* sharding_key,
368 std::vector<bool>* may_modify_variables,
369 absl::Span<const xla::HloProto* const>* hlo_metadatas,
370 const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
371 std::vector<CompiledSubgraph*> removed_entries;
372 auto status = CompileIfKeyAbsentHelper(
373 subgraph_key, session_metadata, per_step_ref_holder, uid, proto_key,
374 sharding_key, may_modify_variables, &removed_entries, hlo_metadatas,
375 compile_function);
376 for (auto entry : removed_entries) {
377 UnloadAndDestroy(entry);
378 }
379 return status;
380 }
381
FindCacheKey(const TpuCompilationCacheKey & subgraph_key)382 std::string TpuCompilationCacheInterface::FindCacheKey(
383 const TpuCompilationCacheKey& subgraph_key) {
384 if (!subgraph_key.has_guaranteed_const) {
385 return subgraph_key.prefix;
386 }
387 auto iter = session_key_map_.find(
388 strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle));
389 if (iter != session_key_map_.end()) {
390 return iter->second;
391 }
392 iter = fingerprint_key_map_.find(strings::StrCat(
393 subgraph_key.prefix, subgraph_key.guaranteed_const_fingerprint()));
394 if (iter != session_key_map_.end()) {
395 return iter->second;
396 }
397 VLOG(1) << "No matching cache key found for key " << subgraph_key.ToString();
398 return "";
399 }
400
CompileIfKeyAbsentHelper(const TpuCompilationCacheKey & subgraph_key,const SessionMetadata * session_metadata,CompilationRefHolder * per_step_ref_holder,int64 * uid,std::vector<std::string> * proto_key,std::vector<std::string> * sharding_key,std::vector<bool> * may_modify_variables,std::vector<CompiledSubgraph * > * removed_entries,absl::Span<const xla::HloProto * const> * hlo_metadatas,const std::function<Status (TpuProgramGroupInterface *)> & compile_function)401 Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
402 const TpuCompilationCacheKey& subgraph_key,
403 const SessionMetadata* session_metadata,
404 CompilationRefHolder* per_step_ref_holder, int64* uid,
405 std::vector<std::string>* proto_key, std::vector<std::string>* sharding_key,
406 std::vector<bool>* may_modify_variables,
407 std::vector<CompiledSubgraph*>* removed_entries,
408 absl::Span<const xla::HloProto* const>* hlo_metadatas,
409 const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
410 CompiledSubgraph* entry = nullptr;
411
412 profiler::TraceMe subgraph_lookup_traceme(
413 "TPU compilation cache subgraph lookup",
414 /*level=*/2);
415
416 // NOTE: In spite of the fact that we use MutexLock, we do not hold the lock
417 // for the lifetime of the object, see InitializeEntry() call below.
418 absl::MutexLock lock(&mu_);
419
420 std::string cache_key = FindCacheKey(subgraph_key);
421 auto iter = cache_.find(cache_key);
422 bool is_new_key = iter == cache_.end();
423
424 const std::string session_name =
425 tpu::SessionNameFromMetadata(session_metadata);
426
427 if (is_new_key) {
428 cache_key = subgraph_key.ToString();
429 TpuCompilationMetrics::IncrementCacheLookupCount(
430 /*is_cache_hit=*/false, session_name);
431 const std::string msg =
432 strings::StrCat("TPU host compilation cache miss: cache_key(",
433 cache_key, "), session_name(", session_name, ")");
434 TRACESTRING(msg);
435 LOG(INFO) << msg;
436
437 // Check if caller has disabled compilation. Set using
438 // internal::ScopedTpuCompileDisabler.
439 if (!OpsApiFn()->TpuCompile_IsTpuCompilationEnabledFn()) {
440 const std::string error_msg = strings::StrCat(
441 "[TpuCompilationDisabled]: Compilation cache miss, but compilation "
442 "disabled, session_name(",
443 session_name, ") Debug String: ", subgraph_key.debug_string);
444 if (VLOG_IS_ON(2)) {
445 VLOG(2) << "Cache Missed. Current cache entries: ";
446 for (auto it = cache_.begin(); it != cache_.end(); ++it) {
447 VLOG(2) << "Cache Debug Info: ";
448 VLOG(2) << it->second->cache_entry_debug_string;
449 }
450 }
451
452 LOG_EVERY_N_SEC(WARNING, 30) << error_msg;
453 return errors::NotFound(error_msg);
454 }
455
456 // The single ref on the newly-created entry is owned by the caller.
457 VLOG(1) << "Before adding new entry for key " << cache_key
458 << " with session_name( " << session_name << ");"
459 << "; cache is " << cache_.size() << " entries ("
460 << cache_size_ + marked_for_eviction_size_ << " bytes), "
461 << " marked for eviction "
462 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
463 << marked_for_eviction_size_ << " bytes).";
464 // Note that InitializeEntry() will Release/Reacquire mu_.
465 entry = InitializeEntry(cache_key, compile_function, subgraph_key);
466 bool compilation_success = entry->tpu_program_group->program_count() > 0;
467 TRACELITERAL("TPU host compilation cache: compilation done.");
468 LOG(INFO) << strings::StrCat(
469 "TPU host compilation cache: compilation ",
470 compilation_success ? "complete" : "failed", " for cache_key(",
471 cache_key, "), session_name(", session_name, "), subgraph_key(",
472 subgraph_key.debug_string, ")");
473 // If session_name is present, log some additional stats related to HBM
474 // here, so that they can be associated directly to the session.
475 if (!session_name.empty()) {
476 entry->tpu_program_group->LogProgramMemorySummary();
477 }
478 } else {
479 TpuCompilationMetrics::IncrementCacheLookupCount(
480 /*is_cache_hit=*/true, session_name);
481 const std::string msg =
482 strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key,
483 "), session_name(", session_name, ")");
484 TRACESTRING(msg);
485 VLOG(1) << msg;
486 VLOG(1) << "Before refreshing entry for key " << cache_key
487 << " with session_name( " << session_name << "); cache is "
488 << cache_.size() << " entries ("
489 << cache_size_ + marked_for_eviction_size_ << " bytes), "
490 << " marked for eviction "
491 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
492 << marked_for_eviction_size_ << " bytes).";
493 entry = iter->second;
494 // Make a new reference that is owned by the caller.
495 entry->Ref();
496 // Block if necessary until the subgraph has been initialized.
497 mu_.Await(absl::Condition(
498 +[](CompiledSubgraph* e) { return e->initialized; }, entry));
499 }
500
501 // Let the caller know the uid of the entry.
502 *uid = entry->uid;
503 // Let the caller know the keys for each of the cached protos.
504 *proto_key = entry->proto_key;
505 *sharding_key = entry->sharding_key;
506 *may_modify_variables = entry->tpu_program_group->may_modify_variables_list();
507 *hlo_metadatas = entry->tpu_program_group->hlo_metadatas();
508
509 // If the caller didn't supply a per_step_ref_holder then the caller is going
510 // to manually release the reference later via a call to Release().
511 if (per_step_ref_holder == nullptr) {
512 ++entry->external_references;
513 } else {
514 // The caller wants its reference to be handed off to a per-step holder that
515 // will discard the reference when the step completes.
516 RefHolder* cast_ref_holder =
517 tensorflow::down_cast<RefHolder*>(per_step_ref_holder);
518 CHECK_NE(cast_ref_holder, nullptr);
519 cast_ref_holder->AddRef(entry);
520 }
521
522 // Remove the old LRU-table entry if it wasn't already marked for eviction.
523 auto erased = entries_by_last_use_.erase(entry->last_use);
524 // Update the LRU table indicating this entry is the most recently used.
525 entry->last_use = use_counter_++;
526 entries_by_last_use_[entry->last_use] = entry;
527 if (erased == 0) {
528 // The entry had been marked for eviction, or is newly created.
529 LookupEntryMarkedForEviction(entry, removed_entries);
530 }
531
532 // Log a little more verbosely when a key is added.
533 if (VLOG_IS_ON(1) || is_new_key) {
534 LOG(INFO) << "After " << (is_new_key ? "adding" : "refreshing")
535 << " entry for key " << cache_key << " with session_name "
536 << session_name << " cache is " << cache_.size() << " entries ("
537 << cache_size_ + marked_for_eviction_size_ << " bytes), "
538 << " marked for eviction "
539 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
540 << marked_for_eviction_size_ << " bytes).";
541 }
542 return entry->initialization_status;
543 }
544
GetKeysFromUid(int64_t uid,std::vector<std::string> * keys)545 Status TpuCompilationCacheInterface::GetKeysFromUid(
546 int64_t uid, std::vector<std::string>* keys) {
547 keys->clear();
548
549 absl::MutexLock lock(&mu_);
550 const auto iter = entries_by_uid_.find(uid);
551 if (iter == entries_by_uid_.end()) {
552 return errors::NotFound("No subgraph found for uid ", uid);
553 }
554 *keys = iter->second->proto_key;
555 return Status::OK();
556 }
557
Lookup(int64_t uid,int proto_index,std::unique_ptr<CompilationCacheEntryRef> * entry)558 Status TpuCompilationCacheInterface::Lookup(
559 int64_t uid, int proto_index,
560 std::unique_ptr<CompilationCacheEntryRef>* entry) {
561 entry->reset();
562
563 profiler::TraceMe proto_lookup_traceme(
564 "TPU compilation cache proto lookup by uid",
565 /*level=*/2);
566
567 absl::MutexLock lock(&mu_);
568 const auto iter = entries_by_uid_.find(uid);
569 if (iter == entries_by_uid_.end()) {
570 return errors::NotFound("No subgraph found for uid ", uid);
571 }
572 CompiledSubgraph* cache_entry = iter->second;
573 if (proto_index < 0 ||
574 proto_index >= cache_entry->tpu_program_group->program_count()) {
575 return errors::NotFound("No proto found for core index ", proto_index,
576 " in subgraph with uid ", uid);
577 }
578 *entry = absl::make_unique<CompilationCacheEntryRef>(this, cache_entry,
579 proto_index);
580 return Status::OK();
581 }
582
Lookup(const std::string & proto_key,std::unique_ptr<CompilationCacheEntryRef> * entry)583 Status TpuCompilationCacheInterface::Lookup(
584 const std::string& proto_key,
585 std::unique_ptr<CompilationCacheEntryRef>* entry) {
586 entry->reset();
587
588 profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
589 /*level=*/2);
590
591 absl::MutexLock lock(&mu_);
592 const auto iter = entries_by_proto_key_.find(proto_key);
593 if (iter == entries_by_proto_key_.end()) {
594 return errors::NotFound("No proto found for key ", proto_key);
595 }
596 CompiledSubgraph* cache_entry = iter->second.first;
597 int proto_index = iter->second.second;
598 *entry = absl::make_unique<CompilationCacheEntryRef>(this, cache_entry,
599 proto_index);
600 return Status::OK();
601 }
602 } // namespace tpu
603 } // namespace tensorflow
604