• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #ifndef GRPCPP_IMPL_CALL_OP_SET_H
20 #define GRPCPP_IMPL_CALL_OP_SET_H
21 
22 #include <grpc/grpc.h>
23 #include <grpc/impl/compression_types.h>
24 #include <grpc/impl/grpc_types.h>
25 #include <grpc/slice.h>
26 #include <grpc/support/alloc.h>
27 #include <grpcpp/client_context.h>
28 #include <grpcpp/completion_queue.h>
29 #include <grpcpp/impl/call.h>
30 #include <grpcpp/impl/call_hook.h>
31 #include <grpcpp/impl/call_op_set_interface.h>
32 #include <grpcpp/impl/codegen/intercepted_channel.h>
33 #include <grpcpp/impl/completion_queue_tag.h>
34 #include <grpcpp/impl/interceptor_common.h>
35 #include <grpcpp/impl/serialization_traits.h>
36 #include <grpcpp/support/byte_buffer.h>
37 #include <grpcpp/support/config.h>
38 #include <grpcpp/support/slice.h>
39 #include <grpcpp/support/string_ref.h>
40 
41 #include <cstring>
42 #include <map>
43 #include <memory>
44 
45 #include "absl/log/absl_check.h"
46 #include "absl/log/absl_log.h"
47 
48 namespace grpc {
49 
50 namespace internal {
51 class Call;
52 class CallHook;
53 
54 // TODO(yangg) if the map is changed before we send, the pointers will be a
55 // mess. Make sure it does not happen.
FillMetadataArray(const std::multimap<std::string,std::string> & metadata,size_t * metadata_count,const std::string & optional_error_details)56 inline grpc_metadata* FillMetadataArray(
57     const std::multimap<std::string, std::string>& metadata,
58     size_t* metadata_count, const std::string& optional_error_details) {
59   *metadata_count = metadata.size() + (optional_error_details.empty() ? 0 : 1);
60   if (*metadata_count == 0) {
61     return nullptr;
62   }
63   grpc_metadata* metadata_array = static_cast<grpc_metadata*>(
64       gpr_malloc((*metadata_count) * sizeof(grpc_metadata)));
65   size_t i = 0;
66   for (auto iter = metadata.cbegin(); iter != metadata.cend(); ++iter, ++i) {
67     metadata_array[i].key = SliceReferencingString(iter->first);
68     metadata_array[i].value = SliceReferencingString(iter->second);
69   }
70   if (!optional_error_details.empty()) {
71     metadata_array[i].key = grpc_slice_from_static_buffer(
72         kBinaryErrorDetailsKey, sizeof(kBinaryErrorDetailsKey) - 1);
73     metadata_array[i].value = SliceReferencingString(optional_error_details);
74   }
75   return metadata_array;
76 }
77 }  // namespace internal
78 
79 /// Per-message write options.
80 class WriteOptions {
81  public:
WriteOptions()82   WriteOptions() : flags_(0), last_message_(false) {}
83 
84   /// Clear all flags.
Clear()85   inline void Clear() { flags_ = 0; }
86 
87   /// Returns raw flags bitset.
flags()88   inline uint32_t flags() const { return flags_; }
89 
90   /// Sets flag for the disabling of compression for the next message write.
91   ///
92   /// \sa GRPC_WRITE_NO_COMPRESS
set_no_compression()93   inline WriteOptions& set_no_compression() {
94     SetBit(GRPC_WRITE_NO_COMPRESS);
95     return *this;
96   }
97 
98   /// Clears flag for the disabling of compression for the next message write.
99   ///
100   /// \sa GRPC_WRITE_NO_COMPRESS
clear_no_compression()101   inline WriteOptions& clear_no_compression() {
102     ClearBit(GRPC_WRITE_NO_COMPRESS);
103     return *this;
104   }
105 
106   /// Get value for the flag indicating whether compression for the next
107   /// message write is forcefully disabled.
108   ///
109   /// \sa GRPC_WRITE_NO_COMPRESS
get_no_compression()110   inline bool get_no_compression() const {
111     return GetBit(GRPC_WRITE_NO_COMPRESS);
112   }
113 
114   /// Sets flag indicating that the write may be buffered and need not go out on
115   /// the wire immediately.
116   ///
117   /// \sa GRPC_WRITE_BUFFER_HINT
set_buffer_hint()118   inline WriteOptions& set_buffer_hint() {
119     SetBit(GRPC_WRITE_BUFFER_HINT);
120     return *this;
121   }
122 
123   /// Clears flag indicating that the write may be buffered and need not go out
124   /// on the wire immediately.
125   ///
126   /// \sa GRPC_WRITE_BUFFER_HINT
clear_buffer_hint()127   inline WriteOptions& clear_buffer_hint() {
128     ClearBit(GRPC_WRITE_BUFFER_HINT);
129     return *this;
130   }
131 
132   /// Get value for the flag indicating that the write may be buffered and need
133   /// not go out on the wire immediately.
134   ///
135   /// \sa GRPC_WRITE_BUFFER_HINT
get_buffer_hint()136   inline bool get_buffer_hint() const { return GetBit(GRPC_WRITE_BUFFER_HINT); }
137 
138   /// corked bit: aliases set_buffer_hint currently, with the intent that
139   /// set_buffer_hint will be removed in the future
set_corked()140   inline WriteOptions& set_corked() {
141     SetBit(GRPC_WRITE_BUFFER_HINT);
142     return *this;
143   }
144 
clear_corked()145   inline WriteOptions& clear_corked() {
146     ClearBit(GRPC_WRITE_BUFFER_HINT);
147     return *this;
148   }
149 
is_corked()150   inline bool is_corked() const { return GetBit(GRPC_WRITE_BUFFER_HINT); }
151 
152   /// last-message bit: indicates this is the last message in a stream
153   /// client-side:  makes Write the equivalent of performing Write, WritesDone
154   /// in a single step
155   /// server-side:  hold the Write until the service handler returns (sync api)
156   /// or until Finish is called (async api)
set_last_message()157   inline WriteOptions& set_last_message() {
158     last_message_ = true;
159     return *this;
160   }
161 
162   /// Clears flag indicating that this is the last message in a stream,
163   /// disabling coalescing.
clear_last_message()164   inline WriteOptions& clear_last_message() {
165     last_message_ = false;
166     return *this;
167   }
168 
169   /// Get value for the flag indicating that this is the last message, and
170   /// should be coalesced with trailing metadata.
171   ///
172   /// \sa GRPC_WRITE_LAST_MESSAGE
is_last_message()173   bool is_last_message() const { return last_message_; }
174 
175   /// Guarantee that all bytes have been written to the socket before completing
176   /// this write (usually writes are completed when they pass flow control).
set_write_through()177   inline WriteOptions& set_write_through() {
178     SetBit(GRPC_WRITE_THROUGH);
179     return *this;
180   }
181 
clear_write_through()182   inline WriteOptions& clear_write_through() {
183     ClearBit(GRPC_WRITE_THROUGH);
184     return *this;
185   }
186 
is_write_through()187   inline bool is_write_through() const { return GetBit(GRPC_WRITE_THROUGH); }
188 
189  private:
SetBit(const uint32_t mask)190   void SetBit(const uint32_t mask) { flags_ |= mask; }
191 
ClearBit(const uint32_t mask)192   void ClearBit(const uint32_t mask) { flags_ &= ~mask; }
193 
GetBit(const uint32_t mask)194   bool GetBit(const uint32_t mask) const { return (flags_ & mask) != 0; }
195 
196   uint32_t flags_;
197   bool last_message_;
198 };
199 
200 namespace internal {
201 
202 /// Default argument for CallOpSet. The Unused parameter is unused by
203 /// the class, but can be used for generating multiple names for the
204 /// same thing.
205 template <int Unused>
206 class CallNoOp {
207  protected:
AddOp(grpc_op *,size_t *)208   void AddOp(grpc_op* /*ops*/, size_t* /*nops*/) {}
FinishOp(bool *)209   void FinishOp(bool* /*status*/) {}
SetInterceptionHookPoint(InterceptorBatchMethodsImpl *)210   void SetInterceptionHookPoint(
211       InterceptorBatchMethodsImpl* /*interceptor_methods*/) {}
SetFinishInterceptionHookPoint(InterceptorBatchMethodsImpl *)212   void SetFinishInterceptionHookPoint(
213       InterceptorBatchMethodsImpl* /*interceptor_methods*/) {}
SetHijackingState(InterceptorBatchMethodsImpl *)214   void SetHijackingState(InterceptorBatchMethodsImpl* /*interceptor_methods*/) {
215   }
216 };
217 
218 class CallOpSendInitialMetadata {
219  public:
CallOpSendInitialMetadata()220   CallOpSendInitialMetadata() : send_(false) {
221     maybe_compression_level_.is_set = false;
222   }
223 
SendInitialMetadata(std::multimap<std::string,std::string> * metadata,uint32_t flags)224   void SendInitialMetadata(std::multimap<std::string, std::string>* metadata,
225                            uint32_t flags) {
226     maybe_compression_level_.is_set = false;
227     send_ = true;
228     flags_ = flags;
229     metadata_map_ = metadata;
230   }
231 
set_compression_level(grpc_compression_level level)232   void set_compression_level(grpc_compression_level level) {
233     maybe_compression_level_.is_set = true;
234     maybe_compression_level_.level = level;
235   }
236 
237  protected:
AddOp(grpc_op * ops,size_t * nops)238   void AddOp(grpc_op* ops, size_t* nops) {
239     if (!send_ || hijacked_) return;
240     grpc_op* op = &ops[(*nops)++];
241     op->op = GRPC_OP_SEND_INITIAL_METADATA;
242     op->flags = flags_;
243     op->reserved = nullptr;
244     initial_metadata_ =
245         FillMetadataArray(*metadata_map_, &initial_metadata_count_, "");
246     op->data.send_initial_metadata.count = initial_metadata_count_;
247     op->data.send_initial_metadata.metadata = initial_metadata_;
248     op->data.send_initial_metadata.maybe_compression_level.is_set =
249         maybe_compression_level_.is_set;
250     if (maybe_compression_level_.is_set) {
251       op->data.send_initial_metadata.maybe_compression_level.level =
252           maybe_compression_level_.level;
253     }
254   }
FinishOp(bool *)255   void FinishOp(bool* /*status*/) {
256     if (!send_ || hijacked_) return;
257     gpr_free(initial_metadata_);
258     send_ = false;
259   }
260 
SetInterceptionHookPoint(InterceptorBatchMethodsImpl * interceptor_methods)261   void SetInterceptionHookPoint(
262       InterceptorBatchMethodsImpl* interceptor_methods) {
263     if (!send_) return;
264     interceptor_methods->AddInterceptionHookPoint(
265         experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA);
266     interceptor_methods->SetSendInitialMetadata(metadata_map_);
267   }
268 
SetFinishInterceptionHookPoint(InterceptorBatchMethodsImpl *)269   void SetFinishInterceptionHookPoint(
270       InterceptorBatchMethodsImpl* /*interceptor_methods*/) {}
271 
SetHijackingState(InterceptorBatchMethodsImpl *)272   void SetHijackingState(InterceptorBatchMethodsImpl* /*interceptor_methods*/) {
273     hijacked_ = true;
274   }
275 
276   bool hijacked_ = false;
277   bool send_;
278   uint32_t flags_;
279   size_t initial_metadata_count_;
280   std::multimap<std::string, std::string>* metadata_map_;
281   grpc_metadata* initial_metadata_;
282   struct {
283     bool is_set;
284     grpc_compression_level level;
285   } maybe_compression_level_;
286 };
287 
288 class CallOpSendMessage {
289  public:
CallOpSendMessage()290   CallOpSendMessage() : send_buf_() {}
291 
292   /// Send \a message using \a options for the write. The \a options are cleared
293   /// after use.
294   template <class M>
295   GRPC_MUST_USE_RESULT Status SendMessage(const M& message,
296                                           WriteOptions options);
297 
298   template <class M>
299   GRPC_MUST_USE_RESULT Status SendMessage(const M& message);
300 
301   /// Send \a message using \a options for the write. The \a options are cleared
302   /// after use. This form of SendMessage allows gRPC to reference \a message
303   /// beyond the lifetime of SendMessage.
304   template <class M>
305   GRPC_MUST_USE_RESULT Status SendMessagePtr(const M* message,
306                                              WriteOptions options);
307 
308   /// This form of SendMessage allows gRPC to reference \a message beyond the
309   /// lifetime of SendMessage.
310   template <class M>
311   GRPC_MUST_USE_RESULT Status SendMessagePtr(const M* message);
312 
313  protected:
AddOp(grpc_op * ops,size_t * nops)314   void AddOp(grpc_op* ops, size_t* nops) {
315     if (msg_ == nullptr && !send_buf_.Valid()) return;
316     if (hijacked_) {
317       serializer_ = nullptr;
318       return;
319     }
320     if (msg_ != nullptr) {
321       ABSL_CHECK(serializer_(msg_).ok());
322     }
323     serializer_ = nullptr;
324     grpc_op* op = &ops[(*nops)++];
325     op->op = GRPC_OP_SEND_MESSAGE;
326     op->flags = write_options_.flags();
327     op->reserved = nullptr;
328     op->data.send_message.send_message = send_buf_.c_buffer();
329     // Flags are per-message: clear them after use.
330     write_options_.Clear();
331   }
FinishOp(bool * status)332   void FinishOp(bool* status) {
333     if (msg_ == nullptr && !send_buf_.Valid()) return;
334     send_buf_.Clear();
335     if (hijacked_ && failed_send_) {
336       // Hijacking interceptor failed this Op
337       *status = false;
338     } else if (!*status) {
339       // This Op was passed down to core and the Op failed
340       failed_send_ = true;
341     }
342   }
343 
SetInterceptionHookPoint(InterceptorBatchMethodsImpl * interceptor_methods)344   void SetInterceptionHookPoint(
345       InterceptorBatchMethodsImpl* interceptor_methods) {
346     if (msg_ == nullptr && !send_buf_.Valid()) return;
347     interceptor_methods->AddInterceptionHookPoint(
348         experimental::InterceptionHookPoints::PRE_SEND_MESSAGE);
349     interceptor_methods->SetSendMessage(&send_buf_, &msg_, &failed_send_,
350                                         serializer_);
351   }
352 
SetFinishInterceptionHookPoint(InterceptorBatchMethodsImpl * interceptor_methods)353   void SetFinishInterceptionHookPoint(
354       InterceptorBatchMethodsImpl* interceptor_methods) {
355     if (msg_ != nullptr || send_buf_.Valid()) {
356       interceptor_methods->AddInterceptionHookPoint(
357           experimental::InterceptionHookPoints::POST_SEND_MESSAGE);
358     }
359     send_buf_.Clear();
360     msg_ = nullptr;
361     // The contents of the SendMessage value that was previously set
362     // has had its references stolen by core's operations
363     interceptor_methods->SetSendMessage(nullptr, nullptr, &failed_send_,
364                                         nullptr);
365   }
366 
SetHijackingState(InterceptorBatchMethodsImpl *)367   void SetHijackingState(InterceptorBatchMethodsImpl* /*interceptor_methods*/) {
368     hijacked_ = true;
369   }
370 
371  private:
372   const void* msg_ = nullptr;  // The original non-serialized message
373   bool hijacked_ = false;
374   bool failed_send_ = false;
375   ByteBuffer send_buf_;
376   WriteOptions write_options_;
377   std::function<Status(const void*)> serializer_;
378 };
379 
380 template <class M>
SendMessage(const M & message,WriteOptions options)381 Status CallOpSendMessage::SendMessage(const M& message, WriteOptions options) {
382   write_options_ = options;
383   // Serialize immediately since we do not have access to the message pointer
384   bool own_buf;
385   Status result = SerializationTraits<M>::Serialize(
386       message, send_buf_.bbuf_ptr(), &own_buf);
387   if (!own_buf) {
388     send_buf_.Duplicate();
389   }
390   return result;
391 }
392 
393 template <class M>
SendMessage(const M & message)394 Status CallOpSendMessage::SendMessage(const M& message) {
395   return SendMessage(message, WriteOptions());
396 }
397 
398 template <class M>
SendMessagePtr(const M * message,WriteOptions options)399 Status CallOpSendMessage::SendMessagePtr(const M* message,
400                                          WriteOptions options) {
401   msg_ = message;
402   write_options_ = options;
403   // Store the serializer for later since we have access to the message
404   serializer_ = [this](const void* message) {
405     bool own_buf;
406     // TODO(vjpai): Remove the void below when possible
407     // The void in the template parameter below should not be needed
408     // (since it should be implicit) but is needed due to an observed
409     // difference in behavior between clang and gcc for certain internal users
410     Status result = SerializationTraits<M>::Serialize(
411         *static_cast<const M*>(message), send_buf_.bbuf_ptr(), &own_buf);
412     if (!own_buf) {
413       send_buf_.Duplicate();
414     }
415     return result;
416   };
417   return Status();
418 }
419 
420 template <class M>
SendMessagePtr(const M * message)421 Status CallOpSendMessage::SendMessagePtr(const M* message) {
422   return SendMessagePtr(message, WriteOptions());
423 }
424 
425 template <class R>
426 class CallOpRecvMessage {
427  public:
RecvMessage(R * message)428   void RecvMessage(R* message) { message_ = message; }
429 
430   // Do not change status if no message is received.
AllowNoMessage()431   void AllowNoMessage() { allow_not_getting_message_ = true; }
432 
433   bool got_message = false;
434 
435  protected:
AddOp(grpc_op * ops,size_t * nops)436   void AddOp(grpc_op* ops, size_t* nops) {
437     if (message_ == nullptr || hijacked_) return;
438     grpc_op* op = &ops[(*nops)++];
439     op->op = GRPC_OP_RECV_MESSAGE;
440     op->flags = 0;
441     op->reserved = nullptr;
442     op->data.recv_message.recv_message = recv_buf_.c_buffer_ptr();
443   }
444 
FinishOp(bool * status)445   void FinishOp(bool* status) {
446     if (message_ == nullptr) return;
447     if (recv_buf_.Valid()) {
448       if (*status) {
449         got_message = *status =
450             SerializationTraits<R>::Deserialize(recv_buf_.bbuf_ptr(), message_)
451                 .ok();
452         recv_buf_.Release();
453       } else {
454         got_message = false;
455         recv_buf_.Clear();
456       }
457     } else if (hijacked_) {
458       if (hijacked_recv_message_failed_) {
459         FinishOpRecvMessageFailureHandler(status);
460       } else {
461         // The op was hijacked and it was successful. There is no further action
462         // to be performed since the message is already in its non-serialized
463         // form.
464       }
465     } else {
466       FinishOpRecvMessageFailureHandler(status);
467     }
468   }
469 
SetInterceptionHookPoint(InterceptorBatchMethodsImpl * interceptor_methods)470   void SetInterceptionHookPoint(
471       InterceptorBatchMethodsImpl* interceptor_methods) {
472     if (message_ == nullptr) return;
473     interceptor_methods->SetRecvMessage(message_,
474                                         &hijacked_recv_message_failed_);
475   }
476 
SetFinishInterceptionHookPoint(InterceptorBatchMethodsImpl * interceptor_methods)477   void SetFinishInterceptionHookPoint(
478       InterceptorBatchMethodsImpl* interceptor_methods) {
479     if (message_ == nullptr) return;
480     interceptor_methods->AddInterceptionHookPoint(
481         experimental::InterceptionHookPoints::POST_RECV_MESSAGE);
482     if (!got_message) interceptor_methods->SetRecvMessage(nullptr, nullptr);
483   }
SetHijackingState(InterceptorBatchMethodsImpl * interceptor_methods)484   void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) {
485     hijacked_ = true;
486     if (message_ == nullptr) return;
487     interceptor_methods->AddInterceptionHookPoint(
488         experimental::InterceptionHookPoints::PRE_RECV_MESSAGE);
489     got_message = true;
490   }
491 
492  private:
493   // Sets got_message and \a status for a failed recv message op
FinishOpRecvMessageFailureHandler(bool * status)494   void FinishOpRecvMessageFailureHandler(bool* status) {
495     got_message = false;
496     if (!allow_not_getting_message_) {
497       *status = false;
498     }
499   }
500 
501   R* message_ = nullptr;
502   ByteBuffer recv_buf_;
503   bool allow_not_getting_message_ = false;
504   bool hijacked_ = false;
505   bool hijacked_recv_message_failed_ = false;
506 };
507 
508 class DeserializeFunc {
509  public:
510   virtual Status Deserialize(ByteBuffer* buf) = 0;
~DeserializeFunc()511   virtual ~DeserializeFunc() {}
512 };
513 
514 template <class R>
515 class DeserializeFuncType final : public DeserializeFunc {
516  public:
DeserializeFuncType(R * message)517   explicit DeserializeFuncType(R* message) : message_(message) {}
Deserialize(ByteBuffer * buf)518   Status Deserialize(ByteBuffer* buf) override {
519     return SerializationTraits<R>::Deserialize(buf->bbuf_ptr(), message_);
520   }
521 
~DeserializeFuncType()522   ~DeserializeFuncType() override {}
523 
524  private:
525   R* message_;  // Not a managed pointer because management is external to this
526 };
527 
528 class CallOpGenericRecvMessage {
529  public:
530   template <class R>
RecvMessage(R * message)531   void RecvMessage(R* message) {
532     // Use an explicit base class pointer to avoid resolution error in the
533     // following unique_ptr::reset for some old implementations.
534     DeserializeFunc* func = new DeserializeFuncType<R>(message);
535     deserialize_.reset(func);
536     message_ = message;
537   }
538 
539   // Do not change status if no message is received.
AllowNoMessage()540   void AllowNoMessage() { allow_not_getting_message_ = true; }
541 
542   bool got_message = false;
543 
544  protected:
AddOp(grpc_op * ops,size_t * nops)545   void AddOp(grpc_op* ops, size_t* nops) {
546     if (!deserialize_ || hijacked_) return;
547     grpc_op* op = &ops[(*nops)++];
548     op->op = GRPC_OP_RECV_MESSAGE;
549     op->flags = 0;
550     op->reserved = nullptr;
551     op->data.recv_message.recv_message = recv_buf_.c_buffer_ptr();
552   }
553 
FinishOp(bool * status)554   void FinishOp(bool* status) {
555     if (!deserialize_) return;
556     if (recv_buf_.Valid()) {
557       if (*status) {
558         got_message = true;
559         *status = deserialize_->Deserialize(&recv_buf_).ok();
560         recv_buf_.Release();
561       } else {
562         got_message = false;
563         recv_buf_.Clear();
564       }
565     } else if (hijacked_) {
566       if (hijacked_recv_message_failed_) {
567         FinishOpRecvMessageFailureHandler(status);
568       } else {
569         // The op was hijacked and it was successful. There is no further action
570         // to be performed since the message is already in its non-serialized
571         // form.
572       }
573     } else {
574       got_message = false;
575       if (!allow_not_getting_message_) {
576         *status = false;
577       }
578     }
579   }
580 
SetInterceptionHookPoint(InterceptorBatchMethodsImpl * interceptor_methods)581   void SetInterceptionHookPoint(
582       InterceptorBatchMethodsImpl* interceptor_methods) {
583     if (!deserialize_) return;
584     interceptor_methods->SetRecvMessage(message_,
585                                         &hijacked_recv_message_failed_);
586   }
587 
SetFinishInterceptionHookPoint(InterceptorBatchMethodsImpl * interceptor_methods)588   void SetFinishInterceptionHookPoint(
589       InterceptorBatchMethodsImpl* interceptor_methods) {
590     if (!deserialize_) return;
591     interceptor_methods->AddInterceptionHookPoint(
592         experimental::InterceptionHookPoints::POST_RECV_MESSAGE);
593     if (!got_message) interceptor_methods->SetRecvMessage(nullptr, nullptr);
594     deserialize_.reset();
595   }
SetHijackingState(InterceptorBatchMethodsImpl * interceptor_methods)596   void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) {
597     hijacked_ = true;
598     if (!deserialize_) return;
599     interceptor_methods->AddInterceptionHookPoint(
600         experimental::InterceptionHookPoints::PRE_RECV_MESSAGE);
601     got_message = true;
602   }
603 
604  private:
605   // Sets got_message and \a status for a failed recv message op
FinishOpRecvMessageFailureHandler(bool * status)606   void FinishOpRecvMessageFailureHandler(bool* status) {
607     got_message = false;
608     if (!allow_not_getting_message_) {
609       *status = false;
610     }
611   }
612 
613   void* message_ = nullptr;
614   std::unique_ptr<DeserializeFunc> deserialize_;
615   ByteBuffer recv_buf_;
616   bool allow_not_getting_message_ = false;
617   bool hijacked_ = false;
618   bool hijacked_recv_message_failed_ = false;
619 };
620 
621 class CallOpClientSendClose {
622  public:
CallOpClientSendClose()623   CallOpClientSendClose() : send_(false) {}
624 
ClientSendClose()625   void ClientSendClose() { send_ = true; }
626 
627  protected:
AddOp(grpc_op * ops,size_t * nops)628   void AddOp(grpc_op* ops, size_t* nops) {
629     if (!send_ || hijacked_) return;
630     grpc_op* op = &ops[(*nops)++];
631     op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT;
632     op->flags = 0;
633     op->reserved = nullptr;
634   }
FinishOp(bool *)635   void FinishOp(bool* /*status*/) { send_ = false; }
636 
SetInterceptionHookPoint(InterceptorBatchMethodsImpl * interceptor_methods)637   void SetInterceptionHookPoint(
638       InterceptorBatchMethodsImpl* interceptor_methods) {
639     if (!send_) return;
640     interceptor_methods->AddInterceptionHookPoint(
641         experimental::InterceptionHookPoints::PRE_SEND_CLOSE);
642   }
643 
SetFinishInterceptionHookPoint(InterceptorBatchMethodsImpl *)644   void SetFinishInterceptionHookPoint(
645       InterceptorBatchMethodsImpl* /*interceptor_methods*/) {}
646 
SetHijackingState(InterceptorBatchMethodsImpl *)647   void SetHijackingState(InterceptorBatchMethodsImpl* /*interceptor_methods*/) {
648     hijacked_ = true;
649   }
650 
651  private:
652   bool hijacked_ = false;
653   bool send_;
654 };
655 
656 class CallOpServerSendStatus {
657  public:
CallOpServerSendStatus()658   CallOpServerSendStatus() : send_status_available_(false) {}
659 
ServerSendStatus(std::multimap<std::string,std::string> * trailing_metadata,const Status & status)660   void ServerSendStatus(
661       std::multimap<std::string, std::string>* trailing_metadata,
662       const Status& status) {
663     send_error_details_ = status.error_details();
664     metadata_map_ = trailing_metadata;
665     send_status_available_ = true;
666     send_status_code_ = static_cast<grpc_status_code>(status.error_code());
667     send_error_message_ = status.error_message();
668   }
669 
670  protected:
AddOp(grpc_op * ops,size_t * nops)671   void AddOp(grpc_op* ops, size_t* nops) {
672     if (!send_status_available_ || hijacked_) return;
673     trailing_metadata_ = FillMetadataArray(
674         *metadata_map_, &trailing_metadata_count_, send_error_details_);
675     grpc_op* op = &ops[(*nops)++];
676     op->op = GRPC_OP_SEND_STATUS_FROM_SERVER;
677     op->data.send_status_from_server.trailing_metadata_count =
678         trailing_metadata_count_;
679     op->data.send_status_from_server.trailing_metadata = trailing_metadata_;
680     op->data.send_status_from_server.status = send_status_code_;
681     error_message_slice_ = SliceReferencingString(send_error_message_);
682     op->data.send_status_from_server.status_details =
683         send_error_message_.empty() ? nullptr : &error_message_slice_;
684     op->flags = 0;
685     op->reserved = nullptr;
686   }
687 
FinishOp(bool *)688   void FinishOp(bool* /*status*/) {
689     if (!send_status_available_ || hijacked_) return;
690     gpr_free(trailing_metadata_);
691     send_status_available_ = false;
692   }
693 
SetInterceptionHookPoint(InterceptorBatchMethodsImpl * interceptor_methods)694   void SetInterceptionHookPoint(
695       InterceptorBatchMethodsImpl* interceptor_methods) {
696     if (!send_status_available_) return;
697     interceptor_methods->AddInterceptionHookPoint(
698         experimental::InterceptionHookPoints::PRE_SEND_STATUS);
699     interceptor_methods->SetSendTrailingMetadata(metadata_map_);
700     interceptor_methods->SetSendStatus(&send_status_code_, &send_error_details_,
701                                        &send_error_message_);
702   }
703 
SetFinishInterceptionHookPoint(InterceptorBatchMethodsImpl *)704   void SetFinishInterceptionHookPoint(
705       InterceptorBatchMethodsImpl* /*interceptor_methods*/) {}
706 
SetHijackingState(InterceptorBatchMethodsImpl *)707   void SetHijackingState(InterceptorBatchMethodsImpl* /*interceptor_methods*/) {
708     hijacked_ = true;
709   }
710 
711  private:
712   bool hijacked_ = false;
713   bool send_status_available_;
714   grpc_status_code send_status_code_;
715   std::string send_error_details_;
716   std::string send_error_message_;
717   size_t trailing_metadata_count_;
718   std::multimap<std::string, std::string>* metadata_map_;
719   grpc_metadata* trailing_metadata_;
720   grpc_slice error_message_slice_;
721 };
722 
723 class CallOpRecvInitialMetadata {
724  public:
CallOpRecvInitialMetadata()725   CallOpRecvInitialMetadata() : metadata_map_(nullptr) {}
726 
RecvInitialMetadata(grpc::ClientContext * context)727   void RecvInitialMetadata(grpc::ClientContext* context) {
728     context->initial_metadata_received_ = true;
729     metadata_map_ = &context->recv_initial_metadata_;
730   }
731 
732  protected:
AddOp(grpc_op * ops,size_t * nops)733   void AddOp(grpc_op* ops, size_t* nops) {
734     if (metadata_map_ == nullptr || hijacked_) return;
735     grpc_op* op = &ops[(*nops)++];
736     op->op = GRPC_OP_RECV_INITIAL_METADATA;
737     op->data.recv_initial_metadata.recv_initial_metadata = metadata_map_->arr();
738     op->flags = 0;
739     op->reserved = nullptr;
740   }
741 
FinishOp(bool *)742   void FinishOp(bool* /*status*/) {
743     if (metadata_map_ == nullptr || hijacked_) return;
744   }
745 
SetInterceptionHookPoint(InterceptorBatchMethodsImpl * interceptor_methods)746   void SetInterceptionHookPoint(
747       InterceptorBatchMethodsImpl* interceptor_methods) {
748     interceptor_methods->SetRecvInitialMetadata(metadata_map_);
749   }
750 
SetFinishInterceptionHookPoint(InterceptorBatchMethodsImpl * interceptor_methods)751   void SetFinishInterceptionHookPoint(
752       InterceptorBatchMethodsImpl* interceptor_methods) {
753     if (metadata_map_ == nullptr) return;
754     interceptor_methods->AddInterceptionHookPoint(
755         experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA);
756     metadata_map_ = nullptr;
757   }
758 
SetHijackingState(InterceptorBatchMethodsImpl * interceptor_methods)759   void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) {
760     hijacked_ = true;
761     if (metadata_map_ == nullptr) return;
762     interceptor_methods->AddInterceptionHookPoint(
763         experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA);
764   }
765 
766  private:
767   bool hijacked_ = false;
768   MetadataMap* metadata_map_;
769 };
770 
771 class CallOpClientRecvStatus {
772  public:
CallOpClientRecvStatus()773   CallOpClientRecvStatus()
774       : metadata_map_(nullptr),
775         recv_status_(nullptr),
776         debug_error_string_(nullptr) {}
777 
ClientRecvStatus(grpc::ClientContext * context,Status * status)778   void ClientRecvStatus(grpc::ClientContext* context, Status* status) {
779     client_context_ = context;
780     metadata_map_ = &client_context_->trailing_metadata_;
781     recv_status_ = status;
782     error_message_ = grpc_empty_slice();
783   }
784 
785  protected:
AddOp(grpc_op * ops,size_t * nops)786   void AddOp(grpc_op* ops, size_t* nops) {
787     if (recv_status_ == nullptr || hijacked_) return;
788     grpc_op* op = &ops[(*nops)++];
789     op->op = GRPC_OP_RECV_STATUS_ON_CLIENT;
790     op->data.recv_status_on_client.trailing_metadata = metadata_map_->arr();
791     op->data.recv_status_on_client.status = &status_code_;
792     op->data.recv_status_on_client.status_details = &error_message_;
793     op->data.recv_status_on_client.error_string = &debug_error_string_;
794     op->flags = 0;
795     op->reserved = nullptr;
796   }
797 
FinishOp(bool *)798   void FinishOp(bool* /*status*/) {
799     if (recv_status_ == nullptr || hijacked_) return;
800     if (static_cast<StatusCode>(status_code_) == StatusCode::OK) {
801       *recv_status_ = Status();
802       ABSL_DCHECK_EQ(debug_error_string_, nullptr);
803     } else {
804       *recv_status_ =
805           Status(static_cast<StatusCode>(status_code_),
806                  GRPC_SLICE_IS_EMPTY(error_message_)
807                      ? std::string()
808                      : std::string(GRPC_SLICE_START_PTR(error_message_),
809                                    GRPC_SLICE_END_PTR(error_message_)),
810                  metadata_map_->GetBinaryErrorDetails());
811       if (debug_error_string_ != nullptr) {
812         client_context_->set_debug_error_string(debug_error_string_);
813         gpr_free(const_cast<char*>(debug_error_string_));
814       }
815     }
816     // TODO(soheil): Find callers that set debug string even for status OK,
817     //               and fix them.
818     grpc_slice_unref(error_message_);
819   }
820 
SetInterceptionHookPoint(InterceptorBatchMethodsImpl * interceptor_methods)821   void SetInterceptionHookPoint(
822       InterceptorBatchMethodsImpl* interceptor_methods) {
823     interceptor_methods->SetRecvStatus(recv_status_);
824     interceptor_methods->SetRecvTrailingMetadata(metadata_map_);
825   }
826 
SetFinishInterceptionHookPoint(InterceptorBatchMethodsImpl * interceptor_methods)827   void SetFinishInterceptionHookPoint(
828       InterceptorBatchMethodsImpl* interceptor_methods) {
829     if (recv_status_ == nullptr) return;
830     interceptor_methods->AddInterceptionHookPoint(
831         experimental::InterceptionHookPoints::POST_RECV_STATUS);
832     recv_status_ = nullptr;
833   }
834 
SetHijackingState(InterceptorBatchMethodsImpl * interceptor_methods)835   void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) {
836     hijacked_ = true;
837     if (recv_status_ == nullptr) return;
838     interceptor_methods->AddInterceptionHookPoint(
839         experimental::InterceptionHookPoints::PRE_RECV_STATUS);
840   }
841 
842  private:
843   bool hijacked_ = false;
844   grpc::ClientContext* client_context_;
845   MetadataMap* metadata_map_;
846   Status* recv_status_;
847   const char* debug_error_string_;
848   grpc_status_code status_code_;
849   grpc_slice error_message_;
850 };
851 
852 template <class Op1 = CallNoOp<1>, class Op2 = CallNoOp<2>,
853           class Op3 = CallNoOp<3>, class Op4 = CallNoOp<4>,
854           class Op5 = CallNoOp<5>, class Op6 = CallNoOp<6>>
855 class CallOpSet;
856 
857 /// Primary implementation of CallOpSetInterface.
858 /// Since we cannot use variadic templates, we declare slots up to
859 /// the maximum count of ops we'll need in a set. We leverage the
860 /// empty base class optimization to slim this class (especially
861 /// when there are many unused slots used). To avoid duplicate base classes,
862 /// the template parameter for CallNoOp is varied by argument position.
863 template <class Op1, class Op2, class Op3, class Op4, class Op5, class Op6>
864 class CallOpSet : public CallOpSetInterface,
865                   public Op1,
866                   public Op2,
867                   public Op3,
868                   public Op4,
869                   public Op5,
870                   public Op6 {
871  public:
CallOpSet()872   CallOpSet() : core_cq_tag_(this), return_tag_(this) {}
873   // The copy constructor and assignment operator reset the value of
874   // core_cq_tag_, return_tag_, done_intercepting_ and interceptor_methods_
875   // since those are only meaningful on a specific object, not across objects.
CallOpSet(const CallOpSet & other)876   CallOpSet(const CallOpSet& other)
877       : core_cq_tag_(this),
878         return_tag_(this),
879         call_(other.call_),
880         done_intercepting_(false),
881         interceptor_methods_(InterceptorBatchMethodsImpl()) {}
882 
883   CallOpSet& operator=(const CallOpSet& other) {
884     if (&other == this) {
885       return *this;
886     }
887     core_cq_tag_ = this;
888     return_tag_ = this;
889     call_ = other.call_;
890     done_intercepting_ = false;
891     interceptor_methods_ = InterceptorBatchMethodsImpl();
892     return *this;
893   }
894 
FillOps(Call * call)895   void FillOps(Call* call) override {
896     done_intercepting_ = false;
897     grpc_call_ref(call->call());
898     call_ =
899         *call;  // It's fine to create a copy of call since it's just pointers
900 
901     if (RunInterceptors()) {
902       ContinueFillOpsAfterInterception();
903     } else {
904       // After the interceptors are run, ContinueFillOpsAfterInterception will
905       // be run
906     }
907   }
908 
FinalizeResult(void ** tag,bool * status)909   bool FinalizeResult(void** tag, bool* status) override {
910     if (done_intercepting_) {
911       // Complete the avalanching since we are done with this batch of ops
912       call_.cq()->CompleteAvalanching();
913       // We have already finished intercepting and filling in the results. This
914       // round trip from the core needed to be made because interceptors were
915       // run
916       *tag = return_tag_;
917       *status = saved_status_;
918       grpc_call_unref(call_.call());
919       return true;
920     }
921 
922     this->Op1::FinishOp(status);
923     this->Op2::FinishOp(status);
924     this->Op3::FinishOp(status);
925     this->Op4::FinishOp(status);
926     this->Op5::FinishOp(status);
927     this->Op6::FinishOp(status);
928     saved_status_ = *status;
929     if (RunInterceptorsPostRecv()) {
930       *tag = return_tag_;
931       grpc_call_unref(call_.call());
932       return true;
933     }
934     // Interceptors are going to be run, so we can't return the tag just yet.
935     // After the interceptors are run, ContinueFinalizeResultAfterInterception
936     return false;
937   }
938 
set_output_tag(void * return_tag)939   void set_output_tag(void* return_tag) { return_tag_ = return_tag; }
940 
core_cq_tag()941   void* core_cq_tag() override { return core_cq_tag_; }
942 
943   /// set_core_cq_tag is used to provide a different core CQ tag than "this".
944   /// This is used for callback-based tags, where the core tag is the core
945   /// callback function. It does not change the use or behavior of any other
946   /// function (such as FinalizeResult)
set_core_cq_tag(void * core_cq_tag)947   void set_core_cq_tag(void* core_cq_tag) { core_cq_tag_ = core_cq_tag; }
948 
949   // This will be called while interceptors are run if the RPC is a hijacked
950   // RPC. This should set hijacking state for each of the ops.
SetHijackingState()951   void SetHijackingState() override {
952     this->Op1::SetHijackingState(&interceptor_methods_);
953     this->Op2::SetHijackingState(&interceptor_methods_);
954     this->Op3::SetHijackingState(&interceptor_methods_);
955     this->Op4::SetHijackingState(&interceptor_methods_);
956     this->Op5::SetHijackingState(&interceptor_methods_);
957     this->Op6::SetHijackingState(&interceptor_methods_);
958   }
959 
960   // Should be called after interceptors are done running
ContinueFillOpsAfterInterception()961   void ContinueFillOpsAfterInterception() override {
962     static const size_t MAX_OPS = 6;
963     grpc_op ops[MAX_OPS];
964     size_t nops = 0;
965     this->Op1::AddOp(ops, &nops);
966     this->Op2::AddOp(ops, &nops);
967     this->Op3::AddOp(ops, &nops);
968     this->Op4::AddOp(ops, &nops);
969     this->Op5::AddOp(ops, &nops);
970     this->Op6::AddOp(ops, &nops);
971 
972     grpc_call_error err =
973         grpc_call_start_batch(call_.call(), ops, nops, core_cq_tag(), nullptr);
974 
975     if (err != GRPC_CALL_OK) {
976       // A failure here indicates an API misuse; for example, doing a Write
977       // while another Write is already pending on the same RPC or invoking
978       // WritesDone multiple times
979       ABSL_LOG(ERROR) << "API misuse of type " << grpc_call_error_to_string(err)
980                       << " observed";
981       ABSL_CHECK(false);
982     }
983   }
984 
985   // Should be called after interceptors are done running on the finalize result
986   // path
ContinueFinalizeResultAfterInterception()987   void ContinueFinalizeResultAfterInterception() override {
988     done_intercepting_ = true;
989     // The following call_start_batch is internally-generated so no need for an
990     // explanatory log on failure.
991     ABSL_CHECK(grpc_call_start_batch(call_.call(), nullptr, 0, core_cq_tag(),
992                                      nullptr) == GRPC_CALL_OK);
993   }
994 
995  private:
996   // Returns true if no interceptors need to be run
RunInterceptors()997   bool RunInterceptors() {
998     interceptor_methods_.ClearState();
999     interceptor_methods_.SetCallOpSetInterface(this);
1000     interceptor_methods_.SetCall(&call_);
1001     this->Op1::SetInterceptionHookPoint(&interceptor_methods_);
1002     this->Op2::SetInterceptionHookPoint(&interceptor_methods_);
1003     this->Op3::SetInterceptionHookPoint(&interceptor_methods_);
1004     this->Op4::SetInterceptionHookPoint(&interceptor_methods_);
1005     this->Op5::SetInterceptionHookPoint(&interceptor_methods_);
1006     this->Op6::SetInterceptionHookPoint(&interceptor_methods_);
1007     if (interceptor_methods_.InterceptorsListEmpty()) {
1008       return true;
1009     }
1010     // This call will go through interceptors and would need to
1011     // schedule new batches, so delay completion queue shutdown
1012     call_.cq()->RegisterAvalanching();
1013     return interceptor_methods_.RunInterceptors();
1014   }
1015   // Returns true if no interceptors need to be run
RunInterceptorsPostRecv()1016   bool RunInterceptorsPostRecv() {
1017     // Call and OpSet had already been set on the set state.
1018     // SetReverse also clears previously set hook points
1019     interceptor_methods_.SetReverse();
1020     this->Op1::SetFinishInterceptionHookPoint(&interceptor_methods_);
1021     this->Op2::SetFinishInterceptionHookPoint(&interceptor_methods_);
1022     this->Op3::SetFinishInterceptionHookPoint(&interceptor_methods_);
1023     this->Op4::SetFinishInterceptionHookPoint(&interceptor_methods_);
1024     this->Op5::SetFinishInterceptionHookPoint(&interceptor_methods_);
1025     this->Op6::SetFinishInterceptionHookPoint(&interceptor_methods_);
1026     return interceptor_methods_.RunInterceptors();
1027   }
1028 
1029   void* core_cq_tag_;
1030   void* return_tag_;
1031   Call call_;
1032   bool done_intercepting_ = false;
1033   InterceptorBatchMethodsImpl interceptor_methods_;
1034   bool saved_status_;
1035 };
1036 
1037 }  // namespace internal
1038 }  // namespace grpc
1039 
1040 #endif  // GRPCPP_IMPL_CALL_OP_SET_H
1041