• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *
3  * Copyright 2018 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #ifndef GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H
20 #define GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H
21 
22 #include <array>
23 #include <functional>
24 
25 #include <grpcpp/impl/codegen/call.h>
26 #include <grpcpp/impl/codegen/call_op_set_interface.h>
27 #include <grpcpp/impl/codegen/client_interceptor.h>
28 #include <grpcpp/impl/codegen/intercepted_channel.h>
29 #include <grpcpp/impl/codegen/server_interceptor.h>
30 
31 #include <grpc/impl/codegen/grpc_types.h>
32 
33 namespace grpc {
34 namespace internal {
35 
36 class InterceptorBatchMethodsImpl
37     : public experimental::InterceptorBatchMethods {
38  public:
InterceptorBatchMethodsImpl()39   InterceptorBatchMethodsImpl() {
40     for (auto i = static_cast<experimental::InterceptionHookPoints>(0);
41          i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS;
42          i = static_cast<experimental::InterceptionHookPoints>(
43              static_cast<size_t>(i) + 1)) {
44       hooks_[static_cast<size_t>(i)] = false;
45     }
46   }
47 
~InterceptorBatchMethodsImpl()48   ~InterceptorBatchMethodsImpl() {}
49 
QueryInterceptionHookPoint(experimental::InterceptionHookPoints type)50   bool QueryInterceptionHookPoint(
51       experimental::InterceptionHookPoints type) override {
52     return hooks_[static_cast<size_t>(type)];
53   }
54 
Proceed()55   void Proceed() override {
56     if (call_->client_rpc_info() != nullptr) {
57       return ProceedClient();
58     }
59     GPR_CODEGEN_ASSERT(call_->server_rpc_info() != nullptr);
60     ProceedServer();
61   }
62 
Hijack()63   void Hijack() override {
64     // Only the client can hijack when sending down initial metadata
65     GPR_CODEGEN_ASSERT(!reverse_ && ops_ != nullptr &&
66                        call_->client_rpc_info() != nullptr);
67     // It is illegal to call Hijack twice
68     GPR_CODEGEN_ASSERT(!ran_hijacking_interceptor_);
69     auto* rpc_info = call_->client_rpc_info();
70     rpc_info->hijacked_ = true;
71     rpc_info->hijacked_interceptor_ = current_interceptor_index_;
72     ClearHookPoints();
73     ops_->SetHijackingState();
74     ran_hijacking_interceptor_ = true;
75     rpc_info->RunInterceptor(this, current_interceptor_index_);
76   }
77 
AddInterceptionHookPoint(experimental::InterceptionHookPoints type)78   void AddInterceptionHookPoint(experimental::InterceptionHookPoints type) {
79     hooks_[static_cast<size_t>(type)] = true;
80   }
81 
GetSerializedSendMessage()82   ByteBuffer* GetSerializedSendMessage() override {
83     GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr);
84     if (*orig_send_message_ != nullptr) {
85       GPR_CODEGEN_ASSERT(serializer_(*orig_send_message_).ok());
86       *orig_send_message_ = nullptr;
87     }
88     return send_message_;
89   }
90 
GetSendMessage()91   const void* GetSendMessage() override {
92     GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr);
93     return *orig_send_message_;
94   }
95 
ModifySendMessage(const void * message)96   void ModifySendMessage(const void* message) override {
97     GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr);
98     *orig_send_message_ = message;
99   }
100 
GetSendMessageStatus()101   bool GetSendMessageStatus() override { return !*fail_send_message_; }
102 
GetSendInitialMetadata()103   std::multimap<std::string, std::string>* GetSendInitialMetadata() override {
104     return send_initial_metadata_;
105   }
106 
GetSendStatus()107   Status GetSendStatus() override {
108     return Status(static_cast<StatusCode>(*code_), *error_message_,
109                   *error_details_);
110   }
111 
ModifySendStatus(const Status & status)112   void ModifySendStatus(const Status& status) override {
113     *code_ = static_cast<grpc_status_code>(status.error_code());
114     *error_details_ = status.error_details();
115     *error_message_ = status.error_message();
116   }
117 
GetSendTrailingMetadata()118   std::multimap<std::string, std::string>* GetSendTrailingMetadata() override {
119     return send_trailing_metadata_;
120   }
121 
GetRecvMessage()122   void* GetRecvMessage() override { return recv_message_; }
123 
GetRecvInitialMetadata()124   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata()
125       override {
126     return recv_initial_metadata_->map();
127   }
128 
GetRecvStatus()129   Status* GetRecvStatus() override { return recv_status_; }
130 
FailHijackedSendMessage()131   void FailHijackedSendMessage() override {
132     GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>(
133         experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]);
134     *fail_send_message_ = true;
135   }
136 
GetRecvTrailingMetadata()137   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata()
138       override {
139     return recv_trailing_metadata_->map();
140   }
141 
SetSendMessage(ByteBuffer * buf,const void ** msg,bool * fail_send_message,std::function<Status (const void *)> serializer)142   void SetSendMessage(ByteBuffer* buf, const void** msg,
143                       bool* fail_send_message,
144                       std::function<Status(const void*)> serializer) {
145     send_message_ = buf;
146     orig_send_message_ = msg;
147     fail_send_message_ = fail_send_message;
148     serializer_ = serializer;
149   }
150 
SetSendInitialMetadata(std::multimap<std::string,std::string> * metadata)151   void SetSendInitialMetadata(
152       std::multimap<std::string, std::string>* metadata) {
153     send_initial_metadata_ = metadata;
154   }
155 
SetSendStatus(grpc_status_code * code,std::string * error_details,std::string * error_message)156   void SetSendStatus(grpc_status_code* code, std::string* error_details,
157                      std::string* error_message) {
158     code_ = code;
159     error_details_ = error_details;
160     error_message_ = error_message;
161   }
162 
SetSendTrailingMetadata(std::multimap<std::string,std::string> * metadata)163   void SetSendTrailingMetadata(
164       std::multimap<std::string, std::string>* metadata) {
165     send_trailing_metadata_ = metadata;
166   }
167 
SetRecvMessage(void * message,bool * hijacked_recv_message_failed)168   void SetRecvMessage(void* message, bool* hijacked_recv_message_failed) {
169     recv_message_ = message;
170     hijacked_recv_message_failed_ = hijacked_recv_message_failed;
171   }
172 
SetRecvInitialMetadata(MetadataMap * map)173   void SetRecvInitialMetadata(MetadataMap* map) {
174     recv_initial_metadata_ = map;
175   }
176 
SetRecvStatus(Status * status)177   void SetRecvStatus(Status* status) { recv_status_ = status; }
178 
SetRecvTrailingMetadata(MetadataMap * map)179   void SetRecvTrailingMetadata(MetadataMap* map) {
180     recv_trailing_metadata_ = map;
181   }
182 
GetInterceptedChannel()183   std::unique_ptr<ChannelInterface> GetInterceptedChannel() override {
184     auto* info = call_->client_rpc_info();
185     if (info == nullptr) {
186       return std::unique_ptr<ChannelInterface>(nullptr);
187     }
188     // The intercepted channel starts from the interceptor just after the
189     // current interceptor
190     return std::unique_ptr<ChannelInterface>(new InterceptedChannel(
191         info->channel(), current_interceptor_index_ + 1));
192   }
193 
FailHijackedRecvMessage()194   void FailHijackedRecvMessage() override {
195     GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>(
196         experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)]);
197     *hijacked_recv_message_failed_ = true;
198   }
199 
200   // Clears all state
ClearState()201   void ClearState() {
202     reverse_ = false;
203     ran_hijacking_interceptor_ = false;
204     ClearHookPoints();
205   }
206 
207   // Prepares for Post_recv operations
SetReverse()208   void SetReverse() {
209     reverse_ = true;
210     ran_hijacking_interceptor_ = false;
211     ClearHookPoints();
212   }
213 
214   // This needs to be set before interceptors are run
SetCall(Call * call)215   void SetCall(Call* call) { call_ = call; }
216 
217   // This needs to be set before interceptors are run using RunInterceptors().
218   // Alternatively, RunInterceptors(std::function<void(void)> f) can be used.
SetCallOpSetInterface(CallOpSetInterface * ops)219   void SetCallOpSetInterface(CallOpSetInterface* ops) { ops_ = ops; }
220 
221   // SetCall should have been called before this.
222   // Returns true if the interceptors list is empty
InterceptorsListEmpty()223   bool InterceptorsListEmpty() {
224     auto* client_rpc_info = call_->client_rpc_info();
225     if (client_rpc_info != nullptr) {
226       if (client_rpc_info->interceptors_.size() == 0) {
227         return true;
228       } else {
229         return false;
230       }
231     }
232 
233     auto* server_rpc_info = call_->server_rpc_info();
234     if (server_rpc_info == nullptr ||
235         server_rpc_info->interceptors_.size() == 0) {
236       return true;
237     }
238     return false;
239   }
240 
241   // This should be used only by subclasses of CallOpSetInterface. SetCall and
242   // SetCallOpSetInterface should have been called before this. After all the
243   // interceptors are done running, either ContinueFillOpsAfterInterception or
244   // ContinueFinalizeOpsAfterInterception will be called. Note that neither of
245   // them is invoked if there were no interceptors registered.
RunInterceptors()246   bool RunInterceptors() {
247     GPR_CODEGEN_ASSERT(ops_);
248     auto* client_rpc_info = call_->client_rpc_info();
249     if (client_rpc_info != nullptr) {
250       if (client_rpc_info->interceptors_.size() == 0) {
251         return true;
252       } else {
253         RunClientInterceptors();
254         return false;
255       }
256     }
257 
258     auto* server_rpc_info = call_->server_rpc_info();
259     if (server_rpc_info == nullptr ||
260         server_rpc_info->interceptors_.size() == 0) {
261       return true;
262     }
263     RunServerInterceptors();
264     return false;
265   }
266 
267   // Returns true if no interceptors are run. Returns false otherwise if there
268   // are interceptors registered. After the interceptors are done running \a f
269   // will be invoked. This is to be used only by BaseAsyncRequest and
270   // SyncRequest.
RunInterceptors(std::function<void (void)> f)271   bool RunInterceptors(std::function<void(void)> f) {
272     // This is used only by the server for initial call request
273     GPR_CODEGEN_ASSERT(reverse_ == true);
274     GPR_CODEGEN_ASSERT(call_->client_rpc_info() == nullptr);
275     auto* server_rpc_info = call_->server_rpc_info();
276     if (server_rpc_info == nullptr ||
277         server_rpc_info->interceptors_.size() == 0) {
278       return true;
279     }
280     callback_ = std::move(f);
281     RunServerInterceptors();
282     return false;
283   }
284 
285  private:
RunClientInterceptors()286   void RunClientInterceptors() {
287     auto* rpc_info = call_->client_rpc_info();
288     if (!reverse_) {
289       current_interceptor_index_ = 0;
290     } else {
291       if (rpc_info->hijacked_) {
292         current_interceptor_index_ = rpc_info->hijacked_interceptor_;
293       } else {
294         current_interceptor_index_ = rpc_info->interceptors_.size() - 1;
295       }
296     }
297     rpc_info->RunInterceptor(this, current_interceptor_index_);
298   }
299 
RunServerInterceptors()300   void RunServerInterceptors() {
301     auto* rpc_info = call_->server_rpc_info();
302     if (!reverse_) {
303       current_interceptor_index_ = 0;
304     } else {
305       current_interceptor_index_ = rpc_info->interceptors_.size() - 1;
306     }
307     rpc_info->RunInterceptor(this, current_interceptor_index_);
308   }
309 
ProceedClient()310   void ProceedClient() {
311     auto* rpc_info = call_->client_rpc_info();
312     if (rpc_info->hijacked_ && !reverse_ &&
313         current_interceptor_index_ == rpc_info->hijacked_interceptor_ &&
314         !ran_hijacking_interceptor_) {
315       // We now need to provide hijacked recv ops to this interceptor
316       ClearHookPoints();
317       ops_->SetHijackingState();
318       ran_hijacking_interceptor_ = true;
319       rpc_info->RunInterceptor(this, current_interceptor_index_);
320       return;
321     }
322     if (!reverse_) {
323       current_interceptor_index_++;
324       // We are going down the stack of interceptors
325       if (current_interceptor_index_ < rpc_info->interceptors_.size()) {
326         if (rpc_info->hijacked_ &&
327             current_interceptor_index_ > rpc_info->hijacked_interceptor_) {
328           // This is a hijacked RPC and we are done with hijacking
329           ops_->ContinueFillOpsAfterInterception();
330         } else {
331           rpc_info->RunInterceptor(this, current_interceptor_index_);
332         }
333       } else {
334         // we are done running all the interceptors without any hijacking
335         ops_->ContinueFillOpsAfterInterception();
336       }
337     } else {
338       // We are going up the stack of interceptors
339       if (current_interceptor_index_ > 0) {
340         // Continue running interceptors
341         current_interceptor_index_--;
342         rpc_info->RunInterceptor(this, current_interceptor_index_);
343       } else {
344         // we are done running all the interceptors without any hijacking
345         ops_->ContinueFinalizeResultAfterInterception();
346       }
347     }
348   }
349 
ProceedServer()350   void ProceedServer() {
351     auto* rpc_info = call_->server_rpc_info();
352     if (!reverse_) {
353       current_interceptor_index_++;
354       if (current_interceptor_index_ < rpc_info->interceptors_.size()) {
355         return rpc_info->RunInterceptor(this, current_interceptor_index_);
356       } else if (ops_) {
357         return ops_->ContinueFillOpsAfterInterception();
358       }
359     } else {
360       // We are going up the stack of interceptors
361       if (current_interceptor_index_ > 0) {
362         // Continue running interceptors
363         current_interceptor_index_--;
364         return rpc_info->RunInterceptor(this, current_interceptor_index_);
365       } else if (ops_) {
366         return ops_->ContinueFinalizeResultAfterInterception();
367       }
368     }
369     GPR_CODEGEN_ASSERT(callback_);
370     callback_();
371   }
372 
ClearHookPoints()373   void ClearHookPoints() {
374     for (auto i = static_cast<experimental::InterceptionHookPoints>(0);
375          i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS;
376          i = static_cast<experimental::InterceptionHookPoints>(
377              static_cast<size_t>(i) + 1)) {
378       hooks_[static_cast<size_t>(i)] = false;
379     }
380   }
381 
382   std::array<bool,
383              static_cast<size_t>(
384                  experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)>
385       hooks_;
386 
387   size_t current_interceptor_index_ = 0;  // Current iterator
388   bool reverse_ = false;
389   bool ran_hijacking_interceptor_ = false;
390   Call* call_ = nullptr;  // The Call object is present along with CallOpSet
391                           // object/callback
392   CallOpSetInterface* ops_ = nullptr;
393   std::function<void(void)> callback_;
394 
395   ByteBuffer* send_message_ = nullptr;
396   bool* fail_send_message_ = nullptr;
397   const void** orig_send_message_ = nullptr;
398   std::function<Status(const void*)> serializer_;
399 
400   std::multimap<std::string, std::string>* send_initial_metadata_;
401 
402   grpc_status_code* code_ = nullptr;
403   std::string* error_details_ = nullptr;
404   std::string* error_message_ = nullptr;
405 
406   std::multimap<std::string, std::string>* send_trailing_metadata_ = nullptr;
407 
408   void* recv_message_ = nullptr;
409   bool* hijacked_recv_message_failed_ = nullptr;
410 
411   MetadataMap* recv_initial_metadata_ = nullptr;
412 
413   Status* recv_status_ = nullptr;
414 
415   MetadataMap* recv_trailing_metadata_ = nullptr;
416 };
417 
418 // A special implementation of InterceptorBatchMethods to send a Cancel
419 // notification down the interceptor stack
420 class CancelInterceptorBatchMethods
421     : public experimental::InterceptorBatchMethods {
422  public:
QueryInterceptionHookPoint(experimental::InterceptionHookPoints type)423   bool QueryInterceptionHookPoint(
424       experimental::InterceptionHookPoints type) override {
425     if (type == experimental::InterceptionHookPoints::PRE_SEND_CANCEL) {
426       return true;
427     } else {
428       return false;
429     }
430   }
431 
Proceed()432   void Proceed() override {
433     // This is a no-op. For actual continuation of the RPC simply needs to
434     // return from the Intercept method
435   }
436 
Hijack()437   void Hijack() override {
438     // Only the client can hijack when sending down initial metadata
439     GPR_CODEGEN_ASSERT(false &&
440                        "It is illegal to call Hijack on a method which has a "
441                        "Cancel notification");
442   }
443 
GetSerializedSendMessage()444   ByteBuffer* GetSerializedSendMessage() override {
445     GPR_CODEGEN_ASSERT(false &&
446                        "It is illegal to call GetSendMessage on a method which "
447                        "has a Cancel notification");
448     return nullptr;
449   }
450 
GetSendMessageStatus()451   bool GetSendMessageStatus() override {
452     GPR_CODEGEN_ASSERT(
453         false &&
454         "It is illegal to call GetSendMessageStatus on a method which "
455         "has a Cancel notification");
456     return false;
457   }
458 
GetSendMessage()459   const void* GetSendMessage() override {
460     GPR_CODEGEN_ASSERT(
461         false &&
462         "It is illegal to call GetOriginalSendMessage on a method which "
463         "has a Cancel notification");
464     return nullptr;
465   }
466 
ModifySendMessage(const void *)467   void ModifySendMessage(const void* /*message*/) override {
468     GPR_CODEGEN_ASSERT(
469         false &&
470         "It is illegal to call ModifySendMessage on a method which "
471         "has a Cancel notification");
472   }
473 
GetSendInitialMetadata()474   std::multimap<std::string, std::string>* GetSendInitialMetadata() override {
475     GPR_CODEGEN_ASSERT(false &&
476                        "It is illegal to call GetSendInitialMetadata on a "
477                        "method which has a Cancel notification");
478     return nullptr;
479   }
480 
GetSendStatus()481   Status GetSendStatus() override {
482     GPR_CODEGEN_ASSERT(false &&
483                        "It is illegal to call GetSendStatus on a method which "
484                        "has a Cancel notification");
485     return Status();
486   }
487 
ModifySendStatus(const Status &)488   void ModifySendStatus(const Status& /*status*/) override {
489     GPR_CODEGEN_ASSERT(false &&
490                        "It is illegal to call ModifySendStatus on a method "
491                        "which has a Cancel notification");
492     return;
493   }
494 
GetSendTrailingMetadata()495   std::multimap<std::string, std::string>* GetSendTrailingMetadata() override {
496     GPR_CODEGEN_ASSERT(false &&
497                        "It is illegal to call GetSendTrailingMetadata on a "
498                        "method which has a Cancel notification");
499     return nullptr;
500   }
501 
GetRecvMessage()502   void* GetRecvMessage() override {
503     GPR_CODEGEN_ASSERT(false &&
504                        "It is illegal to call GetRecvMessage on a method which "
505                        "has a Cancel notification");
506     return nullptr;
507   }
508 
GetRecvInitialMetadata()509   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata()
510       override {
511     GPR_CODEGEN_ASSERT(false &&
512                        "It is illegal to call GetRecvInitialMetadata on a "
513                        "method which has a Cancel notification");
514     return nullptr;
515   }
516 
GetRecvStatus()517   Status* GetRecvStatus() override {
518     GPR_CODEGEN_ASSERT(false &&
519                        "It is illegal to call GetRecvStatus on a method which "
520                        "has a Cancel notification");
521     return nullptr;
522   }
523 
GetRecvTrailingMetadata()524   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata()
525       override {
526     GPR_CODEGEN_ASSERT(false &&
527                        "It is illegal to call GetRecvTrailingMetadata on a "
528                        "method which has a Cancel notification");
529     return nullptr;
530   }
531 
GetInterceptedChannel()532   std::unique_ptr<ChannelInterface> GetInterceptedChannel() override {
533     GPR_CODEGEN_ASSERT(false &&
534                        "It is illegal to call GetInterceptedChannel on a "
535                        "method which has a Cancel notification");
536     return std::unique_ptr<ChannelInterface>(nullptr);
537   }
538 
FailHijackedRecvMessage()539   void FailHijackedRecvMessage() override {
540     GPR_CODEGEN_ASSERT(false &&
541                        "It is illegal to call FailHijackedRecvMessage on a "
542                        "method which has a Cancel notification");
543   }
544 
FailHijackedSendMessage()545   void FailHijackedSendMessage() override {
546     GPR_CODEGEN_ASSERT(false &&
547                        "It is illegal to call FailHijackedSendMessage on a "
548                        "method which has a Cancel notification");
549   }
550 };
551 }  // namespace internal
552 }  // namespace grpc
553 
554 #endif  // GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H
555