1 // Copyright 2024 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/core/call/request_buffer.h"
16
17 #include <cstdint>
18
19 #include "absl/strings/str_cat.h"
20 #include "absl/types/optional.h"
21 #include "src/core/util/match.h"
22
23 namespace grpc_core {
24
Buffering()25 RequestBuffer::Buffering::Buffering() {}
26
RequestBuffer()27 RequestBuffer::RequestBuffer() : state_(absl::in_place_type_t<Buffering>()) {}
28
PushClientInitialMetadata(ClientMetadataHandle md)29 ValueOrFailure<size_t> RequestBuffer::PushClientInitialMetadata(
30 ClientMetadataHandle md) {
31 MutexLock lock(&mu_);
32 if (absl::get_if<Cancelled>(&state_)) return Failure{};
33 auto& buffering = absl::get<Buffering>(state_);
34 CHECK_EQ(buffering.initial_metadata.get(), nullptr);
35 buffering.initial_metadata = std::move(md);
36 buffering.buffered += buffering.initial_metadata->TransportSize();
37 WakeupAsyncAllPullers();
38 return buffering.buffered;
39 }
40
PollPushMessage(MessageHandle & message)41 Poll<ValueOrFailure<size_t>> RequestBuffer::PollPushMessage(
42 MessageHandle& message) {
43 MutexLock lock(&mu_);
44 if (absl::get_if<Cancelled>(&state_)) return Failure{};
45 size_t buffered = 0;
46 if (auto* buffering = absl::get_if<Buffering>(&state_)) {
47 if (winner_ != nullptr) return PendingPush();
48 buffering->buffered += message->payload()->Length();
49 buffered = buffering->buffered;
50 buffering->messages.push_back(std::move(message));
51 } else {
52 auto& streaming = absl::get<Streaming>(state_);
53 CHECK_EQ(streaming.end_of_stream, false);
54 if (streaming.message != nullptr) {
55 return PendingPush();
56 }
57 streaming.message = std::move(message);
58 }
59 WakeupAsyncAllPullers();
60 return buffered;
61 }
62
FinishSends()63 StatusFlag RequestBuffer::FinishSends() {
64 MutexLock lock(&mu_);
65 if (absl::get_if<Cancelled>(&state_)) return Failure{};
66 if (auto* buffering = absl::get_if<Buffering>(&state_)) {
67 Buffered buffered(std::move(buffering->initial_metadata),
68 std::move(buffering->messages));
69 state_.emplace<Buffered>(std::move(buffered));
70 } else {
71 auto& streaming = absl::get<Streaming>(state_);
72 CHECK_EQ(streaming.end_of_stream, false);
73 streaming.end_of_stream = true;
74 }
75 WakeupAsyncAllPullers();
76 return Success{};
77 }
78
Cancel(absl::Status error)79 void RequestBuffer::Cancel(absl::Status error) {
80 MutexLock lock(&mu_);
81 if (absl::holds_alternative<Cancelled>(state_)) return;
82 state_.emplace<Cancelled>(std::move(error));
83 WakeupAsyncAllPullers();
84 }
85
Commit(Reader * winner)86 void RequestBuffer::Commit(Reader* winner) {
87 MutexLock lock(&mu_);
88 CHECK_EQ(winner_, nullptr);
89 winner_ = winner;
90 if (auto* buffering = absl::get_if<Buffering>(&state_)) {
91 if (buffering->initial_metadata != nullptr &&
92 winner->message_index_ == buffering->messages.size() &&
93 winner->pulled_client_initial_metadata_) {
94 state_.emplace<Streaming>();
95 }
96 } else if (auto* buffered = absl::get_if<Buffered>(&state_)) {
97 CHECK_NE(buffered->initial_metadata.get(), nullptr);
98 if (winner->message_index_ == buffered->messages.size()) {
99 state_.emplace<Streaming>().end_of_stream = true;
100 }
101 }
102 WakeupAsyncAllPullersExcept(winner);
103 }
104
WakeupAsyncAllPullersExcept(Reader * except_reader)105 void RequestBuffer::WakeupAsyncAllPullersExcept(Reader* except_reader) {
106 for (auto wakeup_reader : readers_) {
107 if (wakeup_reader == except_reader) continue;
108 wakeup_reader->pull_waker_.WakeupAsync();
109 }
110 }
111
112 Poll<ValueOrFailure<ClientMetadataHandle>>
PollPullClientInitialMetadata()113 RequestBuffer::Reader::PollPullClientInitialMetadata() {
114 MutexLock lock(&buffer_->mu_);
115 if (buffer_->winner_ != nullptr && buffer_->winner_ != this) {
116 error_ = absl::CancelledError("Another call was chosen");
117 return Failure{};
118 }
119 if (auto* buffering = absl::get_if<Buffering>(&buffer_->state_)) {
120 if (buffering->initial_metadata.get() == nullptr) {
121 return buffer_->PendingPull(this);
122 }
123 pulled_client_initial_metadata_ = true;
124 auto result = ClaimObject(buffering->initial_metadata);
125 buffer_->MaybeSwitchToStreaming();
126 return std::move(result);
127 }
128 if (auto* buffered = absl::get_if<Buffered>(&buffer_->state_)) {
129 pulled_client_initial_metadata_ = true;
130 return ClaimObject(buffered->initial_metadata);
131 }
132 error_ = absl::get<Cancelled>(buffer_->state_).error;
133 return Failure{};
134 }
135
136 Poll<ValueOrFailure<absl::optional<MessageHandle>>>
PollPullMessage()137 RequestBuffer::Reader::PollPullMessage() {
138 ReleasableMutexLock lock(&buffer_->mu_);
139 if (buffer_->winner_ != nullptr && buffer_->winner_ != this) {
140 error_ = absl::CancelledError("Another call was chosen");
141 return Failure{};
142 }
143 if (auto* buffering = absl::get_if<Buffering>(&buffer_->state_)) {
144 if (message_index_ == buffering->messages.size()) {
145 return buffer_->PendingPull(this);
146 }
147 const auto idx = message_index_;
148 auto result = ClaimObject(buffering->messages[idx]);
149 ++message_index_;
150 buffer_->MaybeSwitchToStreaming();
151 return std::move(result);
152 }
153 if (auto* buffered = absl::get_if<Buffered>(&buffer_->state_)) {
154 if (message_index_ == buffered->messages.size()) return absl::nullopt;
155 const auto idx = message_index_;
156 ++message_index_;
157 return ClaimObject(buffered->messages[idx]);
158 }
159 if (auto* streaming = absl::get_if<Streaming>(&buffer_->state_)) {
160 if (streaming->message == nullptr) {
161 if (streaming->end_of_stream) return absl::nullopt;
162 return buffer_->PendingPull(this);
163 }
164 auto msg = std::move(streaming->message);
165 auto waker = std::move(buffer_->push_waker_);
166 lock.Release();
167 waker.Wakeup();
168 return std::move(msg);
169 }
170 error_ = absl::get<Cancelled>(buffer_->state_).error;
171 return Failure{};
172 }
173
DebugString(Reader * caller)174 std::string RequestBuffer::DebugString(Reader* caller)
175 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
176 return absl::StrCat(
177 "have_winner=",
178 (winner_ == nullptr ? "no" : (winner_ == caller ? "this" : "other")),
179 " num_readers=", readers_.size(),
180 " push_waker=", push_waker_.DebugString(),
181 Match(
182 state_,
183 [](const Buffering& buffering) {
184 return absl::StrCat(
185 " buffering initial_metadata=",
186 (buffering.initial_metadata != nullptr
187 ? buffering.initial_metadata->DebugString()
188 : "null"),
189 " messages=[",
190 absl::StrJoin(
191 buffering.messages, ",",
192 [](std::string* output, const MessageHandle& hdl) {
193 absl::StrAppend(output, hdl->DebugString());
194 }),
195 "] buffered=", buffering.buffered);
196 },
197 [](const Buffered& buffered) {
198 return absl::StrCat(
199 " buffered initial_metadata=",
200 (buffered.initial_metadata != nullptr
201 ? buffered.initial_metadata->DebugString()
202 : "null"),
203 " messages=[",
204 absl::StrJoin(
205 buffered.messages, ",",
206 [](std::string* output, const MessageHandle& hdl) {
207 absl::StrAppend(
208 output, hdl != nullptr ? hdl->DebugString() : "null");
209 }),
210 "]");
211 },
212 [](const Streaming& streaming) {
213 return absl::StrCat(
214 " streaming message=",
215 (streaming.message != nullptr ? streaming.message->DebugString()
216 : "null"),
217 " end_of_stream=", streaming.end_of_stream);
218 },
219 [](const Cancelled& cancelled) {
220 return absl::StrCat(" cancelled error=",
221 cancelled.error.ToString());
222 }));
223 }
224 } // namespace grpc_core
225