1 // Copyright 2024 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef GRPC_SRC_CORE_LIB_SURFACE_CALL_UTILS_H
16 #define GRPC_SRC_CORE_LIB_SURFACE_CALL_UTILS_H
17
18 #include <grpc/byte_buffer.h>
19 #include <grpc/compression.h>
20 #include <grpc/event_engine/event_engine.h>
21 #include <grpc/grpc.h>
22 #include <grpc/impl/call.h>
23 #include <grpc/impl/propagation_bits.h>
24 #include <grpc/slice.h>
25 #include <grpc/slice_buffer.h>
26 #include <grpc/status.h>
27 #include <grpc/support/alloc.h>
28 #include <grpc/support/atm.h>
29 #include <grpc/support/port_platform.h>
30 #include <grpc/support/string_util.h>
31 #include <inttypes.h>
32 #include <limits.h>
33 #include <stdlib.h>
34 #include <string.h>
35
36 #include <algorithm>
37 #include <atomic>
38 #include <cstdint>
39 #include <string>
40 #include <type_traits>
41 #include <utility>
42
43 #include "absl/log/check.h"
44 #include "absl/status/status.h"
45 #include "absl/strings/str_cat.h"
46 #include "absl/strings/string_view.h"
47 #include "src/core/lib/promise/activity.h"
48 #include "src/core/lib/promise/cancel_callback.h"
49 #include "src/core/lib/promise/map.h"
50 #include "src/core/lib/promise/poll.h"
51 #include "src/core/lib/promise/seq.h"
52 #include "src/core/lib/promise/status_flag.h"
53 #include "src/core/lib/surface/completion_queue.h"
54 #include "src/core/lib/transport/message.h"
55 #include "src/core/lib/transport/metadata.h"
56 #include "src/core/lib/transport/metadata_batch.h"
57 #include "src/core/util/crash.h"
58
59 namespace grpc_core {
60
61 class PublishToAppEncoder {
62 public:
PublishToAppEncoder(grpc_metadata_array * dest,const grpc_metadata_batch * encoding,bool is_client)63 explicit PublishToAppEncoder(grpc_metadata_array* dest,
64 const grpc_metadata_batch* encoding,
65 bool is_client)
66 : dest_(dest), encoding_(encoding), is_client_(is_client) {}
67
Encode(const Slice & key,const Slice & value)68 void Encode(const Slice& key, const Slice& value) {
69 Append(key.c_slice(), value.c_slice());
70 }
71
72 // Catch anything that is not explicitly handled, and do not publish it to the
73 // application. If new metadata is added to a batch that needs to be
74 // published, it should be called out here.
75 template <typename Which>
Encode(Which,const typename Which::ValueType &)76 void Encode(Which, const typename Which::ValueType&) {}
77
Encode(UserAgentMetadata,const Slice & slice)78 void Encode(UserAgentMetadata, const Slice& slice) {
79 Append(UserAgentMetadata::key(), slice);
80 }
81
Encode(HostMetadata,const Slice & slice)82 void Encode(HostMetadata, const Slice& slice) {
83 Append(HostMetadata::key(), slice);
84 }
85
Encode(GrpcPreviousRpcAttemptsMetadata,uint32_t count)86 void Encode(GrpcPreviousRpcAttemptsMetadata, uint32_t count) {
87 Append(GrpcPreviousRpcAttemptsMetadata::key(), count);
88 }
89
Encode(GrpcRetryPushbackMsMetadata,Duration count)90 void Encode(GrpcRetryPushbackMsMetadata, Duration count) {
91 Append(GrpcRetryPushbackMsMetadata::key(), count.millis());
92 }
93
Encode(LbTokenMetadata,const Slice & slice)94 void Encode(LbTokenMetadata, const Slice& slice) {
95 Append(LbTokenMetadata::key(), slice);
96 }
97
Encode(W3CTraceParentMetadata,const Slice & slice)98 void Encode(W3CTraceParentMetadata, const Slice& slice) {
99 Append(W3CTraceParentMetadata::key(), slice);
100 }
101
102 private:
Append(absl::string_view key,int64_t value)103 void Append(absl::string_view key, int64_t value) {
104 Append(StaticSlice::FromStaticString(key).c_slice(),
105 Slice::FromInt64(value).c_slice());
106 }
107
Append(absl::string_view key,const Slice & value)108 void Append(absl::string_view key, const Slice& value) {
109 Append(StaticSlice::FromStaticString(key).c_slice(), value.c_slice());
110 }
111
Append(grpc_slice key,grpc_slice value)112 void Append(grpc_slice key, grpc_slice value) {
113 if (dest_->count == dest_->capacity) {
114 Crash(absl::StrCat(
115 "Too many metadata entries: capacity=", dest_->capacity, " on ",
116 is_client_ ? "client" : "server", " encoding ", encoding_->count(),
117 " elements: ", encoding_->DebugString().c_str()));
118 }
119 auto* mdusr = &dest_->metadata[dest_->count++];
120 mdusr->key = key;
121 mdusr->value = value;
122 }
123
124 grpc_metadata_array* const dest_;
125 const grpc_metadata_batch* const encoding_;
126 const bool is_client_;
127 };
128
129 void PublishMetadataArray(grpc_metadata_batch* md, grpc_metadata_array* array,
130 bool is_client);
131 void CToMetadata(grpc_metadata* metadata, size_t count, grpc_metadata_batch* b);
132 const char* GrpcOpTypeName(grpc_op_type op);
133
134 bool ValidateMetadata(size_t count, grpc_metadata* metadata);
135 void EndOpImmediately(grpc_completion_queue* cq, void* notify_tag,
136 bool is_notify_tag_closure);
137
AreWriteFlagsValid(uint32_t flags)138 inline bool AreWriteFlagsValid(uint32_t flags) {
139 // check that only bits in GRPC_WRITE_(INTERNAL?)_USED_MASK are set
140 const uint32_t allowed_write_positions =
141 (GRPC_WRITE_USED_MASK | GRPC_WRITE_INTERNAL_USED_MASK);
142 const uint32_t invalid_positions = ~allowed_write_positions;
143 return !(flags & invalid_positions);
144 }
145
AreInitialMetadataFlagsValid(uint32_t flags)146 inline bool AreInitialMetadataFlagsValid(uint32_t flags) {
147 // check that only bits in GRPC_WRITE_(INTERNAL?)_USED_MASK are set
148 uint32_t invalid_positions = ~GRPC_INITIAL_METADATA_USED_MASK;
149 return !(flags & invalid_positions);
150 }
151
152 // One batch operation
153 // Wrapper around promise steps to perform once of the batch operations for the
154 // legacy grpc surface api.
155 template <typename SetupResult, grpc_op_type kOp>
156 class OpHandlerImpl {
157 public:
158 using PromiseFactory = promise_detail::OncePromiseFactory<void, SetupResult>;
159 using Promise = typename PromiseFactory::Promise;
160 static_assert(!std::is_same<Promise, void>::value,
161 "PromiseFactory must return a promise");
162
OpHandlerImpl()163 OpHandlerImpl() : state_(State::kDismissed) {}
OpHandlerImpl(SetupResult result)164 explicit OpHandlerImpl(SetupResult result) : state_(State::kPromiseFactory) {
165 Construct(&promise_factory_, std::move(result));
166 }
167
~OpHandlerImpl()168 ~OpHandlerImpl() {
169 switch (state_) {
170 case State::kDismissed:
171 break;
172 case State::kPromiseFactory:
173 Destruct(&promise_factory_);
174 break;
175 case State::kPromise:
176 Destruct(&promise_);
177 break;
178 }
179 }
180
181 OpHandlerImpl(const OpHandlerImpl&) = delete;
182 OpHandlerImpl& operator=(const OpHandlerImpl&) = delete;
OpHandlerImpl(OpHandlerImpl && other)183 OpHandlerImpl(OpHandlerImpl&& other) noexcept : state_(other.state_) {
184 switch (state_) {
185 case State::kDismissed:
186 break;
187 case State::kPromiseFactory:
188 Construct(&promise_factory_, std::move(other.promise_factory_));
189 break;
190 case State::kPromise:
191 Construct(&promise_, std::move(other.promise_));
192 break;
193 }
194 }
195 OpHandlerImpl& operator=(OpHandlerImpl&& other) noexcept = delete;
196
operator()197 Poll<StatusFlag> operator()() {
198 switch (state_) {
199 case State::kDismissed:
200 return Success{};
201 case State::kPromiseFactory: {
202 auto promise = promise_factory_.Make();
203 Destruct(&promise_factory_);
204 Construct(&promise_, std::move(promise));
205 state_ = State::kPromise;
206 }
207 ABSL_FALLTHROUGH_INTENDED;
208 case State::kPromise: {
209 GRPC_TRACE_LOG(call, INFO)
210 << Activity::current()->DebugTag() << "BeginPoll " << OpName();
211 auto r = poll_cast<StatusFlag>(promise_());
212 GRPC_TRACE_LOG(call, INFO)
213 << Activity::current()->DebugTag() << "EndPoll " << OpName()
214 << " --> "
215 << (r.pending() ? "PENDING" : (r.value().ok() ? "OK" : "FAILURE"));
216 return r;
217 }
218 }
219 GPR_UNREACHABLE_CODE(return Pending{});
220 }
221
222 private:
223 enum class State {
224 kDismissed,
225 kPromiseFactory,
226 kPromise,
227 };
228
OpName()229 static const char* OpName() { return GrpcOpTypeName(kOp); }
230
231 // gcc-12 has problems with this being a variant
232 GPR_NO_UNIQUE_ADDRESS State state_;
233 union {
234 PromiseFactory promise_factory_;
235 Promise promise_;
236 };
237 };
238
239 template <grpc_op_type op_type, typename PromiseFactory>
OpHandler(PromiseFactory setup)240 auto OpHandler(PromiseFactory setup) {
241 return OpHandlerImpl<PromiseFactory, op_type>(std::move(setup));
242 }
243
244 class BatchOpIndex {
245 public:
BatchOpIndex(const grpc_op * ops,size_t nops)246 BatchOpIndex(const grpc_op* ops, size_t nops) : ops_(ops) {
247 for (size_t i = 0; i < nops; i++) {
248 idxs_[ops[i].op] = static_cast<uint8_t>(i);
249 }
250 }
251
252 // 1. Check if op_type is in the batch
253 // 2. If it is, run the setup function in the context of the API call (NOT in
254 // the call party).
255 // 3. This setup function returns a promise factory which we'll then run *in*
256 // the party to do initial setup, and have it return the promise that we'll
257 // ultimately poll on til completion.
258 // Once we express our surface API in terms of core internal types this whole
259 // dance will go away.
260 template <grpc_op_type op_type, typename SetupFn>
OpHandler(SetupFn setup)261 auto OpHandler(SetupFn setup) {
262 using SetupResult = decltype(std::declval<SetupFn>()(grpc_op()));
263 using Impl = OpHandlerImpl<SetupResult, op_type>;
264 if (const grpc_op* op = this->op(op_type)) {
265 auto r = setup(*op);
266 return Impl(std::move(r));
267 } else {
268 return Impl();
269 }
270 }
271
op(grpc_op_type op_type)272 const grpc_op* op(grpc_op_type op_type) const {
273 return idxs_[op_type] == 255 ? nullptr : &ops_[idxs_[op_type]];
274 }
275
has_op(grpc_op_type op_type)276 bool has_op(grpc_op_type op_type) const { return idxs_[op_type] != 255; }
277
278 private:
279 const grpc_op* const ops_;
280 std::array<uint8_t, 8> idxs_{255, 255, 255, 255, 255, 255, 255, 255};
281 };
282
283 // Defines a promise that calls grpc_cq_end_op() (on first poll) and then waits
284 // for the callback supplied to grpc_cq_end_op() to be called, before resolving
285 // to Empty{}
286 class WaitForCqEndOp {
287 public:
WaitForCqEndOp(bool is_closure,void * tag,grpc_error_handle error,grpc_completion_queue * cq)288 WaitForCqEndOp(bool is_closure, void* tag, grpc_error_handle error,
289 grpc_completion_queue* cq)
290 : state_{NotStarted{is_closure, tag, std::move(error), cq}} {}
291
292 Poll<Empty> operator()();
293
294 WaitForCqEndOp(const WaitForCqEndOp&) = delete;
295 WaitForCqEndOp& operator=(const WaitForCqEndOp&) = delete;
WaitForCqEndOp(WaitForCqEndOp && other)296 WaitForCqEndOp(WaitForCqEndOp&& other) noexcept
297 : state_(std::move(absl::get<NotStarted>(other.state_))) {
298 other.state_.emplace<Invalid>();
299 }
300 WaitForCqEndOp& operator=(WaitForCqEndOp&& other) noexcept {
301 state_ = std::move(absl::get<NotStarted>(other.state_));
302 other.state_.emplace<Invalid>();
303 return *this;
304 }
305
306 private:
307 struct NotStarted {
308 bool is_closure;
309 void* tag;
310 grpc_error_handle error;
311 grpc_completion_queue* cq;
312 };
313 struct Started {
StartedStarted314 explicit Started(Waker waker) : waker(std::move(waker)) {}
315 Waker waker;
316 grpc_cq_completion completion;
317 std::atomic<bool> done{false};
318 };
319 struct Invalid {};
320 using State = absl::variant<NotStarted, Started, Invalid>;
321
322 static std::string StateString(const State& state);
323
324 State state_{Invalid{}};
325 };
326
327 template <typename FalliblePart, typename FinalPart>
InfallibleBatch(FalliblePart fallible_part,FinalPart final_part,bool is_notify_tag_closure,void * notify_tag,grpc_completion_queue * cq)328 auto InfallibleBatch(FalliblePart fallible_part, FinalPart final_part,
329 bool is_notify_tag_closure, void* notify_tag,
330 grpc_completion_queue* cq) {
331 // Perform fallible_part, then final_part, then wait for the
332 // completion queue to be done.
333 // If cancelled, we'll ensure the completion queue is notified.
334 // There's a slight bug here in that if we cancel this promise after
335 // the WaitForCqEndOp we'll double post -- but we don't currently do that.
336 return OnCancelFactory(
337 [fallible_part = std::move(fallible_part),
338 final_part = std::move(final_part), is_notify_tag_closure, notify_tag,
339 cq]() mutable {
340 return LogPollBatch(notify_tag,
341 Seq(std::move(fallible_part), std::move(final_part),
342 [is_notify_tag_closure, notify_tag, cq]() {
343 return WaitForCqEndOp(is_notify_tag_closure,
344 notify_tag,
345 absl::OkStatus(), cq);
346 }));
347 },
348 [cq, notify_tag]() {
349 grpc_cq_end_op(
350 cq, notify_tag, absl::OkStatus(),
351 [](void*, grpc_cq_completion* completion) { delete completion; },
352 nullptr, new grpc_cq_completion);
353 });
354 }
355
356 template <typename FalliblePart>
FallibleBatch(FalliblePart fallible_part,bool is_notify_tag_closure,void * notify_tag,grpc_completion_queue * cq)357 auto FallibleBatch(FalliblePart fallible_part, bool is_notify_tag_closure,
358 void* notify_tag, grpc_completion_queue* cq) {
359 // Perform fallible_part, then wait for the completion queue to be done.
360 // If cancelled, we'll ensure the completion queue is notified.
361 // There's a slight bug here in that if we cancel this promise after
362 // the WaitForCqEndOp we'll double post -- but we don't currently do that.
363 return OnCancelFactory(
364 [fallible_part = std::move(fallible_part), is_notify_tag_closure,
365 notify_tag, cq]() mutable {
366 return LogPollBatch(
367 notify_tag,
368 Seq(std::move(fallible_part),
369 [is_notify_tag_closure, notify_tag, cq](StatusFlag r) {
370 return WaitForCqEndOp(is_notify_tag_closure, notify_tag,
371 StatusCast<absl::Status>(r), cq);
372 }));
373 },
374 [cq]() {
375 grpc_cq_end_op(
376 cq, nullptr, absl::CancelledError(),
377 [](void*, grpc_cq_completion* completion) { delete completion; },
378 nullptr, new grpc_cq_completion);
379 });
380 }
381
382 template <typename F>
383 class PollBatchLogger {
384 public:
PollBatchLogger(void * tag,F f)385 PollBatchLogger(void* tag, F f) : tag_(tag), f_(std::move(f)) {}
386
operator()387 auto operator()() {
388 GRPC_TRACE_LOG(call, INFO) << "Poll batch " << tag_;
389 auto r = f_();
390 GRPC_TRACE_LOG(call, INFO)
391 << "Poll batch " << tag_ << " --> " << ResultString(r);
392 return r;
393 }
394
395 private:
396 template <typename T>
ResultString(Poll<T> r)397 static std::string ResultString(Poll<T> r) {
398 if (r.pending()) return "PENDING";
399 return ResultString(r.value());
400 }
ResultString(Empty)401 static std::string ResultString(Empty) { return "DONE"; }
402
403 void* tag_;
404 F f_;
405 };
406
407 template <typename F>
LogPollBatch(void * tag,F f)408 PollBatchLogger<F> LogPollBatch(void* tag, F f) {
409 return PollBatchLogger<F>(tag, std::move(f));
410 }
411
412 class MessageReceiver {
413 public:
incoming_compression_algorithm()414 grpc_compression_algorithm incoming_compression_algorithm() const {
415 return incoming_compression_algorithm_;
416 }
417
SetIncomingCompressionAlgorithm(grpc_compression_algorithm incoming_compression_algorithm)418 void SetIncomingCompressionAlgorithm(
419 grpc_compression_algorithm incoming_compression_algorithm) {
420 incoming_compression_algorithm_ = incoming_compression_algorithm;
421 }
422
last_message_flags()423 uint32_t last_message_flags() const { return test_only_last_message_flags_; }
424
425 template <typename Puller>
MakeBatchOp(const grpc_op & op,Puller * puller)426 auto MakeBatchOp(const grpc_op& op, Puller* puller) {
427 CHECK_EQ(recv_message_, nullptr);
428 recv_message_ = op.data.recv_message.recv_message;
429 return [this, puller]() mutable {
430 return Map(puller->PullMessage(),
431 [this](typename Puller::NextMessage msg) {
432 return FinishRecvMessage(std::move(msg));
433 });
434 };
435 }
436
437 private:
438 template <typename NextMessage>
FinishRecvMessage(NextMessage result)439 StatusFlag FinishRecvMessage(NextMessage result) {
440 if (!result.ok()) {
441 GRPC_TRACE_LOG(call, INFO)
442 << Activity::current()->DebugTag()
443 << "[call] RecvMessage: outstanding_recv "
444 "finishes: received end-of-stream with error";
445 *recv_message_ = nullptr;
446 recv_message_ = nullptr;
447 return Failure{};
448 }
449 if (!result.has_value()) {
450 GRPC_TRACE_LOG(call, INFO) << Activity::current()->DebugTag()
451 << "[call] RecvMessage: outstanding_recv "
452 "finishes: received end-of-stream";
453 *recv_message_ = nullptr;
454 recv_message_ = nullptr;
455 return Success{};
456 }
457 MessageHandle message = result.TakeValue();
458 test_only_last_message_flags_ = message->flags();
459 if ((message->flags() & GRPC_WRITE_INTERNAL_COMPRESS) &&
460 (incoming_compression_algorithm_ != GRPC_COMPRESS_NONE)) {
461 *recv_message_ = grpc_raw_compressed_byte_buffer_create(
462 nullptr, 0, incoming_compression_algorithm_);
463 } else {
464 *recv_message_ = grpc_raw_byte_buffer_create(nullptr, 0);
465 }
466 grpc_slice_buffer_move_into(message->payload()->c_slice_buffer(),
467 &(*recv_message_)->data.raw.slice_buffer);
468 GRPC_TRACE_LOG(call, INFO)
469 << Activity::current()->DebugTag()
470 << "[call] RecvMessage: outstanding_recv "
471 "finishes: received "
472 << (*recv_message_)->data.raw.slice_buffer.length << " byte message";
473 recv_message_ = nullptr;
474 return Success{};
475 }
476
477 grpc_byte_buffer** recv_message_ = nullptr;
478 uint32_t test_only_last_message_flags_ = 0;
479 // Compression algorithm for incoming data
480 grpc_compression_algorithm incoming_compression_algorithm_ =
481 GRPC_COMPRESS_NONE;
482 };
483
484 std::string MakeErrorString(const ServerMetadata* trailing_metadata);
485
486 } // namespace grpc_core
487
488 #endif // GRPC_SRC_CORE_LIB_SURFACE_CALL_UTILS_H
489