1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_RPC_LOOKUP_H_ 16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_RPC_LOOKUP_H_ 17 18 #include <map> 19 #include <memory> 20 #include <string> 21 #include <unordered_map> 22 #include <vector> 23 24 #include "absl/synchronization/mutex.h" 25 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" 26 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h" 27 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" 28 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h" 29 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h" 30 #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" 31 32 namespace tensorflow { 33 namespace tpu { 34 35 // Class for looking up and caching TPU program via RPC. 36 class TpuCompilationCacheRpcLookup : public TpuCompilationCacheLookup { 37 public: 38 using StubType = tpu::grpc::TpuCompilationCacheService::Stub; 39 40 TpuCompilationCacheRpcLookup(const string& server_address, 41 int64 max_cache_size); 42 ~TpuCompilationCacheRpcLookup() override = default; 43 44 Status Lookup(const string& proto_key, 45 std::unique_ptr<tpu::CompilationCacheEntryRef>* entry, 46 tpu::CompilationCacheFetchTarget fetch_target) override; 47 48 Status Lookup(int64 uid, int proto_index, 49 std::unique_ptr<tpu::CompilationCacheEntryRef>* entry, 50 tpu::CompilationCacheFetchTarget fetch_target) override; 51 52 string DebugString() const override; 53 54 private: 55 // Helper method to make the RPC request to the central cache. 56 Status RemoteLookupLocked(const string& local_proto_key, 57 const tpu::GetTpuProgramRequest& request, 58 std::shared_ptr<CacheEntry>* cache_entry) 59 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 60 61 // Helper method to adjust datastructures after a cache lookup. 62 // We use `removed_entries` so that actual CacheEntry destruction happens 63 // outside the lock. 64 void PostLookupLocked( 65 std::shared_ptr<CacheEntry>* cache_entry, 66 std::unique_ptr<tpu::CompilationCacheEntryRef>* entry, 67 std::vector<std::shared_ptr<CacheEntry>>* removed_entries) 68 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 69 70 // The maximum size of entries that are stored in the cache before entries are 71 // evicted. 72 const int64 max_cache_size_; 73 74 std::unique_ptr<StubType> stub_; 75 76 // Protect concurrent access to member variables below. 77 mutable absl::Mutex mu_; 78 79 // The total size of entries in the cache. 80 int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0; 81 // The value to assign to the last_use field of the next entry that is looked 82 // up. 83 int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0; 84 // The entries that can be looked up in the cache. An entry is deleted from 85 // the cache as soon as it is evicted, but the underlying shared_ptr won't be 86 // freed until any wrappers holding it go out of scope. 87 std::unordered_map<std::string, std::shared_ptr<CacheEntry>> cache_ 88 ABSL_GUARDED_BY(mu_); 89 // Map from last_use to entry, used to evict entries in LRU order. 90 std::map<int64, CacheEntry*> entries_by_last_use_ ABSL_GUARDED_BY(mu_); 91 }; 92 } // namespace tpu 93 } // namespace tensorflow 94 95 #endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_RPC_LOOKUP_H_ 96