• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
16 
17 #include "absl/strings/str_format.h"
18 
19 namespace tensorflow {
20 
ToString(UntypedStreamingRPCState::Tag::TagType tag_type)21 const char* ToString(UntypedStreamingRPCState::Tag::TagType tag_type) {
22   switch (tag_type) {
23     case UntypedStreamingRPCState::Tag::TagType::kCallStarted:
24       return "kCallStarted";
25     case UntypedStreamingRPCState::Tag::TagType::kRequestWriteCompleted:
26       return "kRequestWriteCompleted";
27     case UntypedStreamingRPCState::Tag::TagType::kResponseReadCompleted:
28       return "kResponseReadCompleted";
29     case UntypedStreamingRPCState::Tag::TagType::kCallFinished:
30       return "kCallFinished";
31   }
32 }
33 
Tag(UntypedStreamingRPCState * streaming_state,Tag::TagType type)34 UntypedStreamingRPCState::Tag::Tag(UntypedStreamingRPCState* streaming_state,
35                                    Tag::TagType type)
36     : streaming_state_(streaming_state), type_(type) {}
37 
OnCompleted(bool ok)38 void UntypedStreamingRPCState::Tag::OnCompleted(bool ok) {
39   switch (type_) {
40     case TagType::kCallStarted:
41       streaming_state_->CallStarted(ok);
42       break;
43     case TagType::kRequestWriteCompleted:
44       streaming_state_->RequestWriteCompleted(ok);
45       break;
46     case TagType::kResponseReadCompleted:
47       streaming_state_->ResponseReadCompleted(ok);
48       break;
49     case TagType::kCallFinished:
50       streaming_state_->CallFinished(ok);
51       break;
52   }
53   streaming_state_->Unref();  // Ref acquired when tag was handed to grpc.
54 }
55 
Complete(Status status)56 void Exchange::Complete(Status status) {
57   if (status.ok()) {
58     if (!GrpcMaybeParseProto(&response_buf_, response_)) {
59       status.Update(errors::Internal("could not parse rpc response"));
60     }
61   }
62   VLOG(3) << "Completing exchange " << DebugString() << " with "
63           << status.ToString();
64   cb_(status);
65 }
66 
operator <<(std::ostream & os,const Exchange::State & state)67 std::ostream& operator<<(std::ostream& os, const Exchange::State& state) {
68   os << ToString(state);
69   return os;
70 }
71 
ToString(Exchange::State state)72 const char* ToString(Exchange::State state) {
73   switch (state) {
74     case Exchange::State::kExchangeCreated:
75       return "ExchangeCreated";
76     case Exchange::State::kRequestWriteIssued:
77       return "RequestWriteIssued";
78     case Exchange::State::kRequestWriteCompleted:
79       return "RequestWriteCompleted";
80     case Exchange::State::kResponseReadIssued:
81       return "ResponseReadIssued";
82   }
83 }
84 
DebugString() const85 string Exchange::DebugString() const {
86   return absl::StrFormat("%p@%s_%s", this, ToString(state_), debug_string_);
87 }
88 
Emplace(const::grpc::ByteBuffer & request_buf,protobuf::Message * response,StatusCallback cb,string debug_string)89 void ExchangeQueue::Emplace(const ::grpc::ByteBuffer& request_buf,
90                             protobuf::Message* response, StatusCallback cb,
91                             string debug_string) {
92   exchanges_.emplace(exchanges_.end(), request_buf, response, std::move(cb),
93                      debug_string);
94 }
95 
GetReadyForRequestWriting()96 Exchange* ExchangeQueue::GetReadyForRequestWriting() {
97   CheckInvariants();
98   if (!call_started_) {
99     return nullptr;
100   }
101 
102   // TODO(iga): Optimize to avoid linear search.
103   for (Exchange& e : exchanges_) {
104     if (e.state() == Exchange::State::kExchangeCreated) {
105       return &e;
106     } else if (e.state() == Exchange::State::kRequestWriteIssued) {
107       return nullptr;
108     }
109   }
110   return nullptr;
111 }
112 
GetReadyForResponseReading()113 Exchange* ExchangeQueue::GetReadyForResponseReading() {
114   CheckInvariants();
115   if (!call_started_) {
116     // We should never ask for response reading when call has not
117     // been started, but it does not hurt to defensively check here anyway.
118     return nullptr;
119   }
120   if (exchanges_.empty()) {
121     return nullptr;
122   }
123   Exchange& e = exchanges_[0];
124   if (e.state() == Exchange::State::kRequestWriteCompleted) {
125     return &e;
126   }
127   return nullptr;
128 }
129 
MarkRequestWriteCompleted()130 void ExchangeQueue::MarkRequestWriteCompleted() {
131   CheckInvariants();
132   // TODO(iga): Optimize to avoid linear search.
133   for (Exchange& e : exchanges_) {
134     if (e.state() == Exchange::State::kRequestWriteIssued) {
135       e.MarkRequestWriteCompleted();
136     }
137   }
138   CheckInvariants();
139 }
140 
GetFront()141 Exchange& ExchangeQueue::GetFront() {
142   CheckInvariants();
143   return exchanges_.front();
144 }
145 
PopFront()146 void ExchangeQueue::PopFront() {
147   CheckInvariants();
148   exchanges_.pop_front();
149 }
150 
DebugString() const151 string ExchangeQueue::DebugString() const {
152   return absl::StrJoin(exchanges_, ", ", [](string* out, const Exchange& e) {
153     out->append(e.DebugString());
154   });
155 }
156 
Swap(ExchangeQueue * other)157 void ExchangeQueue::Swap(ExchangeQueue* other) {
158   exchanges_.swap(other->exchanges_);
159   std::swap(call_started_, other->call_started_);
160 }
161 
CompleteAll(Status status)162 void ExchangeQueue::CompleteAll(Status status) {
163   for (Exchange& exchange : exchanges_) {
164     exchange.Complete(status);
165   }
166 }
167 
168 namespace {
169 std::set<std::pair<Exchange::State, Exchange::State>>*
GetPossibleTransitions()170 GetPossibleTransitions() {
171   std::set<std::pair<Exchange::State, Exchange::State>>* s =
172       new std::set<std::pair<Exchange::State, Exchange::State>>();
173   // Regular state transitions
174   s->emplace(Exchange::State::kExchangeCreated,
175              Exchange::State::kRequestWriteIssued);
176   s->emplace(Exchange::State::kRequestWriteIssued,
177              Exchange::State::kRequestWriteCompleted);
178   s->emplace(Exchange::State::kRequestWriteCompleted,
179              Exchange::State::kResponseReadIssued);
180   // Self transitions. Possible when several exchanges can be in
181   // the same state.
182   s->emplace(Exchange::State::kExchangeCreated,
183              Exchange::State::kExchangeCreated);
184   s->emplace(Exchange::State::kRequestWriteCompleted,
185              Exchange::State::kRequestWriteCompleted);
186   // Skip transitions. Possible when there are no exchanges in a
187   // certain state.
188   s->emplace(Exchange::State::kExchangeCreated,
189              Exchange::State::kRequestWriteCompleted);
190   s->emplace(Exchange::State::kExchangeCreated,
191              Exchange::State::kResponseReadIssued);
192   s->emplace(Exchange::State::kRequestWriteIssued,
193              Exchange::State::kResponseReadIssued);
194   return s;
195 }
196 }  // namespace
197 
CheckInvariants()198 void ExchangeQueue::CheckInvariants() {
199   static std::set<std::pair<Exchange::State, Exchange::State>>*
200       possible_transitions = GetPossibleTransitions();
201 
202   if (!VLOG_IS_ON(5)) {
203     return;
204   }
205 
206   for (int i = 1, end = exchanges_.size(); i < end; ++i) {
207     const Exchange& e0 = exchanges_[i - 1];
208     const Exchange& e1 = exchanges_[i];
209     // The first exchange in the pair is the one that arrived later and is
210     // behind in processing.
211     auto p = std::make_pair(e1.state(), e0.state());
212     if (possible_transitions->find(p) == possible_transitions->end()) {
213       LOG(FATAL)
214           << "Found an impossible state transition in the exchange queue: "
215           << p.first << " -> " << p.second;
216     }
217   }
218 }
219 
220 }  // namespace tensorflow
221