• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/core/call/request_buffer.h"
16 
17 #include <cstdint>
18 
19 #include "absl/strings/str_cat.h"
20 #include "absl/types/optional.h"
21 #include "src/core/util/match.h"
22 
23 namespace grpc_core {
24 
Buffering()25 RequestBuffer::Buffering::Buffering() {}
26 
RequestBuffer()27 RequestBuffer::RequestBuffer() : state_(absl::in_place_type_t<Buffering>()) {}
28 
PushClientInitialMetadata(ClientMetadataHandle md)29 ValueOrFailure<size_t> RequestBuffer::PushClientInitialMetadata(
30     ClientMetadataHandle md) {
31   MutexLock lock(&mu_);
32   if (absl::get_if<Cancelled>(&state_)) return Failure{};
33   auto& buffering = absl::get<Buffering>(state_);
34   CHECK_EQ(buffering.initial_metadata.get(), nullptr);
35   buffering.initial_metadata = std::move(md);
36   buffering.buffered += buffering.initial_metadata->TransportSize();
37   WakeupAsyncAllPullers();
38   return buffering.buffered;
39 }
40 
PollPushMessage(MessageHandle & message)41 Poll<ValueOrFailure<size_t>> RequestBuffer::PollPushMessage(
42     MessageHandle& message) {
43   MutexLock lock(&mu_);
44   if (absl::get_if<Cancelled>(&state_)) return Failure{};
45   size_t buffered = 0;
46   if (auto* buffering = absl::get_if<Buffering>(&state_)) {
47     if (winner_ != nullptr) return PendingPush();
48     buffering->buffered += message->payload()->Length();
49     buffered = buffering->buffered;
50     buffering->messages.push_back(std::move(message));
51   } else {
52     auto& streaming = absl::get<Streaming>(state_);
53     CHECK_EQ(streaming.end_of_stream, false);
54     if (streaming.message != nullptr) {
55       return PendingPush();
56     }
57     streaming.message = std::move(message);
58   }
59   WakeupAsyncAllPullers();
60   return buffered;
61 }
62 
FinishSends()63 StatusFlag RequestBuffer::FinishSends() {
64   MutexLock lock(&mu_);
65   if (absl::get_if<Cancelled>(&state_)) return Failure{};
66   if (auto* buffering = absl::get_if<Buffering>(&state_)) {
67     Buffered buffered(std::move(buffering->initial_metadata),
68                       std::move(buffering->messages));
69     state_.emplace<Buffered>(std::move(buffered));
70   } else {
71     auto& streaming = absl::get<Streaming>(state_);
72     CHECK_EQ(streaming.end_of_stream, false);
73     streaming.end_of_stream = true;
74   }
75   WakeupAsyncAllPullers();
76   return Success{};
77 }
78 
Cancel(absl::Status error)79 void RequestBuffer::Cancel(absl::Status error) {
80   MutexLock lock(&mu_);
81   if (absl::holds_alternative<Cancelled>(state_)) return;
82   state_.emplace<Cancelled>(std::move(error));
83   WakeupAsyncAllPullers();
84 }
85 
Commit(Reader * winner)86 void RequestBuffer::Commit(Reader* winner) {
87   MutexLock lock(&mu_);
88   CHECK_EQ(winner_, nullptr);
89   winner_ = winner;
90   if (auto* buffering = absl::get_if<Buffering>(&state_)) {
91     if (buffering->initial_metadata != nullptr &&
92         winner->message_index_ == buffering->messages.size() &&
93         winner->pulled_client_initial_metadata_) {
94       state_.emplace<Streaming>();
95     }
96   } else if (auto* buffered = absl::get_if<Buffered>(&state_)) {
97     CHECK_NE(buffered->initial_metadata.get(), nullptr);
98     if (winner->message_index_ == buffered->messages.size()) {
99       state_.emplace<Streaming>().end_of_stream = true;
100     }
101   }
102   WakeupAsyncAllPullersExcept(winner);
103 }
104 
WakeupAsyncAllPullersExcept(Reader * except_reader)105 void RequestBuffer::WakeupAsyncAllPullersExcept(Reader* except_reader) {
106   for (auto wakeup_reader : readers_) {
107     if (wakeup_reader == except_reader) continue;
108     wakeup_reader->pull_waker_.WakeupAsync();
109   }
110 }
111 
112 Poll<ValueOrFailure<ClientMetadataHandle>>
PollPullClientInitialMetadata()113 RequestBuffer::Reader::PollPullClientInitialMetadata() {
114   MutexLock lock(&buffer_->mu_);
115   if (buffer_->winner_ != nullptr && buffer_->winner_ != this) {
116     error_ = absl::CancelledError("Another call was chosen");
117     return Failure{};
118   }
119   if (auto* buffering = absl::get_if<Buffering>(&buffer_->state_)) {
120     if (buffering->initial_metadata.get() == nullptr) {
121       return buffer_->PendingPull(this);
122     }
123     pulled_client_initial_metadata_ = true;
124     auto result = ClaimObject(buffering->initial_metadata);
125     buffer_->MaybeSwitchToStreaming();
126     return std::move(result);
127   }
128   if (auto* buffered = absl::get_if<Buffered>(&buffer_->state_)) {
129     pulled_client_initial_metadata_ = true;
130     return ClaimObject(buffered->initial_metadata);
131   }
132   error_ = absl::get<Cancelled>(buffer_->state_).error;
133   return Failure{};
134 }
135 
136 Poll<ValueOrFailure<absl::optional<MessageHandle>>>
PollPullMessage()137 RequestBuffer::Reader::PollPullMessage() {
138   ReleasableMutexLock lock(&buffer_->mu_);
139   if (buffer_->winner_ != nullptr && buffer_->winner_ != this) {
140     error_ = absl::CancelledError("Another call was chosen");
141     return Failure{};
142   }
143   if (auto* buffering = absl::get_if<Buffering>(&buffer_->state_)) {
144     if (message_index_ == buffering->messages.size()) {
145       return buffer_->PendingPull(this);
146     }
147     const auto idx = message_index_;
148     auto result = ClaimObject(buffering->messages[idx]);
149     ++message_index_;
150     buffer_->MaybeSwitchToStreaming();
151     return std::move(result);
152   }
153   if (auto* buffered = absl::get_if<Buffered>(&buffer_->state_)) {
154     if (message_index_ == buffered->messages.size()) return absl::nullopt;
155     const auto idx = message_index_;
156     ++message_index_;
157     return ClaimObject(buffered->messages[idx]);
158   }
159   if (auto* streaming = absl::get_if<Streaming>(&buffer_->state_)) {
160     if (streaming->message == nullptr) {
161       if (streaming->end_of_stream) return absl::nullopt;
162       return buffer_->PendingPull(this);
163     }
164     auto msg = std::move(streaming->message);
165     auto waker = std::move(buffer_->push_waker_);
166     lock.Release();
167     waker.Wakeup();
168     return std::move(msg);
169   }
170   error_ = absl::get<Cancelled>(buffer_->state_).error;
171   return Failure{};
172 }
173 
DebugString(Reader * caller)174 std::string RequestBuffer::DebugString(Reader* caller)
175     ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
176   return absl::StrCat(
177       "have_winner=",
178       (winner_ == nullptr ? "no" : (winner_ == caller ? "this" : "other")),
179       " num_readers=", readers_.size(),
180       " push_waker=", push_waker_.DebugString(),
181       Match(
182           state_,
183           [](const Buffering& buffering) {
184             return absl::StrCat(
185                 " buffering initial_metadata=",
186                 (buffering.initial_metadata != nullptr
187                      ? buffering.initial_metadata->DebugString()
188                      : "null"),
189                 " messages=[",
190                 absl::StrJoin(
191                     buffering.messages, ",",
192                     [](std::string* output, const MessageHandle& hdl) {
193                       absl::StrAppend(output, hdl->DebugString());
194                     }),
195                 "] buffered=", buffering.buffered);
196           },
197           [](const Buffered& buffered) {
198             return absl::StrCat(
199                 " buffered initial_metadata=",
200                 (buffered.initial_metadata != nullptr
201                      ? buffered.initial_metadata->DebugString()
202                      : "null"),
203                 " messages=[",
204                 absl::StrJoin(
205                     buffered.messages, ",",
206                     [](std::string* output, const MessageHandle& hdl) {
207                       absl::StrAppend(
208                           output, hdl != nullptr ? hdl->DebugString() : "null");
209                     }),
210                 "]");
211           },
212           [](const Streaming& streaming) {
213             return absl::StrCat(
214                 " streaming message=",
215                 (streaming.message != nullptr ? streaming.message->DebugString()
216                                               : "null"),
217                 " end_of_stream=", streaming.end_of_stream);
218           },
219           [](const Cancelled& cancelled) {
220             return absl::StrCat(" cancelled error=",
221                                 cancelled.error.ToString());
222           }));
223 }
224 }  // namespace grpc_core
225