1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
17 #include "absl/types/optional.h"
18 #include "tensorflow/core/platform/env.h"
19
20 namespace tensorflow {
21
QueueRequest(int64 request_id,int64 step_id,const FinishResponseCB & cb)22 bool GrpcResponseCache::QueueRequest(int64 request_id, int64 step_id,
23 const FinishResponseCB& cb) {
24 VLOG(1) << "GrpcResponseCache Lookup " << request_id;
25
26 mu_.lock();
27
28 ResponseCacheEntry& entry = response_cache_[request_id];
29
30 if (entry.state == ResponseCacheEntry::State::FINISHED) {
31 VLOG(1) << "Reuse cached response for " << request_id;
32 // Make a copy of the ResponseCacheEntry so that we can run FinishResponse
33 // outside the critical section. FinishResponse can be potentially
34 // expensive.
35 auto entry_copy = entry;
36
37 mu_.unlock();
38 entry_copy.FinishResponse(cb);
39 return true;
40 }
41
42 entry.callbacks.emplace_back(cb);
43
44 if (entry.state == ResponseCacheEntry::State::ACTIVE) {
45 VLOG(1) << "Found active request for " << request_id
46 << ". Adding entry to response queue.";
47 mu_.unlock();
48 return true;
49 } else {
50 VLOG(2) << "No cache entry for " << request_id
51 << ", running user computation.";
52 entry.step_id = step_id;
53 entry.state = ResponseCacheEntry::State::ACTIVE;
54 mu_.unlock();
55 return false;
56 }
57 }
58
OnRequestFinished(int64 request_id,const Tensor & tensor,bool is_dead,const Status & status)59 void GrpcResponseCache::OnRequestFinished(int64 request_id,
60 const Tensor& tensor, bool is_dead,
61 const Status& status) {
62 absl::optional<ResponseCacheEntry> entry_copy;
63
64 {
65 mutex_lock m(mu_);
66
67 auto it = response_cache_.find(request_id);
68 if (it == response_cache_.end()) {
69 LOG(ERROR) << "Unexpected missing response cache entry for request "
70 << request_id;
71 return;
72 }
73 ResponseCacheEntry& entry = it->second;
74
75 VLOG(1) << "Operation for " << request_id << " finished. "
76 << "Status: " << status << ", tensor size " << tensor.TotalBytes()
77 << " bytes, " << entry.callbacks.size() << " pending callbacks.";
78
79 entry.tensor = tensor;
80 entry.is_dead = is_dead;
81 entry.response_status = status;
82 entry.state = ResponseCacheEntry::State::FINISHED;
83
84 // We copy the extra work out of the critical section in order to avoid
85 // serializing the work for sending response.
86 entry_copy = entry;
87
88 entry.callbacks.clear();
89 }
90
91 for (auto& cb : entry_copy->callbacks) {
92 entry_copy->FinishResponse(cb);
93 }
94 }
95
EraseRequestId(int64 request_id)96 void GrpcResponseCache::EraseRequestId(int64 request_id) {
97 mutex_lock m(mu_);
98 response_cache_.erase(request_id);
99 }
100
CleanEntriesForStep(int64 step_id)101 void GrpcResponseCache::CleanEntriesForStep(int64 step_id) {
102 mutex_lock m(mu_);
103 // Remove all cache entries whose step id is the given step_id
104 for (auto it = response_cache_.begin(), last = response_cache_.end();
105 it != last;) {
106 if (it->second.step_id == step_id) {
107 VLOG(1) << "Erase stale GrpcResponseCache entry " << it->first;
108 it = response_cache_.erase(it);
109 } else {
110 ++it;
111 }
112 }
113 }
114
115 } // namespace tensorflow
116