1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include <grpc/support/port_platform.h>
20
21 #include "src/core/lib/channel/connected_channel.h"
22
23 #include <inttypes.h>
24
25 #include <functional>
26 #include <memory>
27 #include <string>
28 #include <type_traits>
29 #include <utility>
30
31 #include "absl/status/status.h"
32 #include "absl/status/statusor.h"
33 #include "absl/types/optional.h"
34
35 #include <grpc/grpc.h>
36 #include <grpc/status.h>
37 #include <grpc/support/alloc.h>
38 #include <grpc/support/log.h>
39
40 #include "src/core/lib/channel/call_finalization.h"
41 #include "src/core/lib/channel/channel_args.h"
42 #include "src/core/lib/channel/channel_fwd.h"
43 #include "src/core/lib/channel/channel_stack.h"
44 #include "src/core/lib/config/core_configuration.h"
45 #include "src/core/lib/debug/trace.h"
46 #include "src/core/lib/experiments/experiments.h"
47 #include "src/core/lib/gpr/alloc.h"
48 #include "src/core/lib/gprpp/debug_location.h"
49 #include "src/core/lib/gprpp/orphanable.h"
50 #include "src/core/lib/gprpp/ref_counted_ptr.h"
51 #include "src/core/lib/gprpp/time.h"
52 #include "src/core/lib/iomgr/call_combiner.h"
53 #include "src/core/lib/iomgr/closure.h"
54 #include "src/core/lib/iomgr/error.h"
55 #include "src/core/lib/iomgr/polling_entity.h"
56 #include "src/core/lib/promise/activity.h"
57 #include "src/core/lib/promise/arena_promise.h"
58 #include "src/core/lib/promise/context.h"
59 #include "src/core/lib/promise/detail/status.h"
60 #include "src/core/lib/promise/for_each.h"
61 #include "src/core/lib/promise/if.h"
62 #include "src/core/lib/promise/latch.h"
63 #include "src/core/lib/promise/loop.h"
64 #include "src/core/lib/promise/map.h"
65 #include "src/core/lib/promise/party.h"
66 #include "src/core/lib/promise/pipe.h"
67 #include "src/core/lib/promise/poll.h"
68 #include "src/core/lib/promise/promise.h"
69 #include "src/core/lib/promise/race.h"
70 #include "src/core/lib/promise/seq.h"
71 #include "src/core/lib/promise/try_seq.h"
72 #include "src/core/lib/resource_quota/arena.h"
73 #include "src/core/lib/slice/slice.h"
74 #include "src/core/lib/slice/slice_buffer.h"
75 #include "src/core/lib/surface/call.h"
76 #include "src/core/lib/surface/call_trace.h"
77 #include "src/core/lib/surface/channel_stack_type.h"
78 #include "src/core/lib/transport/batch_builder.h"
79 #include "src/core/lib/transport/error_utils.h"
80 #include "src/core/lib/transport/metadata_batch.h"
81 #include "src/core/lib/transport/transport.h"
82
83 typedef struct connected_channel_channel_data {
84 grpc_core::Transport* transport;
85 } channel_data;
86
87 struct callback_state {
88 grpc_closure closure;
89 grpc_closure* original_closure;
90 grpc_core::CallCombiner* call_combiner;
91 const char* reason;
92 };
93 typedef struct connected_channel_call_data {
94 grpc_core::CallCombiner* call_combiner;
95 // Closures used for returning results on the call combiner.
96 callback_state on_complete[6]; // Max number of pending batches.
97 callback_state recv_initial_metadata_ready;
98 callback_state recv_message_ready;
99 callback_state recv_trailing_metadata_ready;
100 } call_data;
101
run_in_call_combiner(void * arg,grpc_error_handle error)102 static void run_in_call_combiner(void* arg, grpc_error_handle error) {
103 callback_state* state = static_cast<callback_state*>(arg);
104 GRPC_CALL_COMBINER_START(state->call_combiner, state->original_closure, error,
105 state->reason);
106 }
107
run_cancel_in_call_combiner(void * arg,grpc_error_handle error)108 static void run_cancel_in_call_combiner(void* arg, grpc_error_handle error) {
109 run_in_call_combiner(arg, error);
110 gpr_free(arg);
111 }
112
intercept_callback(call_data * calld,callback_state * state,bool free_when_done,const char * reason,grpc_closure ** original_closure)113 static void intercept_callback(call_data* calld, callback_state* state,
114 bool free_when_done, const char* reason,
115 grpc_closure** original_closure) {
116 state->original_closure = *original_closure;
117 state->call_combiner = calld->call_combiner;
118 state->reason = reason;
119 *original_closure = GRPC_CLOSURE_INIT(
120 &state->closure,
121 free_when_done ? run_cancel_in_call_combiner : run_in_call_combiner,
122 state, grpc_schedule_on_exec_ctx);
123 }
124
get_state_for_batch(call_data * calld,grpc_transport_stream_op_batch * batch)125 static callback_state* get_state_for_batch(
126 call_data* calld, grpc_transport_stream_op_batch* batch) {
127 if (batch->send_initial_metadata) return &calld->on_complete[0];
128 if (batch->send_message) return &calld->on_complete[1];
129 if (batch->send_trailing_metadata) return &calld->on_complete[2];
130 if (batch->recv_initial_metadata) return &calld->on_complete[3];
131 if (batch->recv_message) return &calld->on_complete[4];
132 if (batch->recv_trailing_metadata) return &calld->on_complete[5];
133 GPR_UNREACHABLE_CODE(return nullptr);
134 }
135
136 // We perform a small hack to locate transport data alongside the connected
137 // channel data in call allocations, to allow everything to be pulled in minimal
138 // cache line requests
139 #define TRANSPORT_STREAM_FROM_CALL_DATA(calld) \
140 ((grpc_stream*)(((char*)(calld)) + \
141 GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(call_data))))
142 #define CALL_DATA_FROM_TRANSPORT_STREAM(transport_stream) \
143 ((call_data*)(((char*)(transport_stream)) - \
144 GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(call_data))))
145
146 // Intercept a call operation and either push it directly up or translate it
147 // into transport stream operations
connected_channel_start_transport_stream_op_batch(grpc_call_element * elem,grpc_transport_stream_op_batch * batch)148 static void connected_channel_start_transport_stream_op_batch(
149 grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
150 call_data* calld = static_cast<call_data*>(elem->call_data);
151 channel_data* chand = static_cast<channel_data*>(elem->channel_data);
152 if (batch->recv_initial_metadata) {
153 callback_state* state = &calld->recv_initial_metadata_ready;
154 intercept_callback(
155 calld, state, false, "recv_initial_metadata_ready",
156 &batch->payload->recv_initial_metadata.recv_initial_metadata_ready);
157 }
158 if (batch->recv_message) {
159 callback_state* state = &calld->recv_message_ready;
160 intercept_callback(calld, state, false, "recv_message_ready",
161 &batch->payload->recv_message.recv_message_ready);
162 }
163 if (batch->recv_trailing_metadata) {
164 callback_state* state = &calld->recv_trailing_metadata_ready;
165 intercept_callback(
166 calld, state, false, "recv_trailing_metadata_ready",
167 &batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready);
168 }
169 if (batch->cancel_stream) {
170 // There can be more than one cancellation batch in flight at any
171 // given time, so we can't just pick out a fixed index into
172 // calld->on_complete like we can for the other ops. However,
173 // cancellation isn't in the fast path, so we just allocate a new
174 // closure for each one.
175 callback_state* state =
176 static_cast<callback_state*>(gpr_malloc(sizeof(*state)));
177 intercept_callback(calld, state, true, "on_complete (cancel_stream)",
178 &batch->on_complete);
179 } else if (batch->on_complete != nullptr) {
180 callback_state* state = get_state_for_batch(calld, batch);
181 intercept_callback(calld, state, false, "on_complete", &batch->on_complete);
182 }
183 chand->transport->filter_stack_transport()->PerformStreamOp(
184 TRANSPORT_STREAM_FROM_CALL_DATA(calld), batch);
185 GRPC_CALL_COMBINER_STOP(calld->call_combiner, "passed batch to transport");
186 }
187
connected_channel_start_transport_op(grpc_channel_element * elem,grpc_transport_op * op)188 static void connected_channel_start_transport_op(grpc_channel_element* elem,
189 grpc_transport_op* op) {
190 channel_data* chand = static_cast<channel_data*>(elem->channel_data);
191 chand->transport->PerformOp(op);
192 }
193
194 // Constructor for call_data
connected_channel_init_call_elem(grpc_call_element * elem,const grpc_call_element_args * args)195 static grpc_error_handle connected_channel_init_call_elem(
196 grpc_call_element* elem, const grpc_call_element_args* args) {
197 call_data* calld = static_cast<call_data*>(elem->call_data);
198 channel_data* chand = static_cast<channel_data*>(elem->channel_data);
199 calld->call_combiner = args->call_combiner;
200 chand->transport->filter_stack_transport()->InitStream(
201 TRANSPORT_STREAM_FROM_CALL_DATA(calld), &args->call_stack->refcount,
202 args->server_transport_data, args->arena);
203 return absl::OkStatus();
204 }
205
set_pollset_or_pollset_set(grpc_call_element * elem,grpc_polling_entity * pollent)206 static void set_pollset_or_pollset_set(grpc_call_element* elem,
207 grpc_polling_entity* pollent) {
208 call_data* calld = static_cast<call_data*>(elem->call_data);
209 channel_data* chand = static_cast<channel_data*>(elem->channel_data);
210 chand->transport->SetPollingEntity(TRANSPORT_STREAM_FROM_CALL_DATA(calld),
211 pollent);
212 }
213
214 // Destructor for call_data
connected_channel_destroy_call_elem(grpc_call_element * elem,const grpc_call_final_info *,grpc_closure * then_schedule_closure)215 static void connected_channel_destroy_call_elem(
216 grpc_call_element* elem, const grpc_call_final_info* /*final_info*/,
217 grpc_closure* then_schedule_closure) {
218 call_data* calld = static_cast<call_data*>(elem->call_data);
219 channel_data* chand = static_cast<channel_data*>(elem->channel_data);
220 chand->transport->filter_stack_transport()->DestroyStream(
221 TRANSPORT_STREAM_FROM_CALL_DATA(calld), then_schedule_closure);
222 }
223
224 // Constructor for channel_data
connected_channel_init_channel_elem(grpc_channel_element * elem,grpc_channel_element_args * args)225 static grpc_error_handle connected_channel_init_channel_elem(
226 grpc_channel_element* elem, grpc_channel_element_args* args) {
227 channel_data* cd = static_cast<channel_data*>(elem->channel_data);
228 GPR_ASSERT(args->is_last);
229 cd->transport = args->channel_args.GetObject<grpc_core::Transport>();
230 return absl::OkStatus();
231 }
232
233 // Destructor for channel_data
connected_channel_destroy_channel_elem(grpc_channel_element * elem)234 static void connected_channel_destroy_channel_elem(grpc_channel_element* elem) {
235 channel_data* cd = static_cast<channel_data*>(elem->channel_data);
236 if (cd->transport) {
237 cd->transport->Orphan();
238 }
239 }
240
241 // No-op.
connected_channel_get_channel_info(grpc_channel_element *,const grpc_channel_info *)242 static void connected_channel_get_channel_info(
243 grpc_channel_element* /*elem*/, const grpc_channel_info* /*channel_info*/) {
244 }
245
246 namespace grpc_core {
247 namespace {
248
249 #if defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL) || \
250 defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL)
251 class ConnectedChannelStream : public Orphanable {
252 public:
ConnectedChannelStream(Transport * transport)253 explicit ConnectedChannelStream(Transport* transport)
254 : transport_(transport), stream_(nullptr, StreamDeleter(this)) {
255 GRPC_STREAM_REF_INIT(
256 &stream_refcount_, 1,
257 [](void* p, grpc_error_handle) {
258 static_cast<ConnectedChannelStream*>(p)->BeginDestroy();
259 },
260 this, "ConnectedChannelStream");
261 }
262
transport()263 Transport* transport() { return transport_; }
stream_destroyed_closure()264 grpc_closure* stream_destroyed_closure() { return &stream_destroyed_; }
265
batch_target()266 BatchBuilder::Target batch_target() {
267 return BatchBuilder::Target{transport_, stream_.get(), &stream_refcount_};
268 }
269
IncrementRefCount(const char * reason="smartptr")270 void IncrementRefCount(const char* reason = "smartptr") {
271 #ifndef NDEBUG
272 grpc_stream_ref(&stream_refcount_, reason);
273 #else
274 (void)reason;
275 grpc_stream_ref(&stream_refcount_);
276 #endif
277 }
278
Unref(const char * reason="smartptr")279 void Unref(const char* reason = "smartptr") {
280 #ifndef NDEBUG
281 grpc_stream_unref(&stream_refcount_, reason);
282 #else
283 (void)reason;
284 grpc_stream_unref(&stream_refcount_);
285 #endif
286 }
287
InternalRef()288 RefCountedPtr<ConnectedChannelStream> InternalRef() {
289 IncrementRefCount("smartptr");
290 return RefCountedPtr<ConnectedChannelStream>(this);
291 }
292
Orphan()293 void Orphan() final {
294 bool finished = finished_.IsSet();
295 if (grpc_call_trace.enabled()) {
296 gpr_log(GPR_DEBUG, "%s[connected] Orphan stream, finished: %d",
297 party_->DebugTag().c_str(), finished);
298 }
299 // If we hadn't already observed the stream to be finished, we need to
300 // cancel it at the transport.
301 if (!finished) {
302 party_->Spawn(
303 "finish",
304 [self = InternalRef()]() {
305 if (!self->finished_.IsSet()) {
306 self->finished_.Set();
307 }
308 return Empty{};
309 },
310 [](Empty) {});
311 GetContext<BatchBuilder>()->Cancel(batch_target(),
312 absl::CancelledError());
313 }
314 Unref("orphan connected stream");
315 }
316
317 // Returns a promise that implements the receive message loop.
318 auto RecvMessages(PipeSender<MessageHandle>* incoming_messages,
319 bool cancel_on_error);
320 // Returns a promise that implements the send message loop.
321 auto SendMessages(PipeReceiver<MessageHandle>* outgoing_messages);
322
SetStream(grpc_stream * stream)323 void SetStream(grpc_stream* stream) { stream_.reset(stream); }
stream()324 grpc_stream* stream() { return stream_.get(); }
stream_refcount()325 grpc_stream_refcount* stream_refcount() { return &stream_refcount_; }
326
set_finished()327 void set_finished() { finished_.Set(); }
WaitFinished()328 auto WaitFinished() { return finished_.Wait(); }
329
330 private:
331 class StreamDeleter {
332 public:
StreamDeleter(ConnectedChannelStream * impl)333 explicit StreamDeleter(ConnectedChannelStream* impl) : impl_(impl) {}
operator ()(grpc_stream * stream) const334 void operator()(grpc_stream* stream) const {
335 if (stream == nullptr) return;
336 impl_->transport()->filter_stack_transport()->DestroyStream(
337 stream, impl_->stream_destroyed_closure());
338 }
339
340 private:
341 ConnectedChannelStream* impl_;
342 };
343 using StreamPtr = std::unique_ptr<grpc_stream, StreamDeleter>;
344
StreamDestroyed()345 void StreamDestroyed() {
346 call_context_->RunInContext([this] { this->~ConnectedChannelStream(); });
347 }
348
BeginDestroy()349 void BeginDestroy() {
350 if (stream_ != nullptr) {
351 stream_.reset();
352 } else {
353 StreamDestroyed();
354 }
355 }
356
357 Transport* const transport_;
358 RefCountedPtr<CallContext> const call_context_{
359 GetContext<CallContext>()->Ref()};
360 grpc_closure stream_destroyed_ =
361 MakeMemberClosure<ConnectedChannelStream,
362 &ConnectedChannelStream::StreamDestroyed>(
363 this, DEBUG_LOCATION);
364 grpc_stream_refcount stream_refcount_;
365 StreamPtr stream_;
366 Arena* arena_ = GetContext<Arena>();
367 Party* const party_ = GetContext<Party>();
368 ExternallyObservableLatch<void> finished_;
369 };
370
RecvMessages(PipeSender<MessageHandle> * incoming_messages,bool cancel_on_error)371 auto ConnectedChannelStream::RecvMessages(
372 PipeSender<MessageHandle>* incoming_messages, bool cancel_on_error) {
373 return Loop([self = InternalRef(), cancel_on_error,
374 incoming_messages = std::move(*incoming_messages)]() mutable {
375 return Seq(
376 GetContext<BatchBuilder>()->ReceiveMessage(self->batch_target()),
377 [cancel_on_error, &incoming_messages](
378 absl::StatusOr<absl::optional<MessageHandle>> status) mutable {
379 bool has_message = status.ok() && status->has_value();
380 auto publish_message = [&incoming_messages, &status]() {
381 auto pending_message = std::move(**status);
382 if (grpc_call_trace.enabled()) {
383 gpr_log(GPR_INFO,
384 "%s[connected] RecvMessage: received payload of %" PRIdPTR
385 " bytes",
386 GetContext<Activity>()->DebugTag().c_str(),
387 pending_message->payload()->Length());
388 }
389 return Map(incoming_messages.Push(std::move(pending_message)),
390 [](bool ok) -> LoopCtl<absl::Status> {
391 if (!ok) {
392 if (grpc_call_trace.enabled()) {
393 gpr_log(
394 GPR_INFO,
395 "%s[connected] RecvMessage: failed to "
396 "push message towards the application",
397 GetContext<Activity>()->DebugTag().c_str());
398 }
399 return absl::OkStatus();
400 }
401 return Continue{};
402 });
403 };
404 auto publish_close = [cancel_on_error, &incoming_messages,
405 &status]() mutable {
406 if (grpc_call_trace.enabled()) {
407 gpr_log(GPR_INFO,
408 "%s[connected] RecvMessage: reached end of stream with "
409 "status:%s",
410 GetContext<Activity>()->DebugTag().c_str(),
411 status.status().ToString().c_str());
412 }
413 if (cancel_on_error && !status.ok()) {
414 incoming_messages.CloseWithError();
415 } else {
416 incoming_messages.Close();
417 }
418 return Immediate(LoopCtl<absl::Status>(status.status()));
419 };
420 return If(has_message, std::move(publish_message),
421 std::move(publish_close));
422 });
423 });
424 }
425
SendMessages(PipeReceiver<MessageHandle> * outgoing_messages)426 auto ConnectedChannelStream::SendMessages(
427 PipeReceiver<MessageHandle>* outgoing_messages) {
428 return ForEach(std::move(*outgoing_messages),
429 [self = InternalRef()](MessageHandle message) {
430 return GetContext<BatchBuilder>()->SendMessage(
431 self->batch_target(), std::move(message));
432 });
433 }
434 #endif // defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL) ||
435 // defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL)
436
437 #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL
MakeClientCallPromise(Transport * transport,CallArgs call_args,NextPromiseFactory)438 ArenaPromise<ServerMetadataHandle> MakeClientCallPromise(Transport* transport,
439 CallArgs call_args,
440 NextPromiseFactory) {
441 OrphanablePtr<ConnectedChannelStream> stream(
442 GetContext<Arena>()->New<ConnectedChannelStream>(transport));
443 stream->SetStream(static_cast<grpc_stream*>(GetContext<Arena>()->Alloc(
444 transport->filter_stack_transport()->SizeOfStream())));
445 transport->filter_stack_transport()->InitStream(stream->stream(),
446 stream->stream_refcount(),
447 nullptr, GetContext<Arena>());
448 auto* party = GetContext<Party>();
449 party->Spawn("set_polling_entity", call_args.polling_entity->Wait(),
450 [transport, stream = stream->InternalRef()](
451 grpc_polling_entity polling_entity) {
452 transport->SetPollingEntity(stream->stream(), &polling_entity);
453 });
454 // Start a loop to send messages from client_to_server_messages to the
455 // transport. When the pipe closes and the loop completes, send a trailing
456 // metadata batch to close the stream.
457 party->Spawn(
458 "send_messages",
459 TrySeq(stream->SendMessages(call_args.client_to_server_messages),
460 [stream = stream->InternalRef()]() {
461 return GetContext<BatchBuilder>()->SendClientTrailingMetadata(
462 stream->batch_target());
463 }),
464 [](absl::Status) {});
465 // Start a promise to receive server initial metadata and then forward it up
466 // through the receiving pipe.
467 auto server_initial_metadata =
468 GetContext<Arena>()->MakePooled<ServerMetadata>();
469 party->Spawn(
470 "recv_initial_metadata",
471 TrySeq(GetContext<BatchBuilder>()->ReceiveServerInitialMetadata(
472 stream->batch_target()),
473 [pipe = call_args.server_initial_metadata](
474 ServerMetadataHandle server_initial_metadata) {
475 if (grpc_call_trace.enabled()) {
476 gpr_log(GPR_DEBUG,
477 "%s[connected] Publish client initial metadata: %s",
478 GetContext<Activity>()->DebugTag().c_str(),
479 server_initial_metadata->DebugString().c_str());
480 }
481 return Map(pipe->Push(std::move(server_initial_metadata)),
482 [](bool r) {
483 if (r) return absl::OkStatus();
484 return absl::CancelledError();
485 });
486 }),
487 [](absl::Status) {});
488
489 // Build up the rest of the main call promise:
490
491 // Create a promise that will send initial metadata and then signal completion
492 // of that via the token.
493 auto send_initial_metadata = Seq(
494 GetContext<BatchBuilder>()->SendClientInitialMetadata(
495 stream->batch_target(), std::move(call_args.client_initial_metadata)),
496 [sent_initial_metadata_token =
497 std::move(call_args.client_initial_metadata_outstanding)](
498 absl::Status status) mutable {
499 sent_initial_metadata_token.Complete(status.ok());
500 return status;
501 });
502 // Create a promise that will receive server trailing metadata.
503 // If this fails, we massage the error into metadata that we can report
504 // upwards.
505 auto server_trailing_metadata =
506 GetContext<Arena>()->MakePooled<ServerMetadata>();
507 auto recv_trailing_metadata =
508 Map(GetContext<BatchBuilder>()->ReceiveServerTrailingMetadata(
509 stream->batch_target()),
510 [](absl::StatusOr<ServerMetadataHandle> status) mutable {
511 if (!status.ok()) {
512 auto server_trailing_metadata =
513 GetContext<Arena>()->MakePooled<ServerMetadata>();
514 grpc_status_code status_code = GRPC_STATUS_UNKNOWN;
515 std::string message;
516 grpc_error_get_status(status.status(), Timestamp::InfFuture(),
517 &status_code, &message, nullptr, nullptr);
518 server_trailing_metadata->Set(GrpcStatusMetadata(), status_code);
519 server_trailing_metadata->Set(GrpcMessageMetadata(),
520 Slice::FromCopiedString(message));
521 return server_trailing_metadata;
522 } else {
523 return std::move(*status);
524 }
525 });
526 // Finally the main call promise.
527 // Concurrently: send initial metadata and receive messages, until BOTH
528 // complete (or one fails).
529 // Next: receive trailing metadata, and return that up the stack.
530 auto recv_messages =
531 stream->RecvMessages(call_args.server_to_client_messages, false);
532 return Map(
533 [send_initial_metadata = std::move(send_initial_metadata),
534 recv_messages = std::move(recv_messages),
535 recv_trailing_metadata = std::move(recv_trailing_metadata),
536 done_send_initial_metadata = false, done_recv_messages = false,
537 done_recv_trailing_metadata =
538 false]() mutable -> Poll<ServerMetadataHandle> {
539 if (!done_send_initial_metadata) {
540 auto p = send_initial_metadata();
541 if (auto* r = p.value_if_ready()) {
542 done_send_initial_metadata = true;
543 if (!r->ok()) return StatusCast<ServerMetadataHandle>(*r);
544 }
545 }
546 if (!done_recv_messages) {
547 auto p = recv_messages();
548 if (p.ready()) {
549 // NOTE: ignore errors here, they'll be collected in the
550 // recv_trailing_metadata.
551 done_recv_messages = true;
552 } else {
553 return Pending{};
554 }
555 }
556 if (!done_recv_trailing_metadata) {
557 auto p = recv_trailing_metadata();
558 if (auto* r = p.value_if_ready()) {
559 done_recv_trailing_metadata = true;
560 return std::move(*r);
561 }
562 }
563 return Pending{};
564 },
565 [stream = std::move(stream)](ServerMetadataHandle result) {
566 stream->set_finished();
567 return result;
568 });
569 }
570 #endif
571
572 #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL
MakeServerCallPromise(Transport * transport,CallArgs,NextPromiseFactory next_promise_factory)573 ArenaPromise<ServerMetadataHandle> MakeServerCallPromise(
574 Transport* transport, CallArgs, NextPromiseFactory next_promise_factory) {
575 OrphanablePtr<ConnectedChannelStream> stream(
576 GetContext<Arena>()->New<ConnectedChannelStream>(transport));
577
578 stream->SetStream(static_cast<grpc_stream*>(GetContext<Arena>()->Alloc(
579 transport->filter_stack_transport()->SizeOfStream())));
580 transport->filter_stack_transport()->InitStream(
581 stream->stream(), stream->stream_refcount(),
582 GetContext<CallContext>()->server_call_context()->server_stream_data(),
583 GetContext<Arena>());
584 auto* party = GetContext<Party>();
585
586 // Arifacts we need for the lifetime of the call.
587 struct CallData {
588 Pipe<MessageHandle> server_to_client;
589 Pipe<MessageHandle> client_to_server;
590 Pipe<ServerMetadataHandle> server_initial_metadata;
591 Latch<ServerMetadataHandle> failure_latch;
592 Latch<grpc_polling_entity> polling_entity_latch;
593 bool sent_initial_metadata = false;
594 bool sent_trailing_metadata = false;
595 };
596 auto* call_data = GetContext<Arena>()->New<CallData>();
597 GetContext<CallFinalization>()->Add(
598 [call_data](const grpc_call_final_info*) { call_data->~CallData(); });
599
600 party->Spawn("set_polling_entity", call_data->polling_entity_latch.Wait(),
601 [transport, stream = stream->InternalRef()](
602 grpc_polling_entity polling_entity) {
603 transport->SetPollingEntity(stream->stream(), &polling_entity);
604 });
605
606 auto server_to_client_empty =
607 call_data->server_to_client.receiver.AwaitEmpty();
608
609 // Create a promise that will receive client initial metadata, and then run
610 // the main stem of the call (calling next_promise_factory up through the
611 // filters).
612 // Race the main call with failure_latch, allowing us to forcefully complete
613 // the call in the case of a failure.
614 auto recv_initial_metadata_then_run_promise =
615 TrySeq(GetContext<BatchBuilder>()->ReceiveClientInitialMetadata(
616 stream->batch_target()),
617 [next_promise_factory = std::move(next_promise_factory),
618 server_to_client_empty = std::move(server_to_client_empty),
619 call_data](ClientMetadataHandle client_initial_metadata) {
620 auto call_promise = next_promise_factory(CallArgs{
621 std::move(client_initial_metadata),
622 ClientInitialMetadataOutstandingToken::Empty(),
623 &call_data->polling_entity_latch,
624 &call_data->server_initial_metadata.sender,
625 &call_data->client_to_server.receiver,
626 &call_data->server_to_client.sender,
627 });
628 return Race(call_data->failure_latch.Wait(),
629 [call_promise = std::move(call_promise),
630 server_to_client_empty =
631 std::move(server_to_client_empty)]() mutable
632 -> Poll<ServerMetadataHandle> {
633 // TODO(ctiller): this is deeply weird and we need
634 // to clean this up.
635 //
636 // The following few lines check to ensure that
637 // there's no message currently pending in the
638 // outgoing message queue, and if (and only if)
639 // that's true decides to poll the main promise to
640 // see if there's a result.
641 //
642 // This essentially introduces a polling priority
643 // scheme that makes the current promise structure
644 // work out the way we want when talking to
645 // transports.
646 //
647 // The problem is that transports are going to need
648 // to replicate this structure when they convert to
649 // promises, and that becomes troubling as we'll be
650 // replicating weird throughout the stack.
651 //
652 // Instead we likely need to change the way we're
653 // composing promises through the stack.
654 //
655 // Proposed is to change filters from a promise
656 // that takes ClientInitialMetadata and returns
657 // ServerTrailingMetadata with three pipes for
658 // ServerInitialMetadata and
659 // ClientToServerMessages, ServerToClientMessages.
660 // Instead we'll have five pipes, moving
661 // ClientInitialMetadata and ServerTrailingMetadata
662 // to pipes that can be intercepted.
663 //
664 // The effect of this change will be to cripple the
665 // things that can be done in a filter (but cripple
666 // in line with what most filters actually do).
667 // We'll likely need to add a `CallContext::Cancel`
668 // to allow filters to cancel a request, but this
669 // would also have the advantage of centralizing
670 // our cancellation machinery which seems like an
671 // additional win - with the net effect that the
672 // shape of the call gets made explicit at the top
673 // & bottom of the stack.
674 //
675 // There's a small set of filters (retry, this one,
676 // lame client, clinet channel) that terminate
677 // stacks and need a richer set of semantics, but
678 // that ends up being fine because we can spawn
679 // tasks in parties to handle those edge cases, and
680 // keep the majority of filters simple: they just
681 // call InterceptAndMap on a handful of filters at
682 // call initialization time and then proceed to
683 // actually filter.
684 //
685 // So that's the plan, why isn't it enacted here?
686 //
687 // Well, the plan ends up being easy to implement
688 // in the promise based world (I did a prototype on
689 // a branch in an afternoon). It's heinous to
690 // implement in promise_based_filter, and that code
691 // is load bearing for us at the time of writing.
692 // It's not worth delaying promises for a further N
693 // months (N ~ 6) to make that change.
694 //
695 // Instead, we'll move forward with this, get
696 // promise_based_filter out of the picture, and
697 // then during the mop-up phase for promises tweak
698 // the compute structure to move to the magical
699 // five pipes (I'm reminded of an old Onion
700 // article), and end up in a good happy place.
701 if (server_to_client_empty().pending()) {
702 return Pending{};
703 }
704 return call_promise();
705 });
706 });
707
708 // Promise factory that accepts a ServerMetadataHandle, and sends it as the
709 // trailing metadata for this call.
710 auto send_trailing_metadata = [call_data, stream = stream->InternalRef()](
711 ServerMetadataHandle
712 server_trailing_metadata) {
713 bool is_cancellation =
714 server_trailing_metadata->get(GrpcCallWasCancelled()).value_or(false);
715 return GetContext<BatchBuilder>()->SendServerTrailingMetadata(
716 stream->batch_target(), std::move(server_trailing_metadata),
717 is_cancellation ||
718 !std::exchange(call_data->sent_initial_metadata, true));
719 };
720
721 // Runs the receive message loop, either until all the messages
722 // are received or the server call is complete.
723 party->Spawn(
724 "recv_messages",
725 Race(
726 Map(stream->WaitFinished(), [](Empty) { return absl::OkStatus(); }),
727 Map(stream->RecvMessages(&call_data->client_to_server.sender, true),
728 [failure_latch = &call_data->failure_latch](absl::Status status) {
729 if (!status.ok() && !failure_latch->is_set()) {
730 failure_latch->Set(ServerMetadataFromStatus(status));
731 }
732 return status;
733 })),
734 [](absl::Status) {});
735
736 // Run a promise that will send initial metadata (if that pipe sends some).
737 // And then run the send message loop until that completes.
738
739 auto send_initial_metadata = Seq(
740 Race(Map(stream->WaitFinished(),
741 [](Empty) { return NextResult<ServerMetadataHandle>(true); }),
742 call_data->server_initial_metadata.receiver.Next()),
743 [call_data, stream = stream->InternalRef()](
744 NextResult<ServerMetadataHandle> next_result) mutable {
745 auto md = !call_data->sent_initial_metadata && next_result.has_value()
746 ? std::move(next_result.value())
747 : nullptr;
748 if (md != nullptr) {
749 call_data->sent_initial_metadata = true;
750 auto* party = GetContext<Party>();
751 party->Spawn("connected/send_initial_metadata",
752 GetContext<BatchBuilder>()->SendServerInitialMetadata(
753 stream->batch_target(), std::move(md)),
754 [](absl::Status) {});
755 return Immediate(absl::OkStatus());
756 }
757 return Immediate(absl::CancelledError());
758 });
759 party->Spawn(
760 "send_initial_metadata_then_messages",
761 Race(Map(stream->WaitFinished(), [](Empty) { return absl::OkStatus(); }),
762 TrySeq(std::move(send_initial_metadata),
763 stream->SendMessages(&call_data->server_to_client.receiver))),
764 [](absl::Status) {});
765
766 // Spawn a job to fetch the "client trailing metadata" - if this is OK then
767 // it's client done, otherwise it's a signal of cancellation from the client
768 // which we'll use failure_latch to signal.
769
770 party->Spawn(
771 "recv_trailing_metadata",
772 Seq(GetContext<BatchBuilder>()->ReceiveClientTrailingMetadata(
773 stream->batch_target()),
774 [failure_latch = &call_data->failure_latch](
775 absl::StatusOr<ClientMetadataHandle> status) mutable {
776 if (grpc_call_trace.enabled()) {
777 gpr_log(
778 GPR_DEBUG,
779 "%s[connected] Got trailing metadata; status=%s metadata=%s",
780 GetContext<Activity>()->DebugTag().c_str(),
781 status.status().ToString().c_str(),
782 status.ok() ? (*status)->DebugString().c_str() : "<none>");
783 }
784 ClientMetadataHandle trailing_metadata;
785 if (status.ok()) {
786 trailing_metadata = std::move(*status);
787 } else {
788 trailing_metadata =
789 GetContext<Arena>()->MakePooled<ClientMetadata>();
790 grpc_status_code status_code = GRPC_STATUS_UNKNOWN;
791 std::string message;
792 grpc_error_get_status(status.status(), Timestamp::InfFuture(),
793 &status_code, &message, nullptr, nullptr);
794 trailing_metadata->Set(GrpcStatusMetadata(), status_code);
795 trailing_metadata->Set(GrpcMessageMetadata(),
796 Slice::FromCopiedString(message));
797 }
798 if (trailing_metadata->get(GrpcStatusMetadata())
799 .value_or(GRPC_STATUS_UNKNOWN) != GRPC_STATUS_OK) {
800 if (!failure_latch->is_set()) {
801 failure_latch->Set(std::move(trailing_metadata));
802 }
803 }
804 return Empty{};
805 }),
806 [](Empty) {});
807
808 // Finally assemble the main call promise:
809 // Receive initial metadata from the client and start the promise up the
810 // filter stack.
811 // Upon completion, send trailing metadata to the client and then return it
812 // (allowing the call code to decide on what signalling to give the
813 // application).
814
815 struct CleanupPollingEntityLatch {
816 void operator()(Latch<grpc_polling_entity>* latch) {
817 if (!latch->is_set()) latch->Set(grpc_polling_entity());
818 }
819 };
820 auto cleanup_polling_entity_latch =
821 std::unique_ptr<Latch<grpc_polling_entity>, CleanupPollingEntityLatch>(
822 &call_data->polling_entity_latch);
823 struct CleanupSendInitialMetadata {
824 void operator()(CallData* call_data) {
825 call_data->server_initial_metadata.receiver.CloseWithError();
826 }
827 };
828 auto cleanup_send_initial_metadata =
829 std::unique_ptr<CallData, CleanupSendInitialMetadata>(call_data);
830
831 return Map(
832 Seq(std::move(recv_initial_metadata_then_run_promise),
833 std::move(send_trailing_metadata)),
834 [cleanup_polling_entity_latch = std::move(cleanup_polling_entity_latch),
835 cleanup_send_initial_metadata = std::move(cleanup_send_initial_metadata),
836 stream = std::move(stream)](ServerMetadataHandle md) {
837 stream->set_finished();
838 return md;
839 });
840 }
841 #endif
842
843 template <ArenaPromise<ServerMetadataHandle> (*make_call_promise)(
844 Transport*, CallArgs, NextPromiseFactory)>
MakeConnectedFilter()845 grpc_channel_filter MakeConnectedFilter() {
846 // Create a vtable that contains both the legacy call methods (for filter
847 // stack based calls) and the new promise based method for creating
848 // promise based calls (the latter iff make_call_promise != nullptr). In
849 // this way the filter can be inserted into either kind of channel stack,
850 // and only if all the filters in the stack are promise based will the
851 // call be promise based.
852 auto make_call_wrapper = +[](grpc_channel_element* elem, CallArgs call_args,
853 NextPromiseFactory next) {
854 Transport* transport =
855 static_cast<channel_data*>(elem->channel_data)->transport;
856 return make_call_promise(transport, std::move(call_args), std::move(next));
857 };
858 return {
859 connected_channel_start_transport_stream_op_batch,
860 make_call_promise != nullptr ? make_call_wrapper : nullptr,
861 /* init_call: */ nullptr,
862 connected_channel_start_transport_op,
863 sizeof(call_data),
864 connected_channel_init_call_elem,
865 set_pollset_or_pollset_set,
866 connected_channel_destroy_call_elem,
867 sizeof(channel_data),
868 connected_channel_init_channel_elem,
869 +[](grpc_channel_stack* channel_stack, grpc_channel_element* elem) {
870 // HACK(ctiller): increase call stack size for the channel to make
871 // space for channel data. We need a cleaner (but performant) way to
872 // do this, and I'm not sure what that is yet. This is only "safe"
873 // because call stacks place no additional data after the last call
874 // element, and the last call element MUST be the connected channel.
875 auto* transport =
876 static_cast<channel_data*>(elem->channel_data)->transport;
877 if (transport->filter_stack_transport() != nullptr) {
878 channel_stack->call_stack_size +=
879 transport->filter_stack_transport()->SizeOfStream();
880 }
881 },
882 connected_channel_destroy_channel_elem,
883 connected_channel_get_channel_info,
884 "connected",
885 };
886 }
887
MakeClientTransportCallPromise(Transport * transport,CallArgs call_args,NextPromiseFactory)888 ArenaPromise<ServerMetadataHandle> MakeClientTransportCallPromise(
889 Transport* transport, CallArgs call_args, NextPromiseFactory) {
890 auto spine = GetContext<CallContext>()->MakeCallSpine(std::move(call_args));
891 transport->client_transport()->StartCall(CallHandler{spine});
892 return Map(spine->server_trailing_metadata().receiver.Next(),
893 [](NextResult<ServerMetadataHandle> r) {
894 if (r.has_value()) {
895 auto md = std::move(r.value());
896 md->Set(GrpcStatusFromWire(), true);
897 return md;
898 }
899 auto m = GetContext<Arena>()->MakePooled<ServerMetadata>();
900 m->Set(GrpcStatusMetadata(), GRPC_STATUS_CANCELLED);
901 m->Set(GrpcCallWasCancelled(), true);
902 return m;
903 });
904 }
905
906 const grpc_channel_filter kClientPromiseBasedTransportFilter =
907 MakeConnectedFilter<MakeClientTransportCallPromise>();
908
909 #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL
910 const grpc_channel_filter kClientEmulatedFilter =
911 MakeConnectedFilter<MakeClientCallPromise>();
912 #else
913 const grpc_channel_filter kClientEmulatedFilter =
914 MakeConnectedFilter<nullptr>();
915 #endif
916
917 #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL
918 const grpc_channel_filter kServerEmulatedFilter =
919 MakeConnectedFilter<MakeServerCallPromise>();
920 #else
921 const grpc_channel_filter kServerEmulatedFilter =
922 MakeConnectedFilter<nullptr>();
923 #endif
924
925 // noop filter for the v3 stack: placeholder for now because other code requires
926 // we have a terminator.
927 // TODO(ctiller): delete when v3 transition is complete.
928 const grpc_channel_filter kServerPromiseBasedTransportFilter = {
929 nullptr,
930 [](grpc_channel_element*, CallArgs, NextPromiseFactory)
__anon8f1971572902() 931 -> ArenaPromise<ServerMetadataHandle> { Crash("not implemented"); },
__anon8f1971572a02() 932 /* init_call: */ [](grpc_channel_element*, CallSpineInterface*) {},
933 connected_channel_start_transport_op,
934 0,
935 nullptr,
936 set_pollset_or_pollset_set,
937 nullptr,
938 sizeof(channel_data),
939 connected_channel_init_channel_elem,
__anon8f1971572b02() 940 +[](grpc_channel_stack*, grpc_channel_element*) {},
941 connected_channel_destroy_channel_elem,
942 connected_channel_get_channel_info,
943 "connected",
944 };
945
TransportSupportsClientPromiseBasedCalls(const ChannelArgs & args)946 bool TransportSupportsClientPromiseBasedCalls(const ChannelArgs& args) {
947 auto* transport = args.GetObject<Transport>();
948 return transport->client_transport() != nullptr;
949 }
950
TransportSupportsServerPromiseBasedCalls(const ChannelArgs & args)951 bool TransportSupportsServerPromiseBasedCalls(const ChannelArgs& args) {
952 auto* transport = args.GetObject<Transport>();
953 return transport->server_transport() != nullptr;
954 }
955
956 } // namespace
957
RegisterConnectedChannel(CoreConfiguration::Builder * builder)958 void RegisterConnectedChannel(CoreConfiguration::Builder* builder) {
959 // We can't know promise based call or not here (that decision needs the
960 // collaboration of all of the filters on the channel, and we don't want
961 // ordering constraints on when we add filters).
962 // We can know if this results in a promise based call how we'll create
963 // our promise (if indeed we can), and so that is the choice made here.
964
965 // Option 1, and our ideal: the transport supports promise based calls,
966 // and so we simply use the transport directly.
967 builder->channel_init()
968 ->RegisterFilter(GRPC_CLIENT_SUBCHANNEL,
969 &kClientPromiseBasedTransportFilter)
970 .Terminal()
971 .If(TransportSupportsClientPromiseBasedCalls);
972 builder->channel_init()
973 ->RegisterFilter(GRPC_CLIENT_DIRECT_CHANNEL,
974 &kClientPromiseBasedTransportFilter)
975 .Terminal()
976 .If(TransportSupportsClientPromiseBasedCalls);
977 builder->channel_init()
978 ->RegisterFilter(GRPC_SERVER_CHANNEL, &kServerPromiseBasedTransportFilter)
979 .Terminal()
980 .If(TransportSupportsServerPromiseBasedCalls);
981
982 // Option 2: the transport does not support promise based calls.
983 builder->channel_init()
984 ->RegisterFilter(GRPC_CLIENT_SUBCHANNEL, &kClientEmulatedFilter)
985 .Terminal()
986 .IfNot(TransportSupportsClientPromiseBasedCalls);
987 builder->channel_init()
988 ->RegisterFilter(GRPC_CLIENT_DIRECT_CHANNEL, &kClientEmulatedFilter)
989 .Terminal()
990 .IfNot(TransportSupportsClientPromiseBasedCalls);
991 builder->channel_init()
992 ->RegisterFilter(GRPC_SERVER_CHANNEL, &kServerEmulatedFilter)
993 .Terminal()
994 .IfNot(TransportSupportsServerPromiseBasedCalls);
995 }
996
997 } // namespace grpc_core
998