• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GRPC_SRC_CORE_LIB_TRANSPORT_CALL_SPINE_H
16 #define GRPC_SRC_CORE_LIB_TRANSPORT_CALL_SPINE_H
17 
18 #include <grpc/support/port_platform.h>
19 
20 #include "absl/log/check.h"
21 #include "src/core/lib/promise/detail/status.h"
22 #include "src/core/lib/promise/if.h"
23 #include "src/core/lib/promise/latch.h"
24 #include "src/core/lib/promise/party.h"
25 #include "src/core/lib/promise/pipe.h"
26 #include "src/core/lib/promise/prioritized_race.h"
27 #include "src/core/lib/promise/promise.h"
28 #include "src/core/lib/promise/status_flag.h"
29 #include "src/core/lib/promise/try_seq.h"
30 #include "src/core/lib/transport/call_arena_allocator.h"
31 #include "src/core/lib/transport/call_filters.h"
32 #include "src/core/lib/transport/message.h"
33 #include "src/core/lib/transport/metadata.h"
34 #include "src/core/util/dual_ref_counted.h"
35 
36 namespace grpc_core {
37 
38 // The common middle part of a call - a reference is held by each of
39 // CallInitiator and CallHandler - which provide interfaces that are appropriate
40 // for each side of a call.
41 // Hosts context, call filters, and the arena.
42 class CallSpine final : public Party {
43  public:
Create(ClientMetadataHandle client_initial_metadata,RefCountedPtr<Arena> arena)44   static RefCountedPtr<CallSpine> Create(
45       ClientMetadataHandle client_initial_metadata,
46       RefCountedPtr<Arena> arena) {
47     Arena* arena_ptr = arena.get();
48     return RefCountedPtr<CallSpine>(arena_ptr->New<CallSpine>(
49         std::move(client_initial_metadata), std::move(arena)));
50   }
51 
~CallSpine()52   ~CallSpine() override { CallOnDone(true); }
53 
call_filters()54   CallFilters& call_filters() { return call_filters_; }
55 
56   // Add a callback to be called when server trailing metadata is received and
57   // return true.
58   // If CallOnDone has already been invoked, does nothing and returns false.
OnDone(absl::AnyInvocable<void (bool)> fn)59   GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable<void(bool)> fn) {
60     if (call_filters().WasServerTrailingMetadataPulled()) {
61       return false;
62     }
63     if (on_done_ == nullptr) {
64       on_done_ = std::move(fn);
65       return true;
66     }
67     on_done_ = [first = std::move(fn),
68                 next = std::move(on_done_)](bool cancelled) mutable {
69       first(cancelled);
70       next(cancelled);
71     };
72     return true;
73   }
CallOnDone(bool cancelled)74   void CallOnDone(bool cancelled) {
75     if (on_done_ != nullptr) std::exchange(on_done_, nullptr)(cancelled);
76   }
77 
PullServerInitialMetadata()78   auto PullServerInitialMetadata() {
79     return call_filters().PullServerInitialMetadata();
80   }
81 
PullServerTrailingMetadata()82   auto PullServerTrailingMetadata() {
83     return Map(
84         call_filters().PullServerTrailingMetadata(),
85         [this](ServerMetadataHandle result) {
86           CallOnDone(result->get(GrpcCallWasCancelled()).value_or(false));
87           return result;
88         });
89   }
90 
PushClientToServerMessage(MessageHandle message)91   auto PushClientToServerMessage(MessageHandle message) {
92     return call_filters().PushClientToServerMessage(std::move(message));
93   }
94 
PullClientToServerMessage()95   auto PullClientToServerMessage() {
96     return call_filters().PullClientToServerMessage();
97   }
98 
PushServerToClientMessage(MessageHandle message)99   auto PushServerToClientMessage(MessageHandle message) {
100     return call_filters().PushServerToClientMessage(std::move(message));
101   }
102 
PullServerToClientMessage()103   auto PullServerToClientMessage() {
104     return call_filters().PullServerToClientMessage();
105   }
106 
PushServerTrailingMetadata(ServerMetadataHandle md)107   void PushServerTrailingMetadata(ServerMetadataHandle md) {
108     GRPC_TRACE_LOG(call_state, INFO)
109         << "[call_state] PushServerTrailingMetadata: " << md->DebugString();
110     call_filters().PushServerTrailingMetadata(std::move(md));
111   }
112 
FinishSends()113   void FinishSends() { call_filters().FinishClientToServerSends(); }
114 
PullClientInitialMetadata()115   auto PullClientInitialMetadata() {
116     return call_filters().PullClientInitialMetadata();
117   }
118 
PushServerInitialMetadata(ServerMetadataHandle md)119   StatusFlag PushServerInitialMetadata(ServerMetadataHandle md) {
120     return call_filters().PushServerInitialMetadata(std::move(md));
121   }
122 
WasCancelled()123   auto WasCancelled() { return call_filters().WasCancelled(); }
124 
UnprocessedClientInitialMetadata()125   ClientMetadata& UnprocessedClientInitialMetadata() {
126     return *call_filters().unprocessed_client_initial_metadata();
127   }
128 
129   // Wrap a promise so that if it returns failure it automatically cancels
130   // the rest of the call.
131   // The resulting (returned) promise will resolve to Empty.
132   template <typename Promise>
CancelIfFails(Promise promise)133   auto CancelIfFails(Promise promise) {
134     DCHECK(GetContext<Activity>() == this);
135     using P = promise_detail::PromiseLike<Promise>;
136     using ResultType = typename P::Result;
137     return Map(std::move(promise), [this](ResultType r) {
138       CancelIfFailed(r);
139       return r;
140     });
141   }
142 
143   template <typename StatusType>
CancelIfFailed(const StatusType & r)144   void CancelIfFailed(const StatusType& r) {
145     if (!IsStatusOk(r)) {
146       Cancel();
147     }
148   }
149 
Cancel()150   void Cancel() { call_filters().Cancel(); }
151 
152   // Spawn a promise that returns Empty{} and save some boilerplate handling
153   // that detail.
154   template <typename PromiseFactory>
SpawnInfallible(absl::string_view name,PromiseFactory promise_factory)155   void SpawnInfallible(absl::string_view name, PromiseFactory promise_factory) {
156     Spawn(name, std::move(promise_factory), [](Empty) {});
157   }
158 
159   // Spawn a promise that returns some status-like type; if the status
160   // represents failure automatically cancel the rest of the call.
161   template <typename PromiseFactory>
162   void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory,
163                     DebugLocation whence = {}) {
164     using FactoryType =
165         promise_detail::OncePromiseFactory<void, PromiseFactory>;
166     using PromiseType = typename FactoryType::Promise;
167     using ResultType = typename PromiseType::Result;
168     static_assert(
169         std::is_same<bool,
170                      decltype(IsStatusOk(std::declval<ResultType>()))>::value,
171         "SpawnGuarded promise must return a status-like object");
172     Spawn(name, std::move(promise_factory), [this, whence](ResultType r) {
173       if (!IsStatusOk(r)) {
174         GRPC_TRACE_LOG(promise_primitives, INFO)
175             << "SpawnGuarded sees failure: " << r
176             << " (source: " << whence.file() << ":" << whence.line() << ")";
177         auto status = StatusCast<ServerMetadataHandle>(std::move(r));
178         status->Set(GrpcCallWasCancelled(), true);
179         PushServerTrailingMetadata(std::move(status));
180       }
181     });
182   }
183 
184   // Wrap a promise so that if the call completes that promise is cancelled.
185   template <typename Promise>
UntilCallCompletes(Promise promise)186   auto UntilCallCompletes(Promise promise) {
187     using Result = PromiseResult<Promise>;
188     return PrioritizedRace(std::move(promise), Map(WasCancelled(), [](bool) {
189                              return FailureStatusCast<Result>(Failure{});
190                            }));
191   }
192 
193   template <typename PromiseFactory>
SpawnGuardedUntilCallCompletes(absl::string_view name,PromiseFactory promise_factory)194   void SpawnGuardedUntilCallCompletes(absl::string_view name,
195                                       PromiseFactory promise_factory) {
196     SpawnGuarded(name, [this, promise_factory]() mutable {
197       return UntilCallCompletes(promise_factory());
198     });
199   }
200 
201   // Spawned operations: these are callable from /outside/ the call; they spawn
202   // an operation into the call and execute that operation.
203 
SpawnPushServerInitialMetadata(ServerMetadataHandle md)204   void SpawnPushServerInitialMetadata(ServerMetadataHandle md) {
205     SpawnInfallible(
206         "push-server-initial-metadata",
207         [md = std::move(md), self = RefAsSubclass<CallSpine>()]() mutable {
208           self->CancelIfFailed(self->PushServerInitialMetadata(std::move(md)));
209         });
210   }
211 
SpawnPushServerToClientMessage(MessageHandle msg)212   auto SpawnPushServerToClientMessage(MessageHandle msg) {
213     return SpawnWaitable(
214         "push-message",
215         [msg = std::move(msg), self = RefAsSubclass<CallSpine>()]() mutable {
216           return self->CancelIfFails(
217               self->PushServerToClientMessage(std::move(msg)));
218         });
219   }
220 
SpawnPushClientToServerMessage(MessageHandle msg)221   auto SpawnPushClientToServerMessage(MessageHandle msg) {
222     return SpawnWaitable(
223         "push-message",
224         [msg = std::move(msg), self = RefAsSubclass<CallSpine>()]() mutable {
225           return self->CancelIfFails(
226               self->PushClientToServerMessage(std::move(msg)));
227         });
228   }
229 
SpawnFinishSends()230   void SpawnFinishSends() {
231     SpawnInfallible("finish-sends", [self = RefAsSubclass<CallSpine>()]() {
232       self->FinishSends();
233       return Empty{};
234     });
235   }
236 
SpawnPushServerTrailingMetadata(ServerMetadataHandle md)237   void SpawnPushServerTrailingMetadata(ServerMetadataHandle md) {
238     SpawnInfallible(
239         "push-server-trailing-metadata",
240         [md = std::move(md), self = RefAsSubclass<CallSpine>()]() mutable {
241           self->PushServerTrailingMetadata(std::move(md));
242           return Empty{};
243         });
244   }
245 
SpawnCancel()246   void SpawnCancel() {
247     SpawnInfallible("cancel", [self = RefAsSubclass<CallSpine>()]() {
248       self->call_filters().Cancel();
249     });
250   }
251 
AddChildCall(RefCountedPtr<CallSpine> child_call)252   void AddChildCall(RefCountedPtr<CallSpine> child_call) {
253     child_calls_.push_back(std::move(child_call));
254     if (child_calls_.size() == 1) {
255       SpawnInfallible(
256           "check_cancellation", [self = RefAsSubclass<CallSpine>()]() mutable {
257             auto was_completed =
258                 self->call_filters().ServerTrailingMetadataWasPushed();
259             return Map(std::move(was_completed),
260                        [self = std::move(self)](Empty) {
261                          for (auto& child : self->child_calls_) {
262                            child->SpawnCancel();
263                          }
264                          return Empty{};
265                        });
266           });
267     }
268   }
269 
270  private:
271   friend class Arena;
CallSpine(ClientMetadataHandle client_initial_metadata,RefCountedPtr<Arena> arena)272   CallSpine(ClientMetadataHandle client_initial_metadata,
273             RefCountedPtr<Arena> arena)
274       : Party(std::move(arena)),
275         call_filters_(std::move(client_initial_metadata)) {}
276 
277   // Call filters/pipes part of the spine
278   CallFilters call_filters_;
279   absl::AnyInvocable<void(bool)> on_done_{nullptr};
280   // Call spines that should be cancelled if this spine is cancelled
281   absl::InlinedVector<RefCountedPtr<CallSpine>, 3> child_calls_;
282 };
283 
284 class CallHandler;
285 
286 class CallInitiator {
287  public:
288   using NextMessage = ServerToClientNextMessage;
289 
290   CallInitiator() = default;
CallInitiator(RefCountedPtr<CallSpine> spine)291   explicit CallInitiator(RefCountedPtr<CallSpine> spine)
292       : spine_(std::move(spine)) {}
293 
294   template <typename Promise>
CancelIfFails(Promise promise)295   auto CancelIfFails(Promise promise) {
296     return spine_->CancelIfFails(std::move(promise));
297   }
298 
PullServerInitialMetadata()299   auto PullServerInitialMetadata() {
300     return spine_->PullServerInitialMetadata();
301   }
302 
PushMessage(MessageHandle message)303   auto PushMessage(MessageHandle message) {
304     return spine_->PushClientToServerMessage(std::move(message));
305   }
306 
SpawnPushMessage(MessageHandle message)307   auto SpawnPushMessage(MessageHandle message) {
308     return spine_->SpawnPushClientToServerMessage(std::move(message));
309   }
310 
FinishSends()311   void FinishSends() { spine_->FinishSends(); }
312 
SpawnFinishSends()313   void SpawnFinishSends() { spine_->SpawnFinishSends(); }
314 
PullMessage()315   auto PullMessage() { return spine_->PullServerToClientMessage(); }
316 
PullServerTrailingMetadata()317   auto PullServerTrailingMetadata() {
318     return spine_->PullServerTrailingMetadata();
319   }
320 
Cancel(absl::Status error)321   void Cancel(absl::Status error) {
322     CHECK(!error.ok());
323     auto status = ServerMetadataFromStatus(error);
324     status->Set(GrpcCallWasCancelled(), true);
325     spine_->PushServerTrailingMetadata(std::move(status));
326   }
327 
SpawnCancel(absl::Status error)328   void SpawnCancel(absl::Status error) {
329     CHECK(!error.ok());
330     auto status = ServerMetadataFromStatus(error);
331     status->Set(GrpcCallWasCancelled(), true);
332     spine_->SpawnPushServerTrailingMetadata(std::move(status));
333   }
334 
Cancel()335   void Cancel() { spine_->Cancel(); }
336 
SpawnCancel()337   void SpawnCancel() { spine_->SpawnCancel(); }
338 
OnDone(absl::AnyInvocable<void (bool)> fn)339   GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable<void(bool)> fn) {
340     return spine_->OnDone(std::move(fn));
341   }
342 
343   template <typename PromiseFactory>
SpawnGuarded(absl::string_view name,PromiseFactory promise_factory)344   void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory) {
345     spine_->SpawnGuarded(name, std::move(promise_factory));
346   }
347 
348   template <typename PromiseFactory>
SpawnGuardedUntilCallCompletes(absl::string_view name,PromiseFactory promise_factory)349   void SpawnGuardedUntilCallCompletes(absl::string_view name,
350                                       PromiseFactory promise_factory) {
351     spine_->SpawnGuardedUntilCallCompletes(name, std::move(promise_factory));
352   }
353 
354   template <typename PromiseFactory>
SpawnInfallible(absl::string_view name,PromiseFactory promise_factory)355   void SpawnInfallible(absl::string_view name, PromiseFactory promise_factory) {
356     spine_->SpawnInfallible(name, std::move(promise_factory));
357   }
358 
359   template <typename PromiseFactory>
SpawnWaitable(absl::string_view name,PromiseFactory promise_factory)360   auto SpawnWaitable(absl::string_view name, PromiseFactory promise_factory) {
361     return spine_->SpawnWaitable(name, std::move(promise_factory));
362   }
363 
WasCancelledPushed()364   bool WasCancelledPushed() const {
365     return spine_->call_filters().WasCancelledPushed();
366   }
367 
arena()368   Arena* arena() { return spine_->arena(); }
party()369   Party* party() { return spine_.get(); }
370 
371  private:
372   friend class CallHandler;
373   RefCountedPtr<CallSpine> spine_;
374 };
375 
376 class CallHandler {
377  public:
378   using NextMessage = ClientToServerNextMessage;
379 
CallHandler(RefCountedPtr<CallSpine> spine)380   explicit CallHandler(RefCountedPtr<CallSpine> spine)
381       : spine_(std::move(spine)) {}
382 
PullClientInitialMetadata()383   auto PullClientInitialMetadata() {
384     return spine_->PullClientInitialMetadata();
385   }
386 
PushServerInitialMetadata(ServerMetadataHandle md)387   auto PushServerInitialMetadata(ServerMetadataHandle md) {
388     return spine_->PushServerInitialMetadata(std::move(md));
389   }
390 
SpawnPushServerInitialMetadata(ServerMetadataHandle md)391   void SpawnPushServerInitialMetadata(ServerMetadataHandle md) {
392     return spine_->SpawnPushServerInitialMetadata(std::move(md));
393   }
394 
PushServerTrailingMetadata(ServerMetadataHandle status)395   void PushServerTrailingMetadata(ServerMetadataHandle status) {
396     spine_->PushServerTrailingMetadata(std::move(status));
397   }
398 
SpawnPushServerTrailingMetadata(ServerMetadataHandle status)399   void SpawnPushServerTrailingMetadata(ServerMetadataHandle status) {
400     spine_->SpawnPushServerTrailingMetadata(std::move(status));
401   }
402 
OnDone(absl::AnyInvocable<void (bool)> fn)403   GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable<void(bool)> fn) {
404     return spine_->OnDone(std::move(fn));
405   }
406 
407   template <typename Promise>
CancelIfFails(Promise promise)408   auto CancelIfFails(Promise promise) {
409     return spine_->CancelIfFails(std::move(promise));
410   }
411 
PushMessage(MessageHandle message)412   auto PushMessage(MessageHandle message) {
413     return spine_->PushServerToClientMessage(std::move(message));
414   }
415 
SpawnPushMessage(MessageHandle message)416   auto SpawnPushMessage(MessageHandle message) {
417     return spine_->SpawnPushServerToClientMessage(std::move(message));
418   }
419 
PullMessage()420   auto PullMessage() { return spine_->PullClientToServerMessage(); }
421 
WasCancelled()422   auto WasCancelled() { return spine_->WasCancelled(); }
423 
WasCancelledPushed()424   bool WasCancelledPushed() const {
425     return spine_->call_filters().WasCancelledPushed();
426   }
427 
428   template <typename PromiseFactory>
429   void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory,
430                     DebugLocation whence = {}) {
431     spine_->SpawnGuarded(name, std::move(promise_factory), whence);
432   }
433 
434   template <typename PromiseFactory>
SpawnGuardedUntilCallCompletes(absl::string_view name,PromiseFactory promise_factory)435   void SpawnGuardedUntilCallCompletes(absl::string_view name,
436                                       PromiseFactory promise_factory) {
437     spine_->SpawnGuardedUntilCallCompletes(name, std::move(promise_factory));
438   }
439 
440   template <typename PromiseFactory>
SpawnInfallible(absl::string_view name,PromiseFactory promise_factory)441   void SpawnInfallible(absl::string_view name, PromiseFactory promise_factory) {
442     spine_->SpawnInfallible(name, std::move(promise_factory));
443   }
444 
445   template <typename PromiseFactory>
SpawnWaitable(absl::string_view name,PromiseFactory promise_factory)446   auto SpawnWaitable(absl::string_view name, PromiseFactory promise_factory) {
447     return spine_->SpawnWaitable(name, std::move(promise_factory));
448   }
449 
AddChildCall(const CallInitiator & initiator)450   void AddChildCall(const CallInitiator& initiator) {
451     CHECK(initiator.spine_ != nullptr);
452     spine_->AddChildCall(initiator.spine_);
453   }
454 
arena()455   Arena* arena() { return spine_->arena(); }
party()456   Party* party() { return spine_.get(); }
457 
458  private:
459   RefCountedPtr<CallSpine> spine_;
460 };
461 
462 class UnstartedCallHandler {
463  public:
UnstartedCallHandler(RefCountedPtr<CallSpine> spine)464   explicit UnstartedCallHandler(RefCountedPtr<CallSpine> spine)
465       : spine_(std::move(spine)) {}
466 
PushServerTrailingMetadata(ServerMetadataHandle status)467   void PushServerTrailingMetadata(ServerMetadataHandle status) {
468     spine_->PushServerTrailingMetadata(std::move(status));
469   }
470 
OnDone(absl::AnyInvocable<void (bool)> fn)471   GRPC_MUST_USE_RESULT bool OnDone(absl::AnyInvocable<void(bool)> fn) {
472     return spine_->OnDone(std::move(fn));
473   }
474 
475   template <typename Promise>
CancelIfFails(Promise promise)476   auto CancelIfFails(Promise promise) {
477     return spine_->CancelIfFails(std::move(promise));
478   }
479 
480   template <typename PromiseFactory>
481   void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory,
482                     DebugLocation whence = {}) {
483     spine_->SpawnGuarded(name, std::move(promise_factory), whence);
484   }
485 
486   template <typename PromiseFactory>
SpawnGuardedUntilCallCompletes(absl::string_view name,PromiseFactory promise_factory)487   void SpawnGuardedUntilCallCompletes(absl::string_view name,
488                                       PromiseFactory promise_factory) {
489     spine_->SpawnGuardedUntilCallCompletes(name, std::move(promise_factory));
490   }
491 
492   template <typename PromiseFactory>
SpawnInfallible(absl::string_view name,PromiseFactory promise_factory)493   void SpawnInfallible(absl::string_view name, PromiseFactory promise_factory) {
494     spine_->SpawnInfallible(name, std::move(promise_factory));
495   }
496 
497   template <typename PromiseFactory>
SpawnWaitable(absl::string_view name,PromiseFactory promise_factory)498   auto SpawnWaitable(absl::string_view name, PromiseFactory promise_factory) {
499     return spine_->SpawnWaitable(name, std::move(promise_factory));
500   }
501 
UnprocessedClientInitialMetadata()502   ClientMetadata& UnprocessedClientInitialMetadata() {
503     return spine_->UnprocessedClientInitialMetadata();
504   }
505 
AddCallStack(RefCountedPtr<CallFilters::Stack> call_filters)506   void AddCallStack(RefCountedPtr<CallFilters::Stack> call_filters) {
507     spine_->call_filters().AddStack(std::move(call_filters));
508   }
509 
StartCall()510   CallHandler StartCall() {
511     spine_->call_filters().Start();
512     return CallHandler(std::move(spine_));
513   }
514 
arena()515   Arena* arena() { return spine_->arena(); }
516 
517  private:
518   RefCountedPtr<CallSpine> spine_;
519 };
520 
521 struct CallInitiatorAndHandler {
522   CallInitiator initiator;
523   UnstartedCallHandler handler;
524 };
525 
526 CallInitiatorAndHandler MakeCallPair(
527     ClientMetadataHandle client_initial_metadata, RefCountedPtr<Arena> arena);
528 
529 template <typename CallHalf>
MessagesFrom(CallHalf h)530 auto MessagesFrom(CallHalf h) {
531   struct Wrapper {
532     CallHalf h;
533     auto Next() { return h.PullMessage(); }
534   };
535   return Wrapper{std::move(h)};
536 }
537 
538 template <typename CallHalf>
MessagesFrom(CallHalf * h)539 auto MessagesFrom(CallHalf* h) {
540   struct Wrapper {
541     CallHalf* h;
542     auto Next() { return h->PullMessage(); }
543   };
544   return Wrapper{h};
545 }
546 
547 // Forward a call from `call_handler` to `call_initiator` (with initial metadata
548 // `client_initial_metadata`)
549 // `on_server_trailing_metadata_from_initiator` is a callback that will be
550 // called with the server trailing metadata received by the initiator, and can
551 // be used to mutate that metadata if desired.
552 void ForwardCall(
553     CallHandler call_handler, CallInitiator call_initiator,
554     absl::AnyInvocable<void(ServerMetadata&)>
555         on_server_trailing_metadata_from_initiator = [](ServerMetadata&) {});
556 
557 }  // namespace grpc_core
558 
559 #endif  // GRPC_SRC_CORE_LIB_TRANSPORT_CALL_SPINE_H
560