1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h"
16
17 #include "grpcpp/security/credentials.h"
18 #include "absl/strings/str_cat.h"
19 #include "absl/time/time.h"
20 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
21 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
22
23 namespace tensorflow {
24 namespace tpu {
25 namespace {
26
27 #if defined(LIBTPU_ON_GCE)
28 using ResponseType = GetTpuProgramResponseExternal;
29 #else
30 using ResponseType = GetTpuProgramResponse;
31 #endif
32
33 static constexpr absl::Duration kProtoTimeout = absl::Minutes(15);
TimeToGprTimespec(absl::Time time)34 static gpr_timespec TimeToGprTimespec(absl::Time time) {
35 if (time == absl::InfiniteFuture()) {
36 return gpr_inf_future(GPR_CLOCK_REALTIME);
37 }
38 if (time == absl::InfinitePast()) {
39 return gpr_inf_past(GPR_CLOCK_REALTIME);
40 }
41
42 gpr_timespec spec;
43 timespec t = absl::ToTimespec(time);
44 spec.tv_sec = t.tv_sec;
45 spec.tv_nsec = static_cast<int32_t>(t.tv_nsec);
46 spec.clock_type = GPR_CLOCK_REALTIME;
47 return spec;
48 }
49 } // namespace
TpuCompilationCacheRpcLookup(const std::string & server_address,int64 max_cache_size)50 TpuCompilationCacheRpcLookup::TpuCompilationCacheRpcLookup(
51 const std::string& server_address, int64 max_cache_size)
52 : max_cache_size_(max_cache_size) {
53 // Ensure that large TPU program can get sent over the channel.
54 ::grpc::ChannelArguments args;
55 args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
56 auto channel =
57 ::grpc::CreateCustomChannel(absl::StrCat("dns:///", server_address),
58 CreateChannelCredentials(), args);
59 stub_ = tpu::grpc::TpuCompilationCacheService::NewStub(channel);
60 VLOG(1) << "Created RPC lookup cache size " << max_cache_size_ << " bytes.";
61 }
62
Lookup(const std::string & proto_key,std::unique_ptr<CompilationCacheEntryRef> * entry,tpu::CompilationCacheFetchTarget fetch_target)63 Status TpuCompilationCacheRpcLookup::Lookup(
64 const std::string& proto_key,
65 std::unique_ptr<CompilationCacheEntryRef>* entry,
66 tpu::CompilationCacheFetchTarget fetch_target) {
67 profiler::TraceMe proto_lookup_traceme("Remote TPU proto cache lookup",
68 /*level=*/2);
69 entry->reset();
70 std::shared_ptr<CacheEntry> cache_entry;
71 // Keep a reference to CacheEntry objects evicted from the cache so that the
72 // potential deletion happens outside the lock upon method exit.
73 std::vector<std::shared_ptr<CacheEntry>> removed_entries;
74
75 std::string local_proto_key = absl::StrCat(
76 proto_key, "_", tpu::CompilationCacheFetchTarget_Name(fetch_target));
77
78 {
79 absl::MutexLock lock(&mu_);
80 auto iter = cache_.find(local_proto_key);
81 if (iter == cache_.end()) {
82 tpu::GetTpuProgramRequest request;
83 request.set_key(proto_key);
84 request.set_fetch_target(fetch_target);
85 TF_RETURN_IF_ERROR(
86 RemoteLookupLocked(local_proto_key, request, &cache_entry));
87 } else {
88 VLOG(1) << "Found key " << local_proto_key << " in local proto cache.";
89 cache_entry = iter->second;
90 auto erased = entries_by_last_use_.erase(cache_entry->last_use);
91 CHECK_EQ(erased, 1);
92 }
93 PostLookupLocked(&cache_entry, entry, &removed_entries);
94 }
95 return Status::OK();
96 }
97
Lookup(int64 uid,int proto_index,std::unique_ptr<CompilationCacheEntryRef> * entry,tpu::CompilationCacheFetchTarget fetch_target)98 Status TpuCompilationCacheRpcLookup::Lookup(
99 int64 uid, int proto_index,
100 std::unique_ptr<CompilationCacheEntryRef>* entry,
101 tpu::CompilationCacheFetchTarget fetch_target) {
102 profiler::TraceMe proto_lookup_traceme("Remote TPU proto cache lookup by uid",
103 /*level=*/2);
104 entry->reset();
105 std::shared_ptr<CacheEntry> cache_entry;
106 // Keep a reference to CacheEntry objects evicted from the cache so that the
107 // potential deletion happens outside the lock upon method exit.
108 std::vector<std::shared_ptr<CacheEntry>> removed_entries;
109
110 // Make a string key so that we can uniformly store cached entries under
111 // string keys whether they are looked up by proto_key or uid+index. The
112 // expectation is that any given executable will only ever be looked up
113 // *either* by proto_key *or* by uid+index, so we are not concerned that the
114 // same proto could be placed in the cache twice if it is looked up by both
115 // methods.
116 std::string local_proto_key =
117 absl::StrCat(" _ ", uid, ":", proto_index, "_",
118 tpu::CompilationCacheFetchTarget_Name(fetch_target));
119 {
120 absl::MutexLock lock(&mu_);
121 auto iter = cache_.find(local_proto_key);
122 if (iter == cache_.end()) {
123 tpu::GetTpuProgramRequest request;
124 tpu::TpuCompilationUidAndIndex* uid_and_index =
125 request.mutable_uid_and_index();
126 uid_and_index->set_uid(uid);
127 uid_and_index->set_proto_index(proto_index);
128 request.set_fetch_target(fetch_target);
129 TF_RETURN_IF_ERROR(
130 RemoteLookupLocked(local_proto_key, request, &cache_entry));
131 } else {
132 VLOG(1) << "Found uid " << uid << " and index " << proto_index
133 << " in local proto cache.";
134 cache_entry = iter->second;
135 auto erased = entries_by_last_use_.erase(cache_entry->last_use);
136 CHECK_EQ(erased, 1);
137 }
138 PostLookupLocked(&cache_entry, entry, &removed_entries);
139 }
140 return Status::OK();
141 }
142
RemoteLookupLocked(const std::string & local_proto_key,const tpu::GetTpuProgramRequest & request,std::shared_ptr<CacheEntry> * cache_entry)143 Status TpuCompilationCacheRpcLookup::RemoteLookupLocked(
144 const std::string& local_proto_key,
145 const tpu::GetTpuProgramRequest& request,
146 std::shared_ptr<CacheEntry>* cache_entry) {
147 profiler::TraceMe proto_lookup_traceme("Remote TPU proto cache fetch",
148 /*level=*/2);
149 // Perform the RPC while holding the lock unless it is demonstrated that
150 // this causes a performance problem.
151 ::grpc::ClientContext client_context;
152 client_context.set_deadline(TimeToGprTimespec(::absl::Now() + kProtoTimeout));
153 client_context.set_compression_algorithm(GRPC_COMPRESS_GZIP);
154
155 ResponseType response;
156 Status s =
157 FromGrpcStatus(stub_->GetTpuProgram(&client_context, request, &response));
158 VLOG(1) << "Looked up key " << local_proto_key
159 << " in remote subgraph cache status " << s;
160 TF_RETURN_IF_ERROR(s);
161
162 TF_RETURN_IF_ERROR(DeserializeRpcResponseToCacheEntry(
163 local_proto_key, &response, cache_entry));
164 cache_.emplace(local_proto_key, (*cache_entry));
165 cache_size_ += (*cache_entry)->size;
166
167 return Status::OK();
168 }
169
PostLookupLocked(std::shared_ptr<CacheEntry> * cache_entry,std::unique_ptr<CompilationCacheEntryRef> * entry,std::vector<std::shared_ptr<CacheEntry>> * removed_entries)170 void TpuCompilationCacheRpcLookup::PostLookupLocked(
171 std::shared_ptr<CacheEntry>* cache_entry,
172 std::unique_ptr<CompilationCacheEntryRef>* entry,
173 std::vector<std::shared_ptr<CacheEntry>>* removed_entries) {
174 (*cache_entry)->last_use = use_counter_++;
175 entries_by_last_use_[(*cache_entry)->last_use] = cache_entry->get();
176 *entry =
177 std::unique_ptr<CompilationCacheEntryRef>(new CacheWrapper(*cache_entry));
178
179 // Evict overflowing entries if necessary, but never evict the most recently
180 // used entry.
181 while (entries_by_last_use_.size() > 1 && cache_size_ > max_cache_size_) {
182 auto entry_to_evict = entries_by_last_use_.begin()->second;
183 entries_by_last_use_.erase(entry_to_evict->last_use);
184 CHECK_GE(cache_size_, entry_to_evict->size);
185 cache_size_ -= entry_to_evict->size;
186 // Delete the cache's reference to the entry, though clients may still be
187 // holding onto references. We use 'removed_entries' to delay the possible
188 // CacheEntry destruction until the mu_ lock is released.
189 auto entry_to_evict_it = cache_.find(entry_to_evict->key);
190 CHECK(entry_to_evict_it != cache_.end())
191 << "Missing entry key: " << entry_to_evict->key;
192 removed_entries->push_back(entry_to_evict_it->second);
193 cache_.erase(entry_to_evict_it);
194 }
195 }
196
DebugString() const197 std::string TpuCompilationCacheRpcLookup::DebugString() const {
198 return "TpuCompilationCacheRpcLookup";
199 }
200 } // namespace tpu
201 } // namespace tensorflow
202