• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include <grpc/support/port_platform.h>
20 
21 #include "src/core/lib/channel/connected_channel.h"
22 
23 #include <inttypes.h>
24 
25 #include <functional>
26 #include <memory>
27 #include <string>
28 #include <type_traits>
29 #include <utility>
30 
31 #include "absl/status/status.h"
32 #include "absl/status/statusor.h"
33 #include "absl/types/optional.h"
34 
35 #include <grpc/grpc.h>
36 #include <grpc/status.h>
37 #include <grpc/support/alloc.h>
38 #include <grpc/support/log.h>
39 
40 #include "src/core/lib/channel/call_finalization.h"
41 #include "src/core/lib/channel/channel_args.h"
42 #include "src/core/lib/channel/channel_fwd.h"
43 #include "src/core/lib/channel/channel_stack.h"
44 #include "src/core/lib/config/core_configuration.h"
45 #include "src/core/lib/debug/trace.h"
46 #include "src/core/lib/experiments/experiments.h"
47 #include "src/core/lib/gpr/alloc.h"
48 #include "src/core/lib/gprpp/debug_location.h"
49 #include "src/core/lib/gprpp/orphanable.h"
50 #include "src/core/lib/gprpp/ref_counted_ptr.h"
51 #include "src/core/lib/gprpp/time.h"
52 #include "src/core/lib/iomgr/call_combiner.h"
53 #include "src/core/lib/iomgr/closure.h"
54 #include "src/core/lib/iomgr/error.h"
55 #include "src/core/lib/iomgr/polling_entity.h"
56 #include "src/core/lib/promise/activity.h"
57 #include "src/core/lib/promise/arena_promise.h"
58 #include "src/core/lib/promise/context.h"
59 #include "src/core/lib/promise/detail/status.h"
60 #include "src/core/lib/promise/for_each.h"
61 #include "src/core/lib/promise/if.h"
62 #include "src/core/lib/promise/latch.h"
63 #include "src/core/lib/promise/loop.h"
64 #include "src/core/lib/promise/map.h"
65 #include "src/core/lib/promise/party.h"
66 #include "src/core/lib/promise/pipe.h"
67 #include "src/core/lib/promise/poll.h"
68 #include "src/core/lib/promise/promise.h"
69 #include "src/core/lib/promise/race.h"
70 #include "src/core/lib/promise/seq.h"
71 #include "src/core/lib/promise/try_seq.h"
72 #include "src/core/lib/resource_quota/arena.h"
73 #include "src/core/lib/slice/slice.h"
74 #include "src/core/lib/slice/slice_buffer.h"
75 #include "src/core/lib/surface/call.h"
76 #include "src/core/lib/surface/call_trace.h"
77 #include "src/core/lib/surface/channel_stack_type.h"
78 #include "src/core/lib/transport/batch_builder.h"
79 #include "src/core/lib/transport/error_utils.h"
80 #include "src/core/lib/transport/metadata_batch.h"
81 #include "src/core/lib/transport/transport.h"
82 
83 typedef struct connected_channel_channel_data {
84   grpc_core::Transport* transport;
85 } channel_data;
86 
87 struct callback_state {
88   grpc_closure closure;
89   grpc_closure* original_closure;
90   grpc_core::CallCombiner* call_combiner;
91   const char* reason;
92 };
93 typedef struct connected_channel_call_data {
94   grpc_core::CallCombiner* call_combiner;
95   // Closures used for returning results on the call combiner.
96   callback_state on_complete[6];  // Max number of pending batches.
97   callback_state recv_initial_metadata_ready;
98   callback_state recv_message_ready;
99   callback_state recv_trailing_metadata_ready;
100 } call_data;
101 
run_in_call_combiner(void * arg,grpc_error_handle error)102 static void run_in_call_combiner(void* arg, grpc_error_handle error) {
103   callback_state* state = static_cast<callback_state*>(arg);
104   GRPC_CALL_COMBINER_START(state->call_combiner, state->original_closure, error,
105                            state->reason);
106 }
107 
run_cancel_in_call_combiner(void * arg,grpc_error_handle error)108 static void run_cancel_in_call_combiner(void* arg, grpc_error_handle error) {
109   run_in_call_combiner(arg, error);
110   gpr_free(arg);
111 }
112 
intercept_callback(call_data * calld,callback_state * state,bool free_when_done,const char * reason,grpc_closure ** original_closure)113 static void intercept_callback(call_data* calld, callback_state* state,
114                                bool free_when_done, const char* reason,
115                                grpc_closure** original_closure) {
116   state->original_closure = *original_closure;
117   state->call_combiner = calld->call_combiner;
118   state->reason = reason;
119   *original_closure = GRPC_CLOSURE_INIT(
120       &state->closure,
121       free_when_done ? run_cancel_in_call_combiner : run_in_call_combiner,
122       state, grpc_schedule_on_exec_ctx);
123 }
124 
get_state_for_batch(call_data * calld,grpc_transport_stream_op_batch * batch)125 static callback_state* get_state_for_batch(
126     call_data* calld, grpc_transport_stream_op_batch* batch) {
127   if (batch->send_initial_metadata) return &calld->on_complete[0];
128   if (batch->send_message) return &calld->on_complete[1];
129   if (batch->send_trailing_metadata) return &calld->on_complete[2];
130   if (batch->recv_initial_metadata) return &calld->on_complete[3];
131   if (batch->recv_message) return &calld->on_complete[4];
132   if (batch->recv_trailing_metadata) return &calld->on_complete[5];
133   GPR_UNREACHABLE_CODE(return nullptr);
134 }
135 
136 // We perform a small hack to locate transport data alongside the connected
137 // channel data in call allocations, to allow everything to be pulled in minimal
138 // cache line requests
139 #define TRANSPORT_STREAM_FROM_CALL_DATA(calld) \
140   ((grpc_stream*)(((char*)(calld)) +           \
141                   GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(call_data))))
142 #define CALL_DATA_FROM_TRANSPORT_STREAM(transport_stream) \
143   ((call_data*)(((char*)(transport_stream)) -             \
144                 GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(call_data))))
145 
146 // Intercept a call operation and either push it directly up or translate it
147 // into transport stream operations
connected_channel_start_transport_stream_op_batch(grpc_call_element * elem,grpc_transport_stream_op_batch * batch)148 static void connected_channel_start_transport_stream_op_batch(
149     grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
150   call_data* calld = static_cast<call_data*>(elem->call_data);
151   channel_data* chand = static_cast<channel_data*>(elem->channel_data);
152   if (batch->recv_initial_metadata) {
153     callback_state* state = &calld->recv_initial_metadata_ready;
154     intercept_callback(
155         calld, state, false, "recv_initial_metadata_ready",
156         &batch->payload->recv_initial_metadata.recv_initial_metadata_ready);
157   }
158   if (batch->recv_message) {
159     callback_state* state = &calld->recv_message_ready;
160     intercept_callback(calld, state, false, "recv_message_ready",
161                        &batch->payload->recv_message.recv_message_ready);
162   }
163   if (batch->recv_trailing_metadata) {
164     callback_state* state = &calld->recv_trailing_metadata_ready;
165     intercept_callback(
166         calld, state, false, "recv_trailing_metadata_ready",
167         &batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready);
168   }
169   if (batch->cancel_stream) {
170     // There can be more than one cancellation batch in flight at any
171     // given time, so we can't just pick out a fixed index into
172     // calld->on_complete like we can for the other ops.  However,
173     // cancellation isn't in the fast path, so we just allocate a new
174     // closure for each one.
175     callback_state* state =
176         static_cast<callback_state*>(gpr_malloc(sizeof(*state)));
177     intercept_callback(calld, state, true, "on_complete (cancel_stream)",
178                        &batch->on_complete);
179   } else if (batch->on_complete != nullptr) {
180     callback_state* state = get_state_for_batch(calld, batch);
181     intercept_callback(calld, state, false, "on_complete", &batch->on_complete);
182   }
183   chand->transport->filter_stack_transport()->PerformStreamOp(
184       TRANSPORT_STREAM_FROM_CALL_DATA(calld), batch);
185   GRPC_CALL_COMBINER_STOP(calld->call_combiner, "passed batch to transport");
186 }
187 
connected_channel_start_transport_op(grpc_channel_element * elem,grpc_transport_op * op)188 static void connected_channel_start_transport_op(grpc_channel_element* elem,
189                                                  grpc_transport_op* op) {
190   channel_data* chand = static_cast<channel_data*>(elem->channel_data);
191   chand->transport->PerformOp(op);
192 }
193 
194 // Constructor for call_data
connected_channel_init_call_elem(grpc_call_element * elem,const grpc_call_element_args * args)195 static grpc_error_handle connected_channel_init_call_elem(
196     grpc_call_element* elem, const grpc_call_element_args* args) {
197   call_data* calld = static_cast<call_data*>(elem->call_data);
198   channel_data* chand = static_cast<channel_data*>(elem->channel_data);
199   calld->call_combiner = args->call_combiner;
200   chand->transport->filter_stack_transport()->InitStream(
201       TRANSPORT_STREAM_FROM_CALL_DATA(calld), &args->call_stack->refcount,
202       args->server_transport_data, args->arena);
203   return absl::OkStatus();
204 }
205 
set_pollset_or_pollset_set(grpc_call_element * elem,grpc_polling_entity * pollent)206 static void set_pollset_or_pollset_set(grpc_call_element* elem,
207                                        grpc_polling_entity* pollent) {
208   call_data* calld = static_cast<call_data*>(elem->call_data);
209   channel_data* chand = static_cast<channel_data*>(elem->channel_data);
210   chand->transport->SetPollingEntity(TRANSPORT_STREAM_FROM_CALL_DATA(calld),
211                                      pollent);
212 }
213 
214 // Destructor for call_data
connected_channel_destroy_call_elem(grpc_call_element * elem,const grpc_call_final_info *,grpc_closure * then_schedule_closure)215 static void connected_channel_destroy_call_elem(
216     grpc_call_element* elem, const grpc_call_final_info* /*final_info*/,
217     grpc_closure* then_schedule_closure) {
218   call_data* calld = static_cast<call_data*>(elem->call_data);
219   channel_data* chand = static_cast<channel_data*>(elem->channel_data);
220   chand->transport->filter_stack_transport()->DestroyStream(
221       TRANSPORT_STREAM_FROM_CALL_DATA(calld), then_schedule_closure);
222 }
223 
224 // Constructor for channel_data
connected_channel_init_channel_elem(grpc_channel_element * elem,grpc_channel_element_args * args)225 static grpc_error_handle connected_channel_init_channel_elem(
226     grpc_channel_element* elem, grpc_channel_element_args* args) {
227   channel_data* cd = static_cast<channel_data*>(elem->channel_data);
228   GPR_ASSERT(args->is_last);
229   cd->transport = args->channel_args.GetObject<grpc_core::Transport>();
230   return absl::OkStatus();
231 }
232 
233 // Destructor for channel_data
connected_channel_destroy_channel_elem(grpc_channel_element * elem)234 static void connected_channel_destroy_channel_elem(grpc_channel_element* elem) {
235   channel_data* cd = static_cast<channel_data*>(elem->channel_data);
236   if (cd->transport) {
237     cd->transport->Orphan();
238   }
239 }
240 
241 // No-op.
connected_channel_get_channel_info(grpc_channel_element *,const grpc_channel_info *)242 static void connected_channel_get_channel_info(
243     grpc_channel_element* /*elem*/, const grpc_channel_info* /*channel_info*/) {
244 }
245 
246 namespace grpc_core {
247 namespace {
248 
249 #if defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL) || \
250     defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL)
251 class ConnectedChannelStream : public Orphanable {
252  public:
ConnectedChannelStream(Transport * transport)253   explicit ConnectedChannelStream(Transport* transport)
254       : transport_(transport), stream_(nullptr, StreamDeleter(this)) {
255     GRPC_STREAM_REF_INIT(
256         &stream_refcount_, 1,
257         [](void* p, grpc_error_handle) {
258           static_cast<ConnectedChannelStream*>(p)->BeginDestroy();
259         },
260         this, "ConnectedChannelStream");
261   }
262 
transport()263   Transport* transport() { return transport_; }
stream_destroyed_closure()264   grpc_closure* stream_destroyed_closure() { return &stream_destroyed_; }
265 
batch_target()266   BatchBuilder::Target batch_target() {
267     return BatchBuilder::Target{transport_, stream_.get(), &stream_refcount_};
268   }
269 
IncrementRefCount(const char * reason="smartptr")270   void IncrementRefCount(const char* reason = "smartptr") {
271 #ifndef NDEBUG
272     grpc_stream_ref(&stream_refcount_, reason);
273 #else
274     (void)reason;
275     grpc_stream_ref(&stream_refcount_);
276 #endif
277   }
278 
Unref(const char * reason="smartptr")279   void Unref(const char* reason = "smartptr") {
280 #ifndef NDEBUG
281     grpc_stream_unref(&stream_refcount_, reason);
282 #else
283     (void)reason;
284     grpc_stream_unref(&stream_refcount_);
285 #endif
286   }
287 
InternalRef()288   RefCountedPtr<ConnectedChannelStream> InternalRef() {
289     IncrementRefCount("smartptr");
290     return RefCountedPtr<ConnectedChannelStream>(this);
291   }
292 
Orphan()293   void Orphan() final {
294     bool finished = finished_.IsSet();
295     if (grpc_call_trace.enabled()) {
296       gpr_log(GPR_DEBUG, "%s[connected] Orphan stream, finished: %d",
297               party_->DebugTag().c_str(), finished);
298     }
299     // If we hadn't already observed the stream to be finished, we need to
300     // cancel it at the transport.
301     if (!finished) {
302       party_->Spawn(
303           "finish",
304           [self = InternalRef()]() {
305             if (!self->finished_.IsSet()) {
306               self->finished_.Set();
307             }
308             return Empty{};
309           },
310           [](Empty) {});
311       GetContext<BatchBuilder>()->Cancel(batch_target(),
312                                          absl::CancelledError());
313     }
314     Unref("orphan connected stream");
315   }
316 
317   // Returns a promise that implements the receive message loop.
318   auto RecvMessages(PipeSender<MessageHandle>* incoming_messages,
319                     bool cancel_on_error);
320   // Returns a promise that implements the send message loop.
321   auto SendMessages(PipeReceiver<MessageHandle>* outgoing_messages);
322 
SetStream(grpc_stream * stream)323   void SetStream(grpc_stream* stream) { stream_.reset(stream); }
stream()324   grpc_stream* stream() { return stream_.get(); }
stream_refcount()325   grpc_stream_refcount* stream_refcount() { return &stream_refcount_; }
326 
set_finished()327   void set_finished() { finished_.Set(); }
WaitFinished()328   auto WaitFinished() { return finished_.Wait(); }
329 
330  private:
331   class StreamDeleter {
332    public:
StreamDeleter(ConnectedChannelStream * impl)333     explicit StreamDeleter(ConnectedChannelStream* impl) : impl_(impl) {}
operator ()(grpc_stream * stream) const334     void operator()(grpc_stream* stream) const {
335       if (stream == nullptr) return;
336       impl_->transport()->filter_stack_transport()->DestroyStream(
337           stream, impl_->stream_destroyed_closure());
338     }
339 
340    private:
341     ConnectedChannelStream* impl_;
342   };
343   using StreamPtr = std::unique_ptr<grpc_stream, StreamDeleter>;
344 
StreamDestroyed()345   void StreamDestroyed() {
346     call_context_->RunInContext([this] { this->~ConnectedChannelStream(); });
347   }
348 
BeginDestroy()349   void BeginDestroy() {
350     if (stream_ != nullptr) {
351       stream_.reset();
352     } else {
353       StreamDestroyed();
354     }
355   }
356 
357   Transport* const transport_;
358   RefCountedPtr<CallContext> const call_context_{
359       GetContext<CallContext>()->Ref()};
360   grpc_closure stream_destroyed_ =
361       MakeMemberClosure<ConnectedChannelStream,
362                         &ConnectedChannelStream::StreamDestroyed>(
363           this, DEBUG_LOCATION);
364   grpc_stream_refcount stream_refcount_;
365   StreamPtr stream_;
366   Arena* arena_ = GetContext<Arena>();
367   Party* const party_ = GetContext<Party>();
368   ExternallyObservableLatch<void> finished_;
369 };
370 
RecvMessages(PipeSender<MessageHandle> * incoming_messages,bool cancel_on_error)371 auto ConnectedChannelStream::RecvMessages(
372     PipeSender<MessageHandle>* incoming_messages, bool cancel_on_error) {
373   return Loop([self = InternalRef(), cancel_on_error,
374                incoming_messages = std::move(*incoming_messages)]() mutable {
375     return Seq(
376         GetContext<BatchBuilder>()->ReceiveMessage(self->batch_target()),
377         [cancel_on_error, &incoming_messages](
378             absl::StatusOr<absl::optional<MessageHandle>> status) mutable {
379           bool has_message = status.ok() && status->has_value();
380           auto publish_message = [&incoming_messages, &status]() {
381             auto pending_message = std::move(**status);
382             if (grpc_call_trace.enabled()) {
383               gpr_log(GPR_INFO,
384                       "%s[connected] RecvMessage: received payload of %" PRIdPTR
385                       " bytes",
386                       GetContext<Activity>()->DebugTag().c_str(),
387                       pending_message->payload()->Length());
388             }
389             return Map(incoming_messages.Push(std::move(pending_message)),
390                        [](bool ok) -> LoopCtl<absl::Status> {
391                          if (!ok) {
392                            if (grpc_call_trace.enabled()) {
393                              gpr_log(
394                                  GPR_INFO,
395                                  "%s[connected] RecvMessage: failed to "
396                                  "push message towards the application",
397                                  GetContext<Activity>()->DebugTag().c_str());
398                            }
399                            return absl::OkStatus();
400                          }
401                          return Continue{};
402                        });
403           };
404           auto publish_close = [cancel_on_error, &incoming_messages,
405                                 &status]() mutable {
406             if (grpc_call_trace.enabled()) {
407               gpr_log(GPR_INFO,
408                       "%s[connected] RecvMessage: reached end of stream with "
409                       "status:%s",
410                       GetContext<Activity>()->DebugTag().c_str(),
411                       status.status().ToString().c_str());
412             }
413             if (cancel_on_error && !status.ok()) {
414               incoming_messages.CloseWithError();
415             } else {
416               incoming_messages.Close();
417             }
418             return Immediate(LoopCtl<absl::Status>(status.status()));
419           };
420           return If(has_message, std::move(publish_message),
421                     std::move(publish_close));
422         });
423   });
424 }
425 
SendMessages(PipeReceiver<MessageHandle> * outgoing_messages)426 auto ConnectedChannelStream::SendMessages(
427     PipeReceiver<MessageHandle>* outgoing_messages) {
428   return ForEach(std::move(*outgoing_messages),
429                  [self = InternalRef()](MessageHandle message) {
430                    return GetContext<BatchBuilder>()->SendMessage(
431                        self->batch_target(), std::move(message));
432                  });
433 }
434 #endif  // defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL) ||
435         // defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL)
436 
437 #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL
MakeClientCallPromise(Transport * transport,CallArgs call_args,NextPromiseFactory)438 ArenaPromise<ServerMetadataHandle> MakeClientCallPromise(Transport* transport,
439                                                          CallArgs call_args,
440                                                          NextPromiseFactory) {
441   OrphanablePtr<ConnectedChannelStream> stream(
442       GetContext<Arena>()->New<ConnectedChannelStream>(transport));
443   stream->SetStream(static_cast<grpc_stream*>(GetContext<Arena>()->Alloc(
444       transport->filter_stack_transport()->SizeOfStream())));
445   transport->filter_stack_transport()->InitStream(stream->stream(),
446                                                   stream->stream_refcount(),
447                                                   nullptr, GetContext<Arena>());
448   auto* party = GetContext<Party>();
449   party->Spawn("set_polling_entity", call_args.polling_entity->Wait(),
450                [transport, stream = stream->InternalRef()](
451                    grpc_polling_entity polling_entity) {
452                  transport->SetPollingEntity(stream->stream(), &polling_entity);
453                });
454   // Start a loop to send messages from client_to_server_messages to the
455   // transport. When the pipe closes and the loop completes, send a trailing
456   // metadata batch to close the stream.
457   party->Spawn(
458       "send_messages",
459       TrySeq(stream->SendMessages(call_args.client_to_server_messages),
460              [stream = stream->InternalRef()]() {
461                return GetContext<BatchBuilder>()->SendClientTrailingMetadata(
462                    stream->batch_target());
463              }),
464       [](absl::Status) {});
465   // Start a promise to receive server initial metadata and then forward it up
466   // through the receiving pipe.
467   auto server_initial_metadata =
468       GetContext<Arena>()->MakePooled<ServerMetadata>();
469   party->Spawn(
470       "recv_initial_metadata",
471       TrySeq(GetContext<BatchBuilder>()->ReceiveServerInitialMetadata(
472                  stream->batch_target()),
473              [pipe = call_args.server_initial_metadata](
474                  ServerMetadataHandle server_initial_metadata) {
475                if (grpc_call_trace.enabled()) {
476                  gpr_log(GPR_DEBUG,
477                          "%s[connected] Publish client initial metadata: %s",
478                          GetContext<Activity>()->DebugTag().c_str(),
479                          server_initial_metadata->DebugString().c_str());
480                }
481                return Map(pipe->Push(std::move(server_initial_metadata)),
482                           [](bool r) {
483                             if (r) return absl::OkStatus();
484                             return absl::CancelledError();
485                           });
486              }),
487       [](absl::Status) {});
488 
489   // Build up the rest of the main call promise:
490 
491   // Create a promise that will send initial metadata and then signal completion
492   // of that via the token.
493   auto send_initial_metadata = Seq(
494       GetContext<BatchBuilder>()->SendClientInitialMetadata(
495           stream->batch_target(), std::move(call_args.client_initial_metadata)),
496       [sent_initial_metadata_token =
497            std::move(call_args.client_initial_metadata_outstanding)](
498           absl::Status status) mutable {
499         sent_initial_metadata_token.Complete(status.ok());
500         return status;
501       });
502   // Create a promise that will receive server trailing metadata.
503   // If this fails, we massage the error into metadata that we can report
504   // upwards.
505   auto server_trailing_metadata =
506       GetContext<Arena>()->MakePooled<ServerMetadata>();
507   auto recv_trailing_metadata =
508       Map(GetContext<BatchBuilder>()->ReceiveServerTrailingMetadata(
509               stream->batch_target()),
510           [](absl::StatusOr<ServerMetadataHandle> status) mutable {
511             if (!status.ok()) {
512               auto server_trailing_metadata =
513                   GetContext<Arena>()->MakePooled<ServerMetadata>();
514               grpc_status_code status_code = GRPC_STATUS_UNKNOWN;
515               std::string message;
516               grpc_error_get_status(status.status(), Timestamp::InfFuture(),
517                                     &status_code, &message, nullptr, nullptr);
518               server_trailing_metadata->Set(GrpcStatusMetadata(), status_code);
519               server_trailing_metadata->Set(GrpcMessageMetadata(),
520                                             Slice::FromCopiedString(message));
521               return server_trailing_metadata;
522             } else {
523               return std::move(*status);
524             }
525           });
526   // Finally the main call promise.
527   // Concurrently: send initial metadata and receive messages, until BOTH
528   // complete (or one fails).
529   // Next: receive trailing metadata, and return that up the stack.
530   auto recv_messages =
531       stream->RecvMessages(call_args.server_to_client_messages, false);
532   return Map(
533       [send_initial_metadata = std::move(send_initial_metadata),
534        recv_messages = std::move(recv_messages),
535        recv_trailing_metadata = std::move(recv_trailing_metadata),
536        done_send_initial_metadata = false, done_recv_messages = false,
537        done_recv_trailing_metadata =
538            false]() mutable -> Poll<ServerMetadataHandle> {
539         if (!done_send_initial_metadata) {
540           auto p = send_initial_metadata();
541           if (auto* r = p.value_if_ready()) {
542             done_send_initial_metadata = true;
543             if (!r->ok()) return StatusCast<ServerMetadataHandle>(*r);
544           }
545         }
546         if (!done_recv_messages) {
547           auto p = recv_messages();
548           if (p.ready()) {
549             // NOTE: ignore errors here, they'll be collected in the
550             // recv_trailing_metadata.
551             done_recv_messages = true;
552           } else {
553             return Pending{};
554           }
555         }
556         if (!done_recv_trailing_metadata) {
557           auto p = recv_trailing_metadata();
558           if (auto* r = p.value_if_ready()) {
559             done_recv_trailing_metadata = true;
560             return std::move(*r);
561           }
562         }
563         return Pending{};
564       },
565       [stream = std::move(stream)](ServerMetadataHandle result) {
566         stream->set_finished();
567         return result;
568       });
569 }
570 #endif
571 
572 #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL
MakeServerCallPromise(Transport * transport,CallArgs,NextPromiseFactory next_promise_factory)573 ArenaPromise<ServerMetadataHandle> MakeServerCallPromise(
574     Transport* transport, CallArgs, NextPromiseFactory next_promise_factory) {
575   OrphanablePtr<ConnectedChannelStream> stream(
576       GetContext<Arena>()->New<ConnectedChannelStream>(transport));
577 
578   stream->SetStream(static_cast<grpc_stream*>(GetContext<Arena>()->Alloc(
579       transport->filter_stack_transport()->SizeOfStream())));
580   transport->filter_stack_transport()->InitStream(
581       stream->stream(), stream->stream_refcount(),
582       GetContext<CallContext>()->server_call_context()->server_stream_data(),
583       GetContext<Arena>());
584   auto* party = GetContext<Party>();
585 
586   // Arifacts we need for the lifetime of the call.
587   struct CallData {
588     Pipe<MessageHandle> server_to_client;
589     Pipe<MessageHandle> client_to_server;
590     Pipe<ServerMetadataHandle> server_initial_metadata;
591     Latch<ServerMetadataHandle> failure_latch;
592     Latch<grpc_polling_entity> polling_entity_latch;
593     bool sent_initial_metadata = false;
594     bool sent_trailing_metadata = false;
595   };
596   auto* call_data = GetContext<Arena>()->New<CallData>();
597   GetContext<CallFinalization>()->Add(
598       [call_data](const grpc_call_final_info*) { call_data->~CallData(); });
599 
600   party->Spawn("set_polling_entity", call_data->polling_entity_latch.Wait(),
601                [transport, stream = stream->InternalRef()](
602                    grpc_polling_entity polling_entity) {
603                  transport->SetPollingEntity(stream->stream(), &polling_entity);
604                });
605 
606   auto server_to_client_empty =
607       call_data->server_to_client.receiver.AwaitEmpty();
608 
609   // Create a promise that will receive client initial metadata, and then run
610   // the main stem of the call (calling next_promise_factory up through the
611   // filters).
612   // Race the main call with failure_latch, allowing us to forcefully complete
613   // the call in the case of a failure.
614   auto recv_initial_metadata_then_run_promise =
615       TrySeq(GetContext<BatchBuilder>()->ReceiveClientInitialMetadata(
616                  stream->batch_target()),
617              [next_promise_factory = std::move(next_promise_factory),
618               server_to_client_empty = std::move(server_to_client_empty),
619               call_data](ClientMetadataHandle client_initial_metadata) {
620                auto call_promise = next_promise_factory(CallArgs{
621                    std::move(client_initial_metadata),
622                    ClientInitialMetadataOutstandingToken::Empty(),
623                    &call_data->polling_entity_latch,
624                    &call_data->server_initial_metadata.sender,
625                    &call_data->client_to_server.receiver,
626                    &call_data->server_to_client.sender,
627                });
628                return Race(call_data->failure_latch.Wait(),
629                            [call_promise = std::move(call_promise),
630                             server_to_client_empty =
631                                 std::move(server_to_client_empty)]() mutable
632                            -> Poll<ServerMetadataHandle> {
633                              // TODO(ctiller): this is deeply weird and we need
634                              // to clean this up.
635                              //
636                              // The following few lines check to ensure that
637                              // there's no message currently pending in the
638                              // outgoing message queue, and if (and only if)
639                              // that's true decides to poll the main promise to
640                              // see if there's a result.
641                              //
642                              // This essentially introduces a polling priority
643                              // scheme that makes the current promise structure
644                              // work out the way we want when talking to
645                              // transports.
646                              //
647                              // The problem is that transports are going to need
648                              // to replicate this structure when they convert to
649                              // promises, and that becomes troubling as we'll be
650                              // replicating weird throughout the stack.
651                              //
652                              // Instead we likely need to change the way we're
653                              // composing promises through the stack.
654                              //
655                              // Proposed is to change filters from a promise
656                              // that takes ClientInitialMetadata and returns
657                              // ServerTrailingMetadata with three pipes for
658                              // ServerInitialMetadata and
659                              // ClientToServerMessages, ServerToClientMessages.
660                              // Instead we'll have five pipes, moving
661                              // ClientInitialMetadata and ServerTrailingMetadata
662                              // to pipes that can be intercepted.
663                              //
664                              // The effect of this change will be to cripple the
665                              // things that can be done in a filter (but cripple
666                              // in line with what most filters actually do).
667                              // We'll likely need to add a `CallContext::Cancel`
668                              // to allow filters to cancel a request, but this
669                              // would also have the advantage of centralizing
670                              // our cancellation machinery which seems like an
671                              // additional win - with the net effect that the
672                              // shape of the call gets made explicit at the top
673                              // & bottom of the stack.
674                              //
675                              // There's a small set of filters (retry, this one,
676                              // lame client, clinet channel) that terminate
677                              // stacks and need a richer set of semantics, but
678                              // that ends up being fine because we can spawn
679                              // tasks in parties to handle those edge cases, and
680                              // keep the majority of filters simple: they just
681                              // call InterceptAndMap on a handful of filters at
682                              // call initialization time and then proceed to
683                              // actually filter.
684                              //
685                              // So that's the plan, why isn't it enacted here?
686                              //
687                              // Well, the plan ends up being easy to implement
688                              // in the promise based world (I did a prototype on
689                              // a branch in an afternoon). It's heinous to
690                              // implement in promise_based_filter, and that code
691                              // is load bearing for us at the time of writing.
692                              // It's not worth delaying promises for a further N
693                              // months (N ~ 6) to make that change.
694                              //
695                              // Instead, we'll move forward with this, get
696                              // promise_based_filter out of the picture, and
697                              // then during the mop-up phase for promises tweak
698                              // the compute structure to move to the magical
699                              // five pipes (I'm reminded of an old Onion
700                              // article), and end up in a good happy place.
701                              if (server_to_client_empty().pending()) {
702                                return Pending{};
703                              }
704                              return call_promise();
705                            });
706              });
707 
708   // Promise factory that accepts a ServerMetadataHandle, and sends it as the
709   // trailing metadata for this call.
710   auto send_trailing_metadata = [call_data, stream = stream->InternalRef()](
711                                     ServerMetadataHandle
712                                         server_trailing_metadata) {
713     bool is_cancellation =
714         server_trailing_metadata->get(GrpcCallWasCancelled()).value_or(false);
715     return GetContext<BatchBuilder>()->SendServerTrailingMetadata(
716         stream->batch_target(), std::move(server_trailing_metadata),
717         is_cancellation ||
718             !std::exchange(call_data->sent_initial_metadata, true));
719   };
720 
721   // Runs the receive message loop, either until all the messages
722   // are received or the server call is complete.
723   party->Spawn(
724       "recv_messages",
725       Race(
726           Map(stream->WaitFinished(), [](Empty) { return absl::OkStatus(); }),
727           Map(stream->RecvMessages(&call_data->client_to_server.sender, true),
728               [failure_latch = &call_data->failure_latch](absl::Status status) {
729                 if (!status.ok() && !failure_latch->is_set()) {
730                   failure_latch->Set(ServerMetadataFromStatus(status));
731                 }
732                 return status;
733               })),
734       [](absl::Status) {});
735 
736   // Run a promise that will send initial metadata (if that pipe sends some).
737   // And then run the send message loop until that completes.
738 
739   auto send_initial_metadata = Seq(
740       Race(Map(stream->WaitFinished(),
741                [](Empty) { return NextResult<ServerMetadataHandle>(true); }),
742            call_data->server_initial_metadata.receiver.Next()),
743       [call_data, stream = stream->InternalRef()](
744           NextResult<ServerMetadataHandle> next_result) mutable {
745         auto md = !call_data->sent_initial_metadata && next_result.has_value()
746                       ? std::move(next_result.value())
747                       : nullptr;
748         if (md != nullptr) {
749           call_data->sent_initial_metadata = true;
750           auto* party = GetContext<Party>();
751           party->Spawn("connected/send_initial_metadata",
752                        GetContext<BatchBuilder>()->SendServerInitialMetadata(
753                            stream->batch_target(), std::move(md)),
754                        [](absl::Status) {});
755           return Immediate(absl::OkStatus());
756         }
757         return Immediate(absl::CancelledError());
758       });
759   party->Spawn(
760       "send_initial_metadata_then_messages",
761       Race(Map(stream->WaitFinished(), [](Empty) { return absl::OkStatus(); }),
762            TrySeq(std::move(send_initial_metadata),
763                   stream->SendMessages(&call_data->server_to_client.receiver))),
764       [](absl::Status) {});
765 
766   // Spawn a job to fetch the "client trailing metadata" - if this is OK then
767   // it's client done, otherwise it's a signal of cancellation from the client
768   // which we'll use failure_latch to signal.
769 
770   party->Spawn(
771       "recv_trailing_metadata",
772       Seq(GetContext<BatchBuilder>()->ReceiveClientTrailingMetadata(
773               stream->batch_target()),
774           [failure_latch = &call_data->failure_latch](
775               absl::StatusOr<ClientMetadataHandle> status) mutable {
776             if (grpc_call_trace.enabled()) {
777               gpr_log(
778                   GPR_DEBUG,
779                   "%s[connected] Got trailing metadata; status=%s metadata=%s",
780                   GetContext<Activity>()->DebugTag().c_str(),
781                   status.status().ToString().c_str(),
782                   status.ok() ? (*status)->DebugString().c_str() : "<none>");
783             }
784             ClientMetadataHandle trailing_metadata;
785             if (status.ok()) {
786               trailing_metadata = std::move(*status);
787             } else {
788               trailing_metadata =
789                   GetContext<Arena>()->MakePooled<ClientMetadata>();
790               grpc_status_code status_code = GRPC_STATUS_UNKNOWN;
791               std::string message;
792               grpc_error_get_status(status.status(), Timestamp::InfFuture(),
793                                     &status_code, &message, nullptr, nullptr);
794               trailing_metadata->Set(GrpcStatusMetadata(), status_code);
795               trailing_metadata->Set(GrpcMessageMetadata(),
796                                      Slice::FromCopiedString(message));
797             }
798             if (trailing_metadata->get(GrpcStatusMetadata())
799                     .value_or(GRPC_STATUS_UNKNOWN) != GRPC_STATUS_OK) {
800               if (!failure_latch->is_set()) {
801                 failure_latch->Set(std::move(trailing_metadata));
802               }
803             }
804             return Empty{};
805           }),
806       [](Empty) {});
807 
808   // Finally assemble the main call promise:
809   // Receive initial metadata from the client and start the promise up the
810   // filter stack.
811   // Upon completion, send trailing metadata to the client and then return it
812   // (allowing the call code to decide on what signalling to give the
813   // application).
814 
815   struct CleanupPollingEntityLatch {
816     void operator()(Latch<grpc_polling_entity>* latch) {
817       if (!latch->is_set()) latch->Set(grpc_polling_entity());
818     }
819   };
820   auto cleanup_polling_entity_latch =
821       std::unique_ptr<Latch<grpc_polling_entity>, CleanupPollingEntityLatch>(
822           &call_data->polling_entity_latch);
823   struct CleanupSendInitialMetadata {
824     void operator()(CallData* call_data) {
825       call_data->server_initial_metadata.receiver.CloseWithError();
826     }
827   };
828   auto cleanup_send_initial_metadata =
829       std::unique_ptr<CallData, CleanupSendInitialMetadata>(call_data);
830 
831   return Map(
832       Seq(std::move(recv_initial_metadata_then_run_promise),
833           std::move(send_trailing_metadata)),
834       [cleanup_polling_entity_latch = std::move(cleanup_polling_entity_latch),
835        cleanup_send_initial_metadata = std::move(cleanup_send_initial_metadata),
836        stream = std::move(stream)](ServerMetadataHandle md) {
837         stream->set_finished();
838         return md;
839       });
840 }
841 #endif
842 
843 template <ArenaPromise<ServerMetadataHandle> (*make_call_promise)(
844     Transport*, CallArgs, NextPromiseFactory)>
MakeConnectedFilter()845 grpc_channel_filter MakeConnectedFilter() {
846   // Create a vtable that contains both the legacy call methods (for filter
847   // stack based calls) and the new promise based method for creating
848   // promise based calls (the latter iff make_call_promise != nullptr). In
849   // this way the filter can be inserted into either kind of channel stack,
850   // and only if all the filters in the stack are promise based will the
851   // call be promise based.
852   auto make_call_wrapper = +[](grpc_channel_element* elem, CallArgs call_args,
853                                NextPromiseFactory next) {
854     Transport* transport =
855         static_cast<channel_data*>(elem->channel_data)->transport;
856     return make_call_promise(transport, std::move(call_args), std::move(next));
857   };
858   return {
859       connected_channel_start_transport_stream_op_batch,
860       make_call_promise != nullptr ? make_call_wrapper : nullptr,
861       /* init_call: */ nullptr,
862       connected_channel_start_transport_op,
863       sizeof(call_data),
864       connected_channel_init_call_elem,
865       set_pollset_or_pollset_set,
866       connected_channel_destroy_call_elem,
867       sizeof(channel_data),
868       connected_channel_init_channel_elem,
869       +[](grpc_channel_stack* channel_stack, grpc_channel_element* elem) {
870         // HACK(ctiller): increase call stack size for the channel to make
871         // space for channel data. We need a cleaner (but performant) way to
872         // do this, and I'm not sure what that is yet. This is only "safe"
873         // because call stacks place no additional data after the last call
874         // element, and the last call element MUST be the connected channel.
875         auto* transport =
876             static_cast<channel_data*>(elem->channel_data)->transport;
877         if (transport->filter_stack_transport() != nullptr) {
878           channel_stack->call_stack_size +=
879               transport->filter_stack_transport()->SizeOfStream();
880         }
881       },
882       connected_channel_destroy_channel_elem,
883       connected_channel_get_channel_info,
884       "connected",
885   };
886 }
887 
MakeClientTransportCallPromise(Transport * transport,CallArgs call_args,NextPromiseFactory)888 ArenaPromise<ServerMetadataHandle> MakeClientTransportCallPromise(
889     Transport* transport, CallArgs call_args, NextPromiseFactory) {
890   auto spine = GetContext<CallContext>()->MakeCallSpine(std::move(call_args));
891   transport->client_transport()->StartCall(CallHandler{spine});
892   return Map(spine->server_trailing_metadata().receiver.Next(),
893              [](NextResult<ServerMetadataHandle> r) {
894                if (r.has_value()) {
895                  auto md = std::move(r.value());
896                  md->Set(GrpcStatusFromWire(), true);
897                  return md;
898                }
899                auto m = GetContext<Arena>()->MakePooled<ServerMetadata>();
900                m->Set(GrpcStatusMetadata(), GRPC_STATUS_CANCELLED);
901                m->Set(GrpcCallWasCancelled(), true);
902                return m;
903              });
904 }
905 
906 const grpc_channel_filter kClientPromiseBasedTransportFilter =
907     MakeConnectedFilter<MakeClientTransportCallPromise>();
908 
909 #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL
910 const grpc_channel_filter kClientEmulatedFilter =
911     MakeConnectedFilter<MakeClientCallPromise>();
912 #else
913 const grpc_channel_filter kClientEmulatedFilter =
914     MakeConnectedFilter<nullptr>();
915 #endif
916 
917 #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL
918 const grpc_channel_filter kServerEmulatedFilter =
919     MakeConnectedFilter<MakeServerCallPromise>();
920 #else
921 const grpc_channel_filter kServerEmulatedFilter =
922     MakeConnectedFilter<nullptr>();
923 #endif
924 
925 // noop filter for the v3 stack: placeholder for now because other code requires
926 // we have a terminator.
927 // TODO(ctiller): delete when v3 transition is complete.
928 const grpc_channel_filter kServerPromiseBasedTransportFilter = {
929     nullptr,
930     [](grpc_channel_element*, CallArgs, NextPromiseFactory)
__anon8f1971572902() 931         -> ArenaPromise<ServerMetadataHandle> { Crash("not implemented"); },
__anon8f1971572a02() 932     /* init_call: */ [](grpc_channel_element*, CallSpineInterface*) {},
933     connected_channel_start_transport_op,
934     0,
935     nullptr,
936     set_pollset_or_pollset_set,
937     nullptr,
938     sizeof(channel_data),
939     connected_channel_init_channel_elem,
__anon8f1971572b02() 940     +[](grpc_channel_stack*, grpc_channel_element*) {},
941     connected_channel_destroy_channel_elem,
942     connected_channel_get_channel_info,
943     "connected",
944 };
945 
TransportSupportsClientPromiseBasedCalls(const ChannelArgs & args)946 bool TransportSupportsClientPromiseBasedCalls(const ChannelArgs& args) {
947   auto* transport = args.GetObject<Transport>();
948   return transport->client_transport() != nullptr;
949 }
950 
TransportSupportsServerPromiseBasedCalls(const ChannelArgs & args)951 bool TransportSupportsServerPromiseBasedCalls(const ChannelArgs& args) {
952   auto* transport = args.GetObject<Transport>();
953   return transport->server_transport() != nullptr;
954 }
955 
956 }  // namespace
957 
RegisterConnectedChannel(CoreConfiguration::Builder * builder)958 void RegisterConnectedChannel(CoreConfiguration::Builder* builder) {
959   // We can't know promise based call or not here (that decision needs the
960   // collaboration of all of the filters on the channel, and we don't want
961   // ordering constraints on when we add filters).
962   // We can know if this results in a promise based call how we'll create
963   // our promise (if indeed we can), and so that is the choice made here.
964 
965   // Option 1, and our ideal: the transport supports promise based calls,
966   // and so we simply use the transport directly.
967   builder->channel_init()
968       ->RegisterFilter(GRPC_CLIENT_SUBCHANNEL,
969                        &kClientPromiseBasedTransportFilter)
970       .Terminal()
971       .If(TransportSupportsClientPromiseBasedCalls);
972   builder->channel_init()
973       ->RegisterFilter(GRPC_CLIENT_DIRECT_CHANNEL,
974                        &kClientPromiseBasedTransportFilter)
975       .Terminal()
976       .If(TransportSupportsClientPromiseBasedCalls);
977   builder->channel_init()
978       ->RegisterFilter(GRPC_SERVER_CHANNEL, &kServerPromiseBasedTransportFilter)
979       .Terminal()
980       .If(TransportSupportsServerPromiseBasedCalls);
981 
982   // Option 2: the transport does not support promise based calls.
983   builder->channel_init()
984       ->RegisterFilter(GRPC_CLIENT_SUBCHANNEL, &kClientEmulatedFilter)
985       .Terminal()
986       .IfNot(TransportSupportsClientPromiseBasedCalls);
987   builder->channel_init()
988       ->RegisterFilter(GRPC_CLIENT_DIRECT_CHANNEL, &kClientEmulatedFilter)
989       .Terminal()
990       .IfNot(TransportSupportsClientPromiseBasedCalls);
991   builder->channel_init()
992       ->RegisterFilter(GRPC_SERVER_CHANNEL, &kServerEmulatedFilter)
993       .Terminal()
994       .IfNot(TransportSupportsServerPromiseBasedCalls);
995 }
996 
997 }  // namespace grpc_core
998