1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
16
17 #include "absl/strings/str_format.h"
18
19 namespace tensorflow {
20
ToString(UntypedStreamingRPCState::Tag::TagType tag_type)21 const char* ToString(UntypedStreamingRPCState::Tag::TagType tag_type) {
22 switch (tag_type) {
23 case UntypedStreamingRPCState::Tag::TagType::kCallStarted:
24 return "kCallStarted";
25 case UntypedStreamingRPCState::Tag::TagType::kRequestWriteCompleted:
26 return "kRequestWriteCompleted";
27 case UntypedStreamingRPCState::Tag::TagType::kResponseReadCompleted:
28 return "kResponseReadCompleted";
29 case UntypedStreamingRPCState::Tag::TagType::kCallFinished:
30 return "kCallFinished";
31 }
32 }
33
Tag(UntypedStreamingRPCState * streaming_state,Tag::TagType type)34 UntypedStreamingRPCState::Tag::Tag(UntypedStreamingRPCState* streaming_state,
35 Tag::TagType type)
36 : streaming_state_(streaming_state), type_(type) {}
37
OnCompleted(bool ok)38 void UntypedStreamingRPCState::Tag::OnCompleted(bool ok) {
39 switch (type_) {
40 case TagType::kCallStarted:
41 streaming_state_->CallStarted(ok);
42 break;
43 case TagType::kRequestWriteCompleted:
44 streaming_state_->RequestWriteCompleted(ok);
45 break;
46 case TagType::kResponseReadCompleted:
47 streaming_state_->ResponseReadCompleted(ok);
48 break;
49 case TagType::kCallFinished:
50 streaming_state_->CallFinished(ok);
51 break;
52 }
53 streaming_state_->Unref(); // Ref acquired when tag was handed to grpc.
54 }
55
Complete(Status status)56 void Exchange::Complete(Status status) {
57 if (status.ok()) {
58 if (!GrpcMaybeParseProto(&response_buf_, response_)) {
59 status.Update(errors::Internal("could not parse rpc response"));
60 }
61 }
62 VLOG(3) << "Completing exchange " << DebugString() << " with "
63 << status.ToString();
64 cb_(status);
65 }
66
operator <<(std::ostream & os,const Exchange::State & state)67 std::ostream& operator<<(std::ostream& os, const Exchange::State& state) {
68 os << ToString(state);
69 return os;
70 }
71
ToString(Exchange::State state)72 const char* ToString(Exchange::State state) {
73 switch (state) {
74 case Exchange::State::kExchangeCreated:
75 return "ExchangeCreated";
76 case Exchange::State::kRequestWriteIssued:
77 return "RequestWriteIssued";
78 case Exchange::State::kRequestWriteCompleted:
79 return "RequestWriteCompleted";
80 case Exchange::State::kResponseReadIssued:
81 return "ResponseReadIssued";
82 }
83 }
84
DebugString() const85 string Exchange::DebugString() const {
86 return absl::StrFormat("%p@%s_%s", this, ToString(state_), debug_string_);
87 }
88
Emplace(const::grpc::ByteBuffer & request_buf,protobuf::Message * response,StatusCallback cb,string debug_string)89 void ExchangeQueue::Emplace(const ::grpc::ByteBuffer& request_buf,
90 protobuf::Message* response, StatusCallback cb,
91 string debug_string) {
92 exchanges_.emplace(exchanges_.end(), request_buf, response, std::move(cb),
93 debug_string);
94 }
95
GetReadyForRequestWriting()96 Exchange* ExchangeQueue::GetReadyForRequestWriting() {
97 CheckInvariants();
98 if (!call_started_) {
99 return nullptr;
100 }
101
102 // TODO(iga): Optimize to avoid linear search.
103 for (Exchange& e : exchanges_) {
104 if (e.state() == Exchange::State::kExchangeCreated) {
105 return &e;
106 } else if (e.state() == Exchange::State::kRequestWriteIssued) {
107 return nullptr;
108 }
109 }
110 return nullptr;
111 }
112
GetReadyForResponseReading()113 Exchange* ExchangeQueue::GetReadyForResponseReading() {
114 CheckInvariants();
115 if (!call_started_) {
116 // We should never ask for response reading when call has not
117 // been started, but it does not hurt to defensively check here anyway.
118 return nullptr;
119 }
120 if (exchanges_.empty()) {
121 return nullptr;
122 }
123 Exchange& e = exchanges_[0];
124 if (e.state() == Exchange::State::kRequestWriteCompleted) {
125 return &e;
126 }
127 return nullptr;
128 }
129
MarkRequestWriteCompleted()130 void ExchangeQueue::MarkRequestWriteCompleted() {
131 CheckInvariants();
132 // TODO(iga): Optimize to avoid linear search.
133 for (Exchange& e : exchanges_) {
134 if (e.state() == Exchange::State::kRequestWriteIssued) {
135 e.MarkRequestWriteCompleted();
136 }
137 }
138 CheckInvariants();
139 }
140
GetFront()141 Exchange& ExchangeQueue::GetFront() {
142 CheckInvariants();
143 return exchanges_.front();
144 }
145
PopFront()146 void ExchangeQueue::PopFront() {
147 CheckInvariants();
148 exchanges_.pop_front();
149 }
150
DebugString() const151 string ExchangeQueue::DebugString() const {
152 return absl::StrJoin(exchanges_, ", ", [](string* out, const Exchange& e) {
153 out->append(e.DebugString());
154 });
155 }
156
Swap(ExchangeQueue * other)157 void ExchangeQueue::Swap(ExchangeQueue* other) {
158 exchanges_.swap(other->exchanges_);
159 std::swap(call_started_, other->call_started_);
160 }
161
CompleteAll(Status status)162 void ExchangeQueue::CompleteAll(Status status) {
163 for (Exchange& exchange : exchanges_) {
164 exchange.Complete(status);
165 }
166 }
167
168 namespace {
169 std::set<std::pair<Exchange::State, Exchange::State>>*
GetPossibleTransitions()170 GetPossibleTransitions() {
171 std::set<std::pair<Exchange::State, Exchange::State>>* s =
172 new std::set<std::pair<Exchange::State, Exchange::State>>();
173 // Regular state transitions
174 s->emplace(Exchange::State::kExchangeCreated,
175 Exchange::State::kRequestWriteIssued);
176 s->emplace(Exchange::State::kRequestWriteIssued,
177 Exchange::State::kRequestWriteCompleted);
178 s->emplace(Exchange::State::kRequestWriteCompleted,
179 Exchange::State::kResponseReadIssued);
180 // Self transitions. Possible when several exchanges can be in
181 // the same state.
182 s->emplace(Exchange::State::kExchangeCreated,
183 Exchange::State::kExchangeCreated);
184 s->emplace(Exchange::State::kRequestWriteCompleted,
185 Exchange::State::kRequestWriteCompleted);
186 // Skip transitions. Possible when there are no exchanges in a
187 // certain state.
188 s->emplace(Exchange::State::kExchangeCreated,
189 Exchange::State::kRequestWriteCompleted);
190 s->emplace(Exchange::State::kExchangeCreated,
191 Exchange::State::kResponseReadIssued);
192 s->emplace(Exchange::State::kRequestWriteIssued,
193 Exchange::State::kResponseReadIssued);
194 return s;
195 }
196 } // namespace
197
CheckInvariants()198 void ExchangeQueue::CheckInvariants() {
199 static std::set<std::pair<Exchange::State, Exchange::State>>*
200 possible_transitions = GetPossibleTransitions();
201
202 if (!VLOG_IS_ON(5)) {
203 return;
204 }
205
206 for (int i = 1, end = exchanges_.size(); i < end; ++i) {
207 const Exchange& e0 = exchanges_[i - 1];
208 const Exchange& e1 = exchanges_[i];
209 // The first exchange in the pair is the one that arrived later and is
210 // behind in processing.
211 auto p = std::make_pair(e1.state(), e0.state());
212 if (possible_transitions->find(p) == possible_transitions->end()) {
213 LOG(FATAL)
214 << "Found an impossible state transition in the exchange queue: "
215 << p.first << " -> " << p.second;
216 }
217 }
218 }
219
220 } // namespace tensorflow
221