• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GRPC_SRC_CORE_LIB_SURFACE_CALL_UTILS_H
16 #define GRPC_SRC_CORE_LIB_SURFACE_CALL_UTILS_H
17 
18 #include <grpc/byte_buffer.h>
19 #include <grpc/compression.h>
20 #include <grpc/event_engine/event_engine.h>
21 #include <grpc/grpc.h>
22 #include <grpc/impl/call.h>
23 #include <grpc/impl/propagation_bits.h>
24 #include <grpc/slice.h>
25 #include <grpc/slice_buffer.h>
26 #include <grpc/status.h>
27 #include <grpc/support/alloc.h>
28 #include <grpc/support/atm.h>
29 #include <grpc/support/port_platform.h>
30 #include <grpc/support/string_util.h>
31 #include <inttypes.h>
32 #include <limits.h>
33 #include <stdlib.h>
34 #include <string.h>
35 
36 #include <algorithm>
37 #include <atomic>
38 #include <cstdint>
39 #include <string>
40 #include <type_traits>
41 #include <utility>
42 
43 #include "absl/log/check.h"
44 #include "absl/status/status.h"
45 #include "absl/strings/str_cat.h"
46 #include "absl/strings/string_view.h"
47 #include "src/core/lib/promise/activity.h"
48 #include "src/core/lib/promise/cancel_callback.h"
49 #include "src/core/lib/promise/map.h"
50 #include "src/core/lib/promise/poll.h"
51 #include "src/core/lib/promise/seq.h"
52 #include "src/core/lib/promise/status_flag.h"
53 #include "src/core/lib/surface/completion_queue.h"
54 #include "src/core/lib/transport/message.h"
55 #include "src/core/lib/transport/metadata.h"
56 #include "src/core/lib/transport/metadata_batch.h"
57 #include "src/core/util/crash.h"
58 
59 namespace grpc_core {
60 
61 class PublishToAppEncoder {
62  public:
PublishToAppEncoder(grpc_metadata_array * dest,const grpc_metadata_batch * encoding,bool is_client)63   explicit PublishToAppEncoder(grpc_metadata_array* dest,
64                                const grpc_metadata_batch* encoding,
65                                bool is_client)
66       : dest_(dest), encoding_(encoding), is_client_(is_client) {}
67 
Encode(const Slice & key,const Slice & value)68   void Encode(const Slice& key, const Slice& value) {
69     Append(key.c_slice(), value.c_slice());
70   }
71 
72   // Catch anything that is not explicitly handled, and do not publish it to the
73   // application. If new metadata is added to a batch that needs to be
74   // published, it should be called out here.
75   template <typename Which>
Encode(Which,const typename Which::ValueType &)76   void Encode(Which, const typename Which::ValueType&) {}
77 
Encode(UserAgentMetadata,const Slice & slice)78   void Encode(UserAgentMetadata, const Slice& slice) {
79     Append(UserAgentMetadata::key(), slice);
80   }
81 
Encode(HostMetadata,const Slice & slice)82   void Encode(HostMetadata, const Slice& slice) {
83     Append(HostMetadata::key(), slice);
84   }
85 
Encode(GrpcPreviousRpcAttemptsMetadata,uint32_t count)86   void Encode(GrpcPreviousRpcAttemptsMetadata, uint32_t count) {
87     Append(GrpcPreviousRpcAttemptsMetadata::key(), count);
88   }
89 
Encode(GrpcRetryPushbackMsMetadata,Duration count)90   void Encode(GrpcRetryPushbackMsMetadata, Duration count) {
91     Append(GrpcRetryPushbackMsMetadata::key(), count.millis());
92   }
93 
Encode(LbTokenMetadata,const Slice & slice)94   void Encode(LbTokenMetadata, const Slice& slice) {
95     Append(LbTokenMetadata::key(), slice);
96   }
97 
Encode(W3CTraceParentMetadata,const Slice & slice)98   void Encode(W3CTraceParentMetadata, const Slice& slice) {
99     Append(W3CTraceParentMetadata::key(), slice);
100   }
101 
102  private:
Append(absl::string_view key,int64_t value)103   void Append(absl::string_view key, int64_t value) {
104     Append(StaticSlice::FromStaticString(key).c_slice(),
105            Slice::FromInt64(value).c_slice());
106   }
107 
Append(absl::string_view key,const Slice & value)108   void Append(absl::string_view key, const Slice& value) {
109     Append(StaticSlice::FromStaticString(key).c_slice(), value.c_slice());
110   }
111 
Append(grpc_slice key,grpc_slice value)112   void Append(grpc_slice key, grpc_slice value) {
113     if (dest_->count == dest_->capacity) {
114       Crash(absl::StrCat(
115           "Too many metadata entries: capacity=", dest_->capacity, " on ",
116           is_client_ ? "client" : "server", " encoding ", encoding_->count(),
117           " elements: ", encoding_->DebugString().c_str()));
118     }
119     auto* mdusr = &dest_->metadata[dest_->count++];
120     mdusr->key = key;
121     mdusr->value = value;
122   }
123 
124   grpc_metadata_array* const dest_;
125   const grpc_metadata_batch* const encoding_;
126   const bool is_client_;
127 };
128 
129 void PublishMetadataArray(grpc_metadata_batch* md, grpc_metadata_array* array,
130                           bool is_client);
131 void CToMetadata(grpc_metadata* metadata, size_t count, grpc_metadata_batch* b);
132 const char* GrpcOpTypeName(grpc_op_type op);
133 
134 bool ValidateMetadata(size_t count, grpc_metadata* metadata);
135 void EndOpImmediately(grpc_completion_queue* cq, void* notify_tag,
136                       bool is_notify_tag_closure);
137 
AreWriteFlagsValid(uint32_t flags)138 inline bool AreWriteFlagsValid(uint32_t flags) {
139   // check that only bits in GRPC_WRITE_(INTERNAL?)_USED_MASK are set
140   const uint32_t allowed_write_positions =
141       (GRPC_WRITE_USED_MASK | GRPC_WRITE_INTERNAL_USED_MASK);
142   const uint32_t invalid_positions = ~allowed_write_positions;
143   return !(flags & invalid_positions);
144 }
145 
AreInitialMetadataFlagsValid(uint32_t flags)146 inline bool AreInitialMetadataFlagsValid(uint32_t flags) {
147   // check that only bits in GRPC_WRITE_(INTERNAL?)_USED_MASK are set
148   uint32_t invalid_positions = ~GRPC_INITIAL_METADATA_USED_MASK;
149   return !(flags & invalid_positions);
150 }
151 
152 // One batch operation
153 // Wrapper around promise steps to perform once of the batch operations for the
154 // legacy grpc surface api.
155 template <typename SetupResult, grpc_op_type kOp>
156 class OpHandlerImpl {
157  public:
158   using PromiseFactory = promise_detail::OncePromiseFactory<void, SetupResult>;
159   using Promise = typename PromiseFactory::Promise;
160   static_assert(!std::is_same<Promise, void>::value,
161                 "PromiseFactory must return a promise");
162 
OpHandlerImpl()163   OpHandlerImpl() : state_(State::kDismissed) {}
OpHandlerImpl(SetupResult result)164   explicit OpHandlerImpl(SetupResult result) : state_(State::kPromiseFactory) {
165     Construct(&promise_factory_, std::move(result));
166   }
167 
~OpHandlerImpl()168   ~OpHandlerImpl() {
169     switch (state_) {
170       case State::kDismissed:
171         break;
172       case State::kPromiseFactory:
173         Destruct(&promise_factory_);
174         break;
175       case State::kPromise:
176         Destruct(&promise_);
177         break;
178     }
179   }
180 
181   OpHandlerImpl(const OpHandlerImpl&) = delete;
182   OpHandlerImpl& operator=(const OpHandlerImpl&) = delete;
OpHandlerImpl(OpHandlerImpl && other)183   OpHandlerImpl(OpHandlerImpl&& other) noexcept : state_(other.state_) {
184     switch (state_) {
185       case State::kDismissed:
186         break;
187       case State::kPromiseFactory:
188         Construct(&promise_factory_, std::move(other.promise_factory_));
189         break;
190       case State::kPromise:
191         Construct(&promise_, std::move(other.promise_));
192         break;
193     }
194   }
195   OpHandlerImpl& operator=(OpHandlerImpl&& other) noexcept = delete;
196 
operator()197   Poll<StatusFlag> operator()() {
198     switch (state_) {
199       case State::kDismissed:
200         return Success{};
201       case State::kPromiseFactory: {
202         auto promise = promise_factory_.Make();
203         Destruct(&promise_factory_);
204         Construct(&promise_, std::move(promise));
205         state_ = State::kPromise;
206       }
207         ABSL_FALLTHROUGH_INTENDED;
208       case State::kPromise: {
209         GRPC_TRACE_LOG(call, INFO)
210             << Activity::current()->DebugTag() << "BeginPoll " << OpName();
211         auto r = poll_cast<StatusFlag>(promise_());
212         GRPC_TRACE_LOG(call, INFO)
213             << Activity::current()->DebugTag() << "EndPoll " << OpName()
214             << " --> "
215             << (r.pending() ? "PENDING" : (r.value().ok() ? "OK" : "FAILURE"));
216         return r;
217       }
218     }
219     GPR_UNREACHABLE_CODE(return Pending{});
220   }
221 
222  private:
223   enum class State {
224     kDismissed,
225     kPromiseFactory,
226     kPromise,
227   };
228 
OpName()229   static const char* OpName() { return GrpcOpTypeName(kOp); }
230 
231   // gcc-12 has problems with this being a variant
232   GPR_NO_UNIQUE_ADDRESS State state_;
233   union {
234     PromiseFactory promise_factory_;
235     Promise promise_;
236   };
237 };
238 
239 template <grpc_op_type op_type, typename PromiseFactory>
OpHandler(PromiseFactory setup)240 auto OpHandler(PromiseFactory setup) {
241   return OpHandlerImpl<PromiseFactory, op_type>(std::move(setup));
242 }
243 
244 class BatchOpIndex {
245  public:
BatchOpIndex(const grpc_op * ops,size_t nops)246   BatchOpIndex(const grpc_op* ops, size_t nops) : ops_(ops) {
247     for (size_t i = 0; i < nops; i++) {
248       idxs_[ops[i].op] = static_cast<uint8_t>(i);
249     }
250   }
251 
252   // 1. Check if op_type is in the batch
253   // 2. If it is, run the setup function in the context of the API call (NOT in
254   //    the call party).
255   // 3. This setup function returns a promise factory which we'll then run *in*
256   //    the party to do initial setup, and have it return the promise that we'll
257   //    ultimately poll on til completion.
258   // Once we express our surface API in terms of core internal types this whole
259   // dance will go away.
260   template <grpc_op_type op_type, typename SetupFn>
OpHandler(SetupFn setup)261   auto OpHandler(SetupFn setup) {
262     using SetupResult = decltype(std::declval<SetupFn>()(grpc_op()));
263     using Impl = OpHandlerImpl<SetupResult, op_type>;
264     if (const grpc_op* op = this->op(op_type)) {
265       auto r = setup(*op);
266       return Impl(std::move(r));
267     } else {
268       return Impl();
269     }
270   }
271 
op(grpc_op_type op_type)272   const grpc_op* op(grpc_op_type op_type) const {
273     return idxs_[op_type] == 255 ? nullptr : &ops_[idxs_[op_type]];
274   }
275 
has_op(grpc_op_type op_type)276   bool has_op(grpc_op_type op_type) const { return idxs_[op_type] != 255; }
277 
278  private:
279   const grpc_op* const ops_;
280   std::array<uint8_t, 8> idxs_{255, 255, 255, 255, 255, 255, 255, 255};
281 };
282 
283 // Defines a promise that calls grpc_cq_end_op() (on first poll) and then waits
284 // for the callback supplied to grpc_cq_end_op() to be called, before resolving
285 // to Empty{}
286 class WaitForCqEndOp {
287  public:
WaitForCqEndOp(bool is_closure,void * tag,grpc_error_handle error,grpc_completion_queue * cq)288   WaitForCqEndOp(bool is_closure, void* tag, grpc_error_handle error,
289                  grpc_completion_queue* cq)
290       : state_{NotStarted{is_closure, tag, std::move(error), cq}} {}
291 
292   Poll<Empty> operator()();
293 
294   WaitForCqEndOp(const WaitForCqEndOp&) = delete;
295   WaitForCqEndOp& operator=(const WaitForCqEndOp&) = delete;
WaitForCqEndOp(WaitForCqEndOp && other)296   WaitForCqEndOp(WaitForCqEndOp&& other) noexcept
297       : state_(std::move(absl::get<NotStarted>(other.state_))) {
298     other.state_.emplace<Invalid>();
299   }
300   WaitForCqEndOp& operator=(WaitForCqEndOp&& other) noexcept {
301     state_ = std::move(absl::get<NotStarted>(other.state_));
302     other.state_.emplace<Invalid>();
303     return *this;
304   }
305 
306  private:
307   struct NotStarted {
308     bool is_closure;
309     void* tag;
310     grpc_error_handle error;
311     grpc_completion_queue* cq;
312   };
313   struct Started {
StartedStarted314     explicit Started(Waker waker) : waker(std::move(waker)) {}
315     Waker waker;
316     grpc_cq_completion completion;
317     std::atomic<bool> done{false};
318   };
319   struct Invalid {};
320   using State = absl::variant<NotStarted, Started, Invalid>;
321 
322   static std::string StateString(const State& state);
323 
324   State state_{Invalid{}};
325 };
326 
327 template <typename FalliblePart, typename FinalPart>
InfallibleBatch(FalliblePart fallible_part,FinalPart final_part,bool is_notify_tag_closure,void * notify_tag,grpc_completion_queue * cq)328 auto InfallibleBatch(FalliblePart fallible_part, FinalPart final_part,
329                      bool is_notify_tag_closure, void* notify_tag,
330                      grpc_completion_queue* cq) {
331   // Perform fallible_part, then final_part, then wait for the
332   // completion queue to be done.
333   // If cancelled, we'll ensure the completion queue is notified.
334   // There's a slight bug here in that if we cancel this promise after
335   // the WaitForCqEndOp we'll double post -- but we don't currently do that.
336   return OnCancelFactory(
337       [fallible_part = std::move(fallible_part),
338        final_part = std::move(final_part), is_notify_tag_closure, notify_tag,
339        cq]() mutable {
340         return LogPollBatch(notify_tag,
341                             Seq(std::move(fallible_part), std::move(final_part),
342                                 [is_notify_tag_closure, notify_tag, cq]() {
343                                   return WaitForCqEndOp(is_notify_tag_closure,
344                                                         notify_tag,
345                                                         absl::OkStatus(), cq);
346                                 }));
347       },
348       [cq, notify_tag]() {
349         grpc_cq_end_op(
350             cq, notify_tag, absl::OkStatus(),
351             [](void*, grpc_cq_completion* completion) { delete completion; },
352             nullptr, new grpc_cq_completion);
353       });
354 }
355 
356 template <typename FalliblePart>
FallibleBatch(FalliblePart fallible_part,bool is_notify_tag_closure,void * notify_tag,grpc_completion_queue * cq)357 auto FallibleBatch(FalliblePart fallible_part, bool is_notify_tag_closure,
358                    void* notify_tag, grpc_completion_queue* cq) {
359   // Perform fallible_part, then wait for the completion queue to be done.
360   // If cancelled, we'll ensure the completion queue is notified.
361   // There's a slight bug here in that if we cancel this promise after
362   // the WaitForCqEndOp we'll double post -- but we don't currently do that.
363   return OnCancelFactory(
364       [fallible_part = std::move(fallible_part), is_notify_tag_closure,
365        notify_tag, cq]() mutable {
366         return LogPollBatch(
367             notify_tag,
368             Seq(std::move(fallible_part),
369                 [is_notify_tag_closure, notify_tag, cq](StatusFlag r) {
370                   return WaitForCqEndOp(is_notify_tag_closure, notify_tag,
371                                         StatusCast<absl::Status>(r), cq);
372                 }));
373       },
374       [cq]() {
375         grpc_cq_end_op(
376             cq, nullptr, absl::CancelledError(),
377             [](void*, grpc_cq_completion* completion) { delete completion; },
378             nullptr, new grpc_cq_completion);
379       });
380 }
381 
382 template <typename F>
383 class PollBatchLogger {
384  public:
PollBatchLogger(void * tag,F f)385   PollBatchLogger(void* tag, F f) : tag_(tag), f_(std::move(f)) {}
386 
operator()387   auto operator()() {
388     GRPC_TRACE_LOG(call, INFO) << "Poll batch " << tag_;
389     auto r = f_();
390     GRPC_TRACE_LOG(call, INFO)
391         << "Poll batch " << tag_ << " --> " << ResultString(r);
392     return r;
393   }
394 
395  private:
396   template <typename T>
ResultString(Poll<T> r)397   static std::string ResultString(Poll<T> r) {
398     if (r.pending()) return "PENDING";
399     return ResultString(r.value());
400   }
ResultString(Empty)401   static std::string ResultString(Empty) { return "DONE"; }
402 
403   void* tag_;
404   F f_;
405 };
406 
407 template <typename F>
LogPollBatch(void * tag,F f)408 PollBatchLogger<F> LogPollBatch(void* tag, F f) {
409   return PollBatchLogger<F>(tag, std::move(f));
410 }
411 
412 class MessageReceiver {
413  public:
incoming_compression_algorithm()414   grpc_compression_algorithm incoming_compression_algorithm() const {
415     return incoming_compression_algorithm_;
416   }
417 
SetIncomingCompressionAlgorithm(grpc_compression_algorithm incoming_compression_algorithm)418   void SetIncomingCompressionAlgorithm(
419       grpc_compression_algorithm incoming_compression_algorithm) {
420     incoming_compression_algorithm_ = incoming_compression_algorithm;
421   }
422 
last_message_flags()423   uint32_t last_message_flags() const { return test_only_last_message_flags_; }
424 
425   template <typename Puller>
MakeBatchOp(const grpc_op & op,Puller * puller)426   auto MakeBatchOp(const grpc_op& op, Puller* puller) {
427     CHECK_EQ(recv_message_, nullptr);
428     recv_message_ = op.data.recv_message.recv_message;
429     return [this, puller]() mutable {
430       return Map(puller->PullMessage(),
431                  [this](typename Puller::NextMessage msg) {
432                    return FinishRecvMessage(std::move(msg));
433                  });
434     };
435   }
436 
437  private:
438   template <typename NextMessage>
FinishRecvMessage(NextMessage result)439   StatusFlag FinishRecvMessage(NextMessage result) {
440     if (!result.ok()) {
441       GRPC_TRACE_LOG(call, INFO)
442           << Activity::current()->DebugTag()
443           << "[call] RecvMessage: outstanding_recv "
444              "finishes: received end-of-stream with error";
445       *recv_message_ = nullptr;
446       recv_message_ = nullptr;
447       return Failure{};
448     }
449     if (!result.has_value()) {
450       GRPC_TRACE_LOG(call, INFO) << Activity::current()->DebugTag()
451                                  << "[call] RecvMessage: outstanding_recv "
452                                     "finishes: received end-of-stream";
453       *recv_message_ = nullptr;
454       recv_message_ = nullptr;
455       return Success{};
456     }
457     MessageHandle message = result.TakeValue();
458     test_only_last_message_flags_ = message->flags();
459     if ((message->flags() & GRPC_WRITE_INTERNAL_COMPRESS) &&
460         (incoming_compression_algorithm_ != GRPC_COMPRESS_NONE)) {
461       *recv_message_ = grpc_raw_compressed_byte_buffer_create(
462           nullptr, 0, incoming_compression_algorithm_);
463     } else {
464       *recv_message_ = grpc_raw_byte_buffer_create(nullptr, 0);
465     }
466     grpc_slice_buffer_move_into(message->payload()->c_slice_buffer(),
467                                 &(*recv_message_)->data.raw.slice_buffer);
468     GRPC_TRACE_LOG(call, INFO)
469         << Activity::current()->DebugTag()
470         << "[call] RecvMessage: outstanding_recv "
471            "finishes: received "
472         << (*recv_message_)->data.raw.slice_buffer.length << " byte message";
473     recv_message_ = nullptr;
474     return Success{};
475   }
476 
477   grpc_byte_buffer** recv_message_ = nullptr;
478   uint32_t test_only_last_message_flags_ = 0;
479   // Compression algorithm for incoming data
480   grpc_compression_algorithm incoming_compression_algorithm_ =
481       GRPC_COMPRESS_NONE;
482 };
483 
484 std::string MakeErrorString(const ServerMetadata* trailing_metadata);
485 
486 }  // namespace grpc_core
487 
488 #endif  // GRPC_SRC_CORE_LIB_SURFACE_CALL_UTILS_H
489