• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h"
16 
17 #include "grpcpp/security/credentials.h"
18 #include "absl/strings/str_cat.h"
19 #include "absl/time/time.h"
20 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
21 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
22 
23 namespace tensorflow {
24 namespace tpu {
25 namespace {
26 
27 #if defined(LIBTPU_ON_GCE)
28 using ResponseType = GetTpuProgramResponseExternal;
29 #else
30 using ResponseType = GetTpuProgramResponse;
31 #endif
32 
33 static constexpr absl::Duration kProtoTimeout = absl::Minutes(15);
TimeToGprTimespec(absl::Time time)34 static gpr_timespec TimeToGprTimespec(absl::Time time) {
35   if (time == absl::InfiniteFuture()) {
36     return gpr_inf_future(GPR_CLOCK_REALTIME);
37   }
38   if (time == absl::InfinitePast()) {
39     return gpr_inf_past(GPR_CLOCK_REALTIME);
40   }
41 
42   gpr_timespec spec;
43   timespec t = absl::ToTimespec(time);
44   spec.tv_sec = t.tv_sec;
45   spec.tv_nsec = static_cast<int32_t>(t.tv_nsec);
46   spec.clock_type = GPR_CLOCK_REALTIME;
47   return spec;
48 }
49 }  // namespace
TpuCompilationCacheRpcLookup(const std::string & server_address,int64 max_cache_size)50 TpuCompilationCacheRpcLookup::TpuCompilationCacheRpcLookup(
51     const std::string& server_address, int64 max_cache_size)
52     : max_cache_size_(max_cache_size) {
53   // Ensure that large TPU program can get sent over the channel.
54   ::grpc::ChannelArguments args;
55   args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
56   auto channel =
57       ::grpc::CreateCustomChannel(absl::StrCat("dns:///", server_address),
58                                   CreateChannelCredentials(), args);
59   stub_ = tpu::grpc::TpuCompilationCacheService::NewStub(channel);
60   VLOG(1) << "Created RPC lookup cache size " << max_cache_size_ << " bytes.";
61 }
62 
Lookup(const std::string & proto_key,std::unique_ptr<CompilationCacheEntryRef> * entry,tpu::CompilationCacheFetchTarget fetch_target)63 Status TpuCompilationCacheRpcLookup::Lookup(
64     const std::string& proto_key,
65     std::unique_ptr<CompilationCacheEntryRef>* entry,
66     tpu::CompilationCacheFetchTarget fetch_target) {
67   profiler::TraceMe proto_lookup_traceme("Remote TPU proto cache lookup",
68                                          /*level=*/2);
69   entry->reset();
70   std::shared_ptr<CacheEntry> cache_entry;
71   // Keep a reference to CacheEntry objects evicted from the cache so that the
72   // potential deletion happens outside the lock upon method exit.
73   std::vector<std::shared_ptr<CacheEntry>> removed_entries;
74 
75   std::string local_proto_key = absl::StrCat(
76       proto_key, "_", tpu::CompilationCacheFetchTarget_Name(fetch_target));
77 
78   {
79     absl::MutexLock lock(&mu_);
80     auto iter = cache_.find(local_proto_key);
81     if (iter == cache_.end()) {
82       tpu::GetTpuProgramRequest request;
83       request.set_key(proto_key);
84       request.set_fetch_target(fetch_target);
85       TF_RETURN_IF_ERROR(
86           RemoteLookupLocked(local_proto_key, request, &cache_entry));
87     } else {
88       VLOG(1) << "Found key " << local_proto_key << " in local proto cache.";
89       cache_entry = iter->second;
90       auto erased = entries_by_last_use_.erase(cache_entry->last_use);
91       CHECK_EQ(erased, 1);
92     }
93     PostLookupLocked(&cache_entry, entry, &removed_entries);
94   }
95   return Status::OK();
96 }
97 
Lookup(int64 uid,int proto_index,std::unique_ptr<CompilationCacheEntryRef> * entry,tpu::CompilationCacheFetchTarget fetch_target)98 Status TpuCompilationCacheRpcLookup::Lookup(
99     int64 uid, int proto_index,
100     std::unique_ptr<CompilationCacheEntryRef>* entry,
101     tpu::CompilationCacheFetchTarget fetch_target) {
102   profiler::TraceMe proto_lookup_traceme("Remote TPU proto cache lookup by uid",
103                                          /*level=*/2);
104   entry->reset();
105   std::shared_ptr<CacheEntry> cache_entry;
106   // Keep a reference to CacheEntry objects evicted from the cache so that the
107   // potential deletion happens outside the lock upon method exit.
108   std::vector<std::shared_ptr<CacheEntry>> removed_entries;
109 
110   // Make a string key so that we can uniformly store cached entries under
111   // string keys whether they are looked up by proto_key or uid+index. The
112   // expectation is that any given executable will only ever be looked up
113   // *either* by proto_key *or* by uid+index, so we are not concerned that the
114   // same proto could be placed in the cache twice if it is looked up by both
115   // methods.
116   std::string local_proto_key =
117       absl::StrCat(" _ ", uid, ":", proto_index, "_",
118                    tpu::CompilationCacheFetchTarget_Name(fetch_target));
119   {
120     absl::MutexLock lock(&mu_);
121     auto iter = cache_.find(local_proto_key);
122     if (iter == cache_.end()) {
123       tpu::GetTpuProgramRequest request;
124       tpu::TpuCompilationUidAndIndex* uid_and_index =
125           request.mutable_uid_and_index();
126       uid_and_index->set_uid(uid);
127       uid_and_index->set_proto_index(proto_index);
128       request.set_fetch_target(fetch_target);
129       TF_RETURN_IF_ERROR(
130           RemoteLookupLocked(local_proto_key, request, &cache_entry));
131     } else {
132       VLOG(1) << "Found uid " << uid << " and index " << proto_index
133               << " in local proto cache.";
134       cache_entry = iter->second;
135       auto erased = entries_by_last_use_.erase(cache_entry->last_use);
136       CHECK_EQ(erased, 1);
137     }
138     PostLookupLocked(&cache_entry, entry, &removed_entries);
139   }
140   return Status::OK();
141 }
142 
RemoteLookupLocked(const std::string & local_proto_key,const tpu::GetTpuProgramRequest & request,std::shared_ptr<CacheEntry> * cache_entry)143 Status TpuCompilationCacheRpcLookup::RemoteLookupLocked(
144     const std::string& local_proto_key,
145     const tpu::GetTpuProgramRequest& request,
146     std::shared_ptr<CacheEntry>* cache_entry) {
147   profiler::TraceMe proto_lookup_traceme("Remote TPU proto cache fetch",
148                                          /*level=*/2);
149   // Perform the RPC while holding the lock unless it is demonstrated that
150   // this causes a performance problem.
151   ::grpc::ClientContext client_context;
152   client_context.set_deadline(TimeToGprTimespec(::absl::Now() + kProtoTimeout));
153   client_context.set_compression_algorithm(GRPC_COMPRESS_GZIP);
154 
155   ResponseType response;
156   Status s =
157       FromGrpcStatus(stub_->GetTpuProgram(&client_context, request, &response));
158   VLOG(1) << "Looked up key " << local_proto_key
159           << " in remote subgraph cache status " << s;
160   TF_RETURN_IF_ERROR(s);
161 
162   TF_RETURN_IF_ERROR(DeserializeRpcResponseToCacheEntry(
163       local_proto_key, &response, cache_entry));
164   cache_.emplace(local_proto_key, (*cache_entry));
165   cache_size_ += (*cache_entry)->size;
166 
167   return Status::OK();
168 }
169 
PostLookupLocked(std::shared_ptr<CacheEntry> * cache_entry,std::unique_ptr<CompilationCacheEntryRef> * entry,std::vector<std::shared_ptr<CacheEntry>> * removed_entries)170 void TpuCompilationCacheRpcLookup::PostLookupLocked(
171     std::shared_ptr<CacheEntry>* cache_entry,
172     std::unique_ptr<CompilationCacheEntryRef>* entry,
173     std::vector<std::shared_ptr<CacheEntry>>* removed_entries) {
174   (*cache_entry)->last_use = use_counter_++;
175   entries_by_last_use_[(*cache_entry)->last_use] = cache_entry->get();
176   *entry =
177       std::unique_ptr<CompilationCacheEntryRef>(new CacheWrapper(*cache_entry));
178 
179   // Evict overflowing entries if necessary, but never evict the most recently
180   // used entry.
181   while (entries_by_last_use_.size() > 1 && cache_size_ > max_cache_size_) {
182     auto entry_to_evict = entries_by_last_use_.begin()->second;
183     entries_by_last_use_.erase(entry_to_evict->last_use);
184     CHECK_GE(cache_size_, entry_to_evict->size);
185     cache_size_ -= entry_to_evict->size;
186     // Delete the cache's reference to the entry, though clients may still be
187     // holding onto references. We use 'removed_entries' to delay the possible
188     // CacheEntry destruction until the mu_ lock is released.
189     auto entry_to_evict_it = cache_.find(entry_to_evict->key);
190     CHECK(entry_to_evict_it != cache_.end())
191         << "Missing entry key: " << entry_to_evict->key;
192     removed_entries->push_back(entry_to_evict_it->second);
193     cache_.erase(entry_to_evict_it);
194   }
195 }
196 
DebugString() const197 std::string TpuCompilationCacheRpcLookup::DebugString() const {
198   return "TpuCompilationCacheRpcLookup";
199 }
200 }  // namespace tpu
201 }  // namespace tensorflow
202