1 // Copyright 2024 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef GRPC_SRC_CORE_LIB_TRANSPORT_CALL_SPINE_H
16 #define GRPC_SRC_CORE_LIB_TRANSPORT_CALL_SPINE_H
17
18 #include <grpc/support/port_platform.h>
19
20 #include "absl/log/check.h"
21 #include "src/core/lib/promise/detail/status.h"
22 #include "src/core/lib/promise/if.h"
23 #include "src/core/lib/promise/latch.h"
24 #include "src/core/lib/promise/party.h"
25 #include "src/core/lib/promise/pipe.h"
26 #include "src/core/lib/promise/prioritized_race.h"
27 #include "src/core/lib/promise/promise.h"
28 #include "src/core/lib/promise/status_flag.h"
29 #include "src/core/lib/promise/try_seq.h"
30 #include "src/core/lib/transport/call_arena_allocator.h"
31 #include "src/core/lib/transport/call_filters.h"
32 #include "src/core/lib/transport/message.h"
33 #include "src/core/lib/transport/metadata.h"
34 #include "src/core/util/dual_ref_counted.h"
35
36 namespace grpc_core {
37
38 // The common middle part of a call - a reference is held by each of
39 // CallInitiator and CallHandler - which provide interfaces that are appropriate
40 // for each side of a call.
41 // Hosts context, call filters, and the arena.
42 class CallSpine final : public Party {
43 public:
Create(ClientMetadataHandle client_initial_metadata,RefCountedPtr<Arena> arena)44 static RefCountedPtr<CallSpine> Create(
45 ClientMetadataHandle client_initial_metadata,
46 RefCountedPtr<Arena> arena) {
47 Arena* arena_ptr = arena.get();
48 return RefCountedPtr<CallSpine>(arena_ptr->New<CallSpine>(
49 std::move(client_initial_metadata), std::move(arena)));
50 }
51
~CallSpine()52 ~CallSpine() override { CallOnDone(true); }
53
call_filters()54 CallFilters& call_filters() { return call_filters_; }
55
56 // Add a callback to be called when server trailing metadata is received and
57 // return true.
58 // If CallOnDone has already been invoked, does nothing and returns false.
OnDone(absl::AnyInvocable<void (bool)> fn)59 GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable<void(bool)> fn) {
60 if (call_filters().WasServerTrailingMetadataPulled()) {
61 return false;
62 }
63 if (on_done_ == nullptr) {
64 on_done_ = std::move(fn);
65 return true;
66 }
67 on_done_ = [first = std::move(fn),
68 next = std::move(on_done_)](bool cancelled) mutable {
69 first(cancelled);
70 next(cancelled);
71 };
72 return true;
73 }
CallOnDone(bool cancelled)74 void CallOnDone(bool cancelled) {
75 if (on_done_ != nullptr) std::exchange(on_done_, nullptr)(cancelled);
76 }
77
PullServerInitialMetadata()78 auto PullServerInitialMetadata() {
79 return call_filters().PullServerInitialMetadata();
80 }
81
PullServerTrailingMetadata()82 auto PullServerTrailingMetadata() {
83 return Map(
84 call_filters().PullServerTrailingMetadata(),
85 [this](ServerMetadataHandle result) {
86 CallOnDone(result->get(GrpcCallWasCancelled()).value_or(false));
87 return result;
88 });
89 }
90
PushClientToServerMessage(MessageHandle message)91 auto PushClientToServerMessage(MessageHandle message) {
92 return call_filters().PushClientToServerMessage(std::move(message));
93 }
94
PullClientToServerMessage()95 auto PullClientToServerMessage() {
96 return call_filters().PullClientToServerMessage();
97 }
98
PushServerToClientMessage(MessageHandle message)99 auto PushServerToClientMessage(MessageHandle message) {
100 return call_filters().PushServerToClientMessage(std::move(message));
101 }
102
PullServerToClientMessage()103 auto PullServerToClientMessage() {
104 return call_filters().PullServerToClientMessage();
105 }
106
PushServerTrailingMetadata(ServerMetadataHandle md)107 void PushServerTrailingMetadata(ServerMetadataHandle md) {
108 GRPC_TRACE_LOG(call_state, INFO)
109 << "[call_state] PushServerTrailingMetadata: " << md->DebugString();
110 call_filters().PushServerTrailingMetadata(std::move(md));
111 }
112
FinishSends()113 void FinishSends() { call_filters().FinishClientToServerSends(); }
114
PullClientInitialMetadata()115 auto PullClientInitialMetadata() {
116 return call_filters().PullClientInitialMetadata();
117 }
118
PushServerInitialMetadata(ServerMetadataHandle md)119 StatusFlag PushServerInitialMetadata(ServerMetadataHandle md) {
120 return call_filters().PushServerInitialMetadata(std::move(md));
121 }
122
WasCancelled()123 auto WasCancelled() { return call_filters().WasCancelled(); }
124
UnprocessedClientInitialMetadata()125 ClientMetadata& UnprocessedClientInitialMetadata() {
126 return *call_filters().unprocessed_client_initial_metadata();
127 }
128
129 // Wrap a promise so that if it returns failure it automatically cancels
130 // the rest of the call.
131 // The resulting (returned) promise will resolve to Empty.
132 template <typename Promise>
CancelIfFails(Promise promise)133 auto CancelIfFails(Promise promise) {
134 DCHECK(GetContext<Activity>() == this);
135 using P = promise_detail::PromiseLike<Promise>;
136 using ResultType = typename P::Result;
137 return Map(std::move(promise), [this](ResultType r) {
138 CancelIfFailed(r);
139 return r;
140 });
141 }
142
143 template <typename StatusType>
CancelIfFailed(const StatusType & r)144 void CancelIfFailed(const StatusType& r) {
145 if (!IsStatusOk(r)) {
146 Cancel();
147 }
148 }
149
Cancel()150 void Cancel() { call_filters().Cancel(); }
151
152 // Spawn a promise that returns Empty{} and save some boilerplate handling
153 // that detail.
154 template <typename PromiseFactory>
SpawnInfallible(absl::string_view name,PromiseFactory promise_factory)155 void SpawnInfallible(absl::string_view name, PromiseFactory promise_factory) {
156 Spawn(name, std::move(promise_factory), [](Empty) {});
157 }
158
159 // Spawn a promise that returns some status-like type; if the status
160 // represents failure automatically cancel the rest of the call.
161 template <typename PromiseFactory>
162 void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory,
163 DebugLocation whence = {}) {
164 using FactoryType =
165 promise_detail::OncePromiseFactory<void, PromiseFactory>;
166 using PromiseType = typename FactoryType::Promise;
167 using ResultType = typename PromiseType::Result;
168 static_assert(
169 std::is_same<bool,
170 decltype(IsStatusOk(std::declval<ResultType>()))>::value,
171 "SpawnGuarded promise must return a status-like object");
172 Spawn(name, std::move(promise_factory), [this, whence](ResultType r) {
173 if (!IsStatusOk(r)) {
174 GRPC_TRACE_LOG(promise_primitives, INFO)
175 << "SpawnGuarded sees failure: " << r
176 << " (source: " << whence.file() << ":" << whence.line() << ")";
177 auto status = StatusCast<ServerMetadataHandle>(std::move(r));
178 status->Set(GrpcCallWasCancelled(), true);
179 PushServerTrailingMetadata(std::move(status));
180 }
181 });
182 }
183
184 // Wrap a promise so that if the call completes that promise is cancelled.
185 template <typename Promise>
UntilCallCompletes(Promise promise)186 auto UntilCallCompletes(Promise promise) {
187 using Result = PromiseResult<Promise>;
188 return PrioritizedRace(std::move(promise), Map(WasCancelled(), [](bool) {
189 return FailureStatusCast<Result>(Failure{});
190 }));
191 }
192
193 template <typename PromiseFactory>
SpawnGuardedUntilCallCompletes(absl::string_view name,PromiseFactory promise_factory)194 void SpawnGuardedUntilCallCompletes(absl::string_view name,
195 PromiseFactory promise_factory) {
196 SpawnGuarded(name, [this, promise_factory]() mutable {
197 return UntilCallCompletes(promise_factory());
198 });
199 }
200
201 // Spawned operations: these are callable from /outside/ the call; they spawn
202 // an operation into the call and execute that operation.
203
SpawnPushServerInitialMetadata(ServerMetadataHandle md)204 void SpawnPushServerInitialMetadata(ServerMetadataHandle md) {
205 SpawnInfallible(
206 "push-server-initial-metadata",
207 [md = std::move(md), self = RefAsSubclass<CallSpine>()]() mutable {
208 self->CancelIfFailed(self->PushServerInitialMetadata(std::move(md)));
209 });
210 }
211
SpawnPushServerToClientMessage(MessageHandle msg)212 auto SpawnPushServerToClientMessage(MessageHandle msg) {
213 return SpawnWaitable(
214 "push-message",
215 [msg = std::move(msg), self = RefAsSubclass<CallSpine>()]() mutable {
216 return self->CancelIfFails(
217 self->PushServerToClientMessage(std::move(msg)));
218 });
219 }
220
SpawnPushClientToServerMessage(MessageHandle msg)221 auto SpawnPushClientToServerMessage(MessageHandle msg) {
222 return SpawnWaitable(
223 "push-message",
224 [msg = std::move(msg), self = RefAsSubclass<CallSpine>()]() mutable {
225 return self->CancelIfFails(
226 self->PushClientToServerMessage(std::move(msg)));
227 });
228 }
229
SpawnFinishSends()230 void SpawnFinishSends() {
231 SpawnInfallible("finish-sends", [self = RefAsSubclass<CallSpine>()]() {
232 self->FinishSends();
233 return Empty{};
234 });
235 }
236
SpawnPushServerTrailingMetadata(ServerMetadataHandle md)237 void SpawnPushServerTrailingMetadata(ServerMetadataHandle md) {
238 SpawnInfallible(
239 "push-server-trailing-metadata",
240 [md = std::move(md), self = RefAsSubclass<CallSpine>()]() mutable {
241 self->PushServerTrailingMetadata(std::move(md));
242 return Empty{};
243 });
244 }
245
SpawnCancel()246 void SpawnCancel() {
247 SpawnInfallible("cancel", [self = RefAsSubclass<CallSpine>()]() {
248 self->call_filters().Cancel();
249 });
250 }
251
AddChildCall(RefCountedPtr<CallSpine> child_call)252 void AddChildCall(RefCountedPtr<CallSpine> child_call) {
253 child_calls_.push_back(std::move(child_call));
254 if (child_calls_.size() == 1) {
255 SpawnInfallible(
256 "check_cancellation", [self = RefAsSubclass<CallSpine>()]() mutable {
257 auto was_completed =
258 self->call_filters().ServerTrailingMetadataWasPushed();
259 return Map(std::move(was_completed),
260 [self = std::move(self)](Empty) {
261 for (auto& child : self->child_calls_) {
262 child->SpawnCancel();
263 }
264 return Empty{};
265 });
266 });
267 }
268 }
269
270 private:
271 friend class Arena;
CallSpine(ClientMetadataHandle client_initial_metadata,RefCountedPtr<Arena> arena)272 CallSpine(ClientMetadataHandle client_initial_metadata,
273 RefCountedPtr<Arena> arena)
274 : Party(std::move(arena)),
275 call_filters_(std::move(client_initial_metadata)) {}
276
277 // Call filters/pipes part of the spine
278 CallFilters call_filters_;
279 absl::AnyInvocable<void(bool)> on_done_{nullptr};
280 // Call spines that should be cancelled if this spine is cancelled
281 absl::InlinedVector<RefCountedPtr<CallSpine>, 3> child_calls_;
282 };
283
284 class CallHandler;
285
286 class CallInitiator {
287 public:
288 using NextMessage = ServerToClientNextMessage;
289
290 CallInitiator() = default;
CallInitiator(RefCountedPtr<CallSpine> spine)291 explicit CallInitiator(RefCountedPtr<CallSpine> spine)
292 : spine_(std::move(spine)) {}
293
294 template <typename Promise>
CancelIfFails(Promise promise)295 auto CancelIfFails(Promise promise) {
296 return spine_->CancelIfFails(std::move(promise));
297 }
298
PullServerInitialMetadata()299 auto PullServerInitialMetadata() {
300 return spine_->PullServerInitialMetadata();
301 }
302
PushMessage(MessageHandle message)303 auto PushMessage(MessageHandle message) {
304 return spine_->PushClientToServerMessage(std::move(message));
305 }
306
SpawnPushMessage(MessageHandle message)307 auto SpawnPushMessage(MessageHandle message) {
308 return spine_->SpawnPushClientToServerMessage(std::move(message));
309 }
310
FinishSends()311 void FinishSends() { spine_->FinishSends(); }
312
SpawnFinishSends()313 void SpawnFinishSends() { spine_->SpawnFinishSends(); }
314
PullMessage()315 auto PullMessage() { return spine_->PullServerToClientMessage(); }
316
PullServerTrailingMetadata()317 auto PullServerTrailingMetadata() {
318 return spine_->PullServerTrailingMetadata();
319 }
320
Cancel(absl::Status error)321 void Cancel(absl::Status error) {
322 CHECK(!error.ok());
323 auto status = ServerMetadataFromStatus(error);
324 status->Set(GrpcCallWasCancelled(), true);
325 spine_->PushServerTrailingMetadata(std::move(status));
326 }
327
SpawnCancel(absl::Status error)328 void SpawnCancel(absl::Status error) {
329 CHECK(!error.ok());
330 auto status = ServerMetadataFromStatus(error);
331 status->Set(GrpcCallWasCancelled(), true);
332 spine_->SpawnPushServerTrailingMetadata(std::move(status));
333 }
334
Cancel()335 void Cancel() { spine_->Cancel(); }
336
SpawnCancel()337 void SpawnCancel() { spine_->SpawnCancel(); }
338
OnDone(absl::AnyInvocable<void (bool)> fn)339 GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable<void(bool)> fn) {
340 return spine_->OnDone(std::move(fn));
341 }
342
343 template <typename PromiseFactory>
SpawnGuarded(absl::string_view name,PromiseFactory promise_factory)344 void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory) {
345 spine_->SpawnGuarded(name, std::move(promise_factory));
346 }
347
348 template <typename PromiseFactory>
SpawnGuardedUntilCallCompletes(absl::string_view name,PromiseFactory promise_factory)349 void SpawnGuardedUntilCallCompletes(absl::string_view name,
350 PromiseFactory promise_factory) {
351 spine_->SpawnGuardedUntilCallCompletes(name, std::move(promise_factory));
352 }
353
354 template <typename PromiseFactory>
SpawnInfallible(absl::string_view name,PromiseFactory promise_factory)355 void SpawnInfallible(absl::string_view name, PromiseFactory promise_factory) {
356 spine_->SpawnInfallible(name, std::move(promise_factory));
357 }
358
359 template <typename PromiseFactory>
SpawnWaitable(absl::string_view name,PromiseFactory promise_factory)360 auto SpawnWaitable(absl::string_view name, PromiseFactory promise_factory) {
361 return spine_->SpawnWaitable(name, std::move(promise_factory));
362 }
363
WasCancelledPushed()364 bool WasCancelledPushed() const {
365 return spine_->call_filters().WasCancelledPushed();
366 }
367
arena()368 Arena* arena() { return spine_->arena(); }
party()369 Party* party() { return spine_.get(); }
370
371 private:
372 friend class CallHandler;
373 RefCountedPtr<CallSpine> spine_;
374 };
375
376 class CallHandler {
377 public:
378 using NextMessage = ClientToServerNextMessage;
379
CallHandler(RefCountedPtr<CallSpine> spine)380 explicit CallHandler(RefCountedPtr<CallSpine> spine)
381 : spine_(std::move(spine)) {}
382
PullClientInitialMetadata()383 auto PullClientInitialMetadata() {
384 return spine_->PullClientInitialMetadata();
385 }
386
PushServerInitialMetadata(ServerMetadataHandle md)387 auto PushServerInitialMetadata(ServerMetadataHandle md) {
388 return spine_->PushServerInitialMetadata(std::move(md));
389 }
390
SpawnPushServerInitialMetadata(ServerMetadataHandle md)391 void SpawnPushServerInitialMetadata(ServerMetadataHandle md) {
392 return spine_->SpawnPushServerInitialMetadata(std::move(md));
393 }
394
PushServerTrailingMetadata(ServerMetadataHandle status)395 void PushServerTrailingMetadata(ServerMetadataHandle status) {
396 spine_->PushServerTrailingMetadata(std::move(status));
397 }
398
SpawnPushServerTrailingMetadata(ServerMetadataHandle status)399 void SpawnPushServerTrailingMetadata(ServerMetadataHandle status) {
400 spine_->SpawnPushServerTrailingMetadata(std::move(status));
401 }
402
OnDone(absl::AnyInvocable<void (bool)> fn)403 GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable<void(bool)> fn) {
404 return spine_->OnDone(std::move(fn));
405 }
406
407 template <typename Promise>
CancelIfFails(Promise promise)408 auto CancelIfFails(Promise promise) {
409 return spine_->CancelIfFails(std::move(promise));
410 }
411
PushMessage(MessageHandle message)412 auto PushMessage(MessageHandle message) {
413 return spine_->PushServerToClientMessage(std::move(message));
414 }
415
SpawnPushMessage(MessageHandle message)416 auto SpawnPushMessage(MessageHandle message) {
417 return spine_->SpawnPushServerToClientMessage(std::move(message));
418 }
419
PullMessage()420 auto PullMessage() { return spine_->PullClientToServerMessage(); }
421
WasCancelled()422 auto WasCancelled() { return spine_->WasCancelled(); }
423
WasCancelledPushed()424 bool WasCancelledPushed() const {
425 return spine_->call_filters().WasCancelledPushed();
426 }
427
428 template <typename PromiseFactory>
429 void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory,
430 DebugLocation whence = {}) {
431 spine_->SpawnGuarded(name, std::move(promise_factory), whence);
432 }
433
434 template <typename PromiseFactory>
SpawnGuardedUntilCallCompletes(absl::string_view name,PromiseFactory promise_factory)435 void SpawnGuardedUntilCallCompletes(absl::string_view name,
436 PromiseFactory promise_factory) {
437 spine_->SpawnGuardedUntilCallCompletes(name, std::move(promise_factory));
438 }
439
440 template <typename PromiseFactory>
SpawnInfallible(absl::string_view name,PromiseFactory promise_factory)441 void SpawnInfallible(absl::string_view name, PromiseFactory promise_factory) {
442 spine_->SpawnInfallible(name, std::move(promise_factory));
443 }
444
445 template <typename PromiseFactory>
SpawnWaitable(absl::string_view name,PromiseFactory promise_factory)446 auto SpawnWaitable(absl::string_view name, PromiseFactory promise_factory) {
447 return spine_->SpawnWaitable(name, std::move(promise_factory));
448 }
449
AddChildCall(const CallInitiator & initiator)450 void AddChildCall(const CallInitiator& initiator) {
451 CHECK(initiator.spine_ != nullptr);
452 spine_->AddChildCall(initiator.spine_);
453 }
454
arena()455 Arena* arena() { return spine_->arena(); }
party()456 Party* party() { return spine_.get(); }
457
458 private:
459 RefCountedPtr<CallSpine> spine_;
460 };
461
462 class UnstartedCallHandler {
463 public:
UnstartedCallHandler(RefCountedPtr<CallSpine> spine)464 explicit UnstartedCallHandler(RefCountedPtr<CallSpine> spine)
465 : spine_(std::move(spine)) {}
466
PushServerTrailingMetadata(ServerMetadataHandle status)467 void PushServerTrailingMetadata(ServerMetadataHandle status) {
468 spine_->PushServerTrailingMetadata(std::move(status));
469 }
470
OnDone(absl::AnyInvocable<void (bool)> fn)471 GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable<void(bool)> fn) {
472 return spine_->OnDone(std::move(fn));
473 }
474
475 template <typename Promise>
CancelIfFails(Promise promise)476 auto CancelIfFails(Promise promise) {
477 return spine_->CancelIfFails(std::move(promise));
478 }
479
480 template <typename PromiseFactory>
481 void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory,
482 DebugLocation whence = {}) {
483 spine_->SpawnGuarded(name, std::move(promise_factory), whence);
484 }
485
486 template <typename PromiseFactory>
SpawnGuardedUntilCallCompletes(absl::string_view name,PromiseFactory promise_factory)487 void SpawnGuardedUntilCallCompletes(absl::string_view name,
488 PromiseFactory promise_factory) {
489 spine_->SpawnGuardedUntilCallCompletes(name, std::move(promise_factory));
490 }
491
492 template <typename PromiseFactory>
SpawnInfallible(absl::string_view name,PromiseFactory promise_factory)493 void SpawnInfallible(absl::string_view name, PromiseFactory promise_factory) {
494 spine_->SpawnInfallible(name, std::move(promise_factory));
495 }
496
497 template <typename PromiseFactory>
SpawnWaitable(absl::string_view name,PromiseFactory promise_factory)498 auto SpawnWaitable(absl::string_view name, PromiseFactory promise_factory) {
499 return spine_->SpawnWaitable(name, std::move(promise_factory));
500 }
501
UnprocessedClientInitialMetadata()502 ClientMetadata& UnprocessedClientInitialMetadata() {
503 return spine_->UnprocessedClientInitialMetadata();
504 }
505
AddCallStack(RefCountedPtr<CallFilters::Stack> call_filters)506 void AddCallStack(RefCountedPtr<CallFilters::Stack> call_filters) {
507 spine_->call_filters().AddStack(std::move(call_filters));
508 }
509
StartCall()510 CallHandler StartCall() {
511 spine_->call_filters().Start();
512 return CallHandler(std::move(spine_));
513 }
514
arena()515 Arena* arena() { return spine_->arena(); }
516
517 private:
518 RefCountedPtr<CallSpine> spine_;
519 };
520
521 struct CallInitiatorAndHandler {
522 CallInitiator initiator;
523 UnstartedCallHandler handler;
524 };
525
526 CallInitiatorAndHandler MakeCallPair(
527 ClientMetadataHandle client_initial_metadata, RefCountedPtr<Arena> arena);
528
529 template <typename CallHalf>
MessagesFrom(CallHalf h)530 auto MessagesFrom(CallHalf h) {
531 struct Wrapper {
532 CallHalf h;
533 auto Next() { return h.PullMessage(); }
534 };
535 return Wrapper{std::move(h)};
536 }
537
538 template <typename CallHalf>
MessagesFrom(CallHalf * h)539 auto MessagesFrom(CallHalf* h) {
540 struct Wrapper {
541 CallHalf* h;
542 auto Next() { return h->PullMessage(); }
543 };
544 return Wrapper{h};
545 }
546
547 // Forward a call from `call_handler` to `call_initiator` (with initial metadata
548 // `client_initial_metadata`)
549 // `on_server_trailing_metadata_from_initiator` is a callback that will be
550 // called with the server trailing metadata received by the initiator, and can
551 // be used to mutate that metadata if desired.
552 void ForwardCall(
553 CallHandler call_handler, CallInitiator call_initiator,
554 absl::AnyInvocable<void(ServerMetadata&)>
555 on_server_trailing_metadata_from_initiator = [](ServerMetadata&) {});
556
557 } // namespace grpc_core
558
559 #endif // GRPC_SRC_CORE_LIB_TRANSPORT_CALL_SPINE_H
560