• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *
3  * Copyright 2018 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #ifndef GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H
20 #define GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H
21 
22 #include <array>
23 #include <functional>
24 
25 #include <grpcpp/impl/codegen/call.h>
26 #include <grpcpp/impl/codegen/call_op_set_interface.h>
27 #include <grpcpp/impl/codegen/client_interceptor.h>
28 #include <grpcpp/impl/codegen/intercepted_channel.h>
29 #include <grpcpp/impl/codegen/server_interceptor.h>
30 
31 #include <grpc/impl/codegen/grpc_types.h>
32 
33 namespace grpc {
34 namespace internal {
35 
36 class InterceptorBatchMethodsImpl
37     : public experimental::InterceptorBatchMethods {
38  public:
InterceptorBatchMethodsImpl()39   InterceptorBatchMethodsImpl() {
40     for (auto i = static_cast<experimental::InterceptionHookPoints>(0);
41          i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS;
42          i = static_cast<experimental::InterceptionHookPoints>(
43              static_cast<size_t>(i) + 1)) {
44       hooks_[static_cast<size_t>(i)] = false;
45     }
46   }
47 
~InterceptorBatchMethodsImpl()48   ~InterceptorBatchMethodsImpl() override {}
49 
QueryInterceptionHookPoint(experimental::InterceptionHookPoints type)50   bool QueryInterceptionHookPoint(
51       experimental::InterceptionHookPoints type) override {
52     return hooks_[static_cast<size_t>(type)];
53   }
54 
Proceed()55   void Proceed() override {
56     if (call_->client_rpc_info() != nullptr) {
57       return ProceedClient();
58     }
59     GPR_CODEGEN_ASSERT(call_->server_rpc_info() != nullptr);
60     ProceedServer();
61   }
62 
Hijack()63   void Hijack() override {
64     // Only the client can hijack when sending down initial metadata
65     GPR_CODEGEN_ASSERT(!reverse_ && ops_ != nullptr &&
66                        call_->client_rpc_info() != nullptr);
67     // It is illegal to call Hijack twice
68     GPR_CODEGEN_ASSERT(!ran_hijacking_interceptor_);
69     auto* rpc_info = call_->client_rpc_info();
70     rpc_info->hijacked_ = true;
71     rpc_info->hijacked_interceptor_ = current_interceptor_index_;
72     ClearHookPoints();
73     ops_->SetHijackingState();
74     ran_hijacking_interceptor_ = true;
75     rpc_info->RunInterceptor(this, current_interceptor_index_);
76   }
77 
AddInterceptionHookPoint(experimental::InterceptionHookPoints type)78   void AddInterceptionHookPoint(experimental::InterceptionHookPoints type) {
79     hooks_[static_cast<size_t>(type)] = true;
80   }
81 
GetSerializedSendMessage()82   ByteBuffer* GetSerializedSendMessage() override {
83     GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr);
84     if (*orig_send_message_ != nullptr) {
85       GPR_CODEGEN_ASSERT(serializer_(*orig_send_message_).ok());
86       *orig_send_message_ = nullptr;
87     }
88     return send_message_;
89   }
90 
GetSendMessage()91   const void* GetSendMessage() override {
92     GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr);
93     return *orig_send_message_;
94   }
95 
ModifySendMessage(const void * message)96   void ModifySendMessage(const void* message) override {
97     GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr);
98     *orig_send_message_ = message;
99   }
100 
GetSendMessageStatus()101   bool GetSendMessageStatus() override { return !*fail_send_message_; }
102 
GetSendInitialMetadata()103   std::multimap<std::string, std::string>* GetSendInitialMetadata() override {
104     return send_initial_metadata_;
105   }
106 
GetSendStatus()107   Status GetSendStatus() override {
108     return Status(static_cast<StatusCode>(*code_), *error_message_,
109                   *error_details_);
110   }
111 
ModifySendStatus(const Status & status)112   void ModifySendStatus(const Status& status) override {
113     *code_ = static_cast<grpc_status_code>(status.error_code());
114     *error_details_ = status.error_details();
115     *error_message_ = status.error_message();
116   }
117 
GetSendTrailingMetadata()118   std::multimap<std::string, std::string>* GetSendTrailingMetadata() override {
119     return send_trailing_metadata_;
120   }
121 
GetRecvMessage()122   void* GetRecvMessage() override { return recv_message_; }
123 
GetRecvInitialMetadata()124   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata()
125       override {
126     return recv_initial_metadata_->map();
127   }
128 
GetRecvStatus()129   Status* GetRecvStatus() override { return recv_status_; }
130 
FailHijackedSendMessage()131   void FailHijackedSendMessage() override {
132     GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>(
133         experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]);
134     *fail_send_message_ = true;
135   }
136 
GetRecvTrailingMetadata()137   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata()
138       override {
139     return recv_trailing_metadata_->map();
140   }
141 
SetSendMessage(ByteBuffer * buf,const void ** msg,bool * fail_send_message,std::function<Status (const void *)> serializer)142   void SetSendMessage(ByteBuffer* buf, const void** msg,
143                       bool* fail_send_message,
144                       std::function<Status(const void*)> serializer) {
145     send_message_ = buf;
146     orig_send_message_ = msg;
147     fail_send_message_ = fail_send_message;
148     serializer_ = serializer;
149   }
150 
SetSendInitialMetadata(std::multimap<std::string,std::string> * metadata)151   void SetSendInitialMetadata(
152       std::multimap<std::string, std::string>* metadata) {
153     send_initial_metadata_ = metadata;
154   }
155 
SetSendStatus(grpc_status_code * code,std::string * error_details,std::string * error_message)156   void SetSendStatus(grpc_status_code* code, std::string* error_details,
157                      std::string* error_message) {
158     code_ = code;
159     error_details_ = error_details;
160     error_message_ = error_message;
161   }
162 
SetSendTrailingMetadata(std::multimap<std::string,std::string> * metadata)163   void SetSendTrailingMetadata(
164       std::multimap<std::string, std::string>* metadata) {
165     send_trailing_metadata_ = metadata;
166   }
167 
SetRecvMessage(void * message,bool * hijacked_recv_message_failed)168   void SetRecvMessage(void* message, bool* hijacked_recv_message_failed) {
169     recv_message_ = message;
170     hijacked_recv_message_failed_ = hijacked_recv_message_failed;
171   }
172 
SetRecvInitialMetadata(MetadataMap * map)173   void SetRecvInitialMetadata(MetadataMap* map) {
174     recv_initial_metadata_ = map;
175   }
176 
SetRecvStatus(Status * status)177   void SetRecvStatus(Status* status) { recv_status_ = status; }
178 
SetRecvTrailingMetadata(MetadataMap * map)179   void SetRecvTrailingMetadata(MetadataMap* map) {
180     recv_trailing_metadata_ = map;
181   }
182 
GetInterceptedChannel()183   std::unique_ptr<ChannelInterface> GetInterceptedChannel() override {
184     auto* info = call_->client_rpc_info();
185     if (info == nullptr) {
186       return std::unique_ptr<ChannelInterface>(nullptr);
187     }
188     // The intercepted channel starts from the interceptor just after the
189     // current interceptor
190     return std::unique_ptr<ChannelInterface>(new InterceptedChannel(
191         info->channel(), current_interceptor_index_ + 1));
192   }
193 
FailHijackedRecvMessage()194   void FailHijackedRecvMessage() override {
195     GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>(
196         experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)]);
197     *hijacked_recv_message_failed_ = true;
198   }
199 
200   // Clears all state
ClearState()201   void ClearState() {
202     reverse_ = false;
203     ran_hijacking_interceptor_ = false;
204     ClearHookPoints();
205   }
206 
207   // Prepares for Post_recv operations
SetReverse()208   void SetReverse() {
209     reverse_ = true;
210     ran_hijacking_interceptor_ = false;
211     ClearHookPoints();
212   }
213 
214   // This needs to be set before interceptors are run
SetCall(Call * call)215   void SetCall(Call* call) { call_ = call; }
216 
217   // This needs to be set before interceptors are run using RunInterceptors().
218   // Alternatively, RunInterceptors(std::function<void(void)> f) can be used.
SetCallOpSetInterface(CallOpSetInterface * ops)219   void SetCallOpSetInterface(CallOpSetInterface* ops) { ops_ = ops; }
220 
221   // SetCall should have been called before this.
222   // Returns true if the interceptors list is empty
InterceptorsListEmpty()223   bool InterceptorsListEmpty() {
224     auto* client_rpc_info = call_->client_rpc_info();
225     if (client_rpc_info != nullptr) {
226       return client_rpc_info->interceptors_.empty();
227     }
228 
229     auto* server_rpc_info = call_->server_rpc_info();
230     return server_rpc_info == nullptr || server_rpc_info->interceptors_.empty();
231   }
232 
233   // This should be used only by subclasses of CallOpSetInterface. SetCall and
234   // SetCallOpSetInterface should have been called before this. After all the
235   // interceptors are done running, either ContinueFillOpsAfterInterception or
236   // ContinueFinalizeOpsAfterInterception will be called. Note that neither of
237   // them is invoked if there were no interceptors registered.
RunInterceptors()238   bool RunInterceptors() {
239     GPR_CODEGEN_ASSERT(ops_);
240     auto* client_rpc_info = call_->client_rpc_info();
241     if (client_rpc_info != nullptr) {
242       if (client_rpc_info->interceptors_.empty()) {
243         return true;
244       } else {
245         RunClientInterceptors();
246         return false;
247       }
248     }
249 
250     auto* server_rpc_info = call_->server_rpc_info();
251     if (server_rpc_info == nullptr || server_rpc_info->interceptors_.empty()) {
252       return true;
253     }
254     RunServerInterceptors();
255     return false;
256   }
257 
258   // Returns true if no interceptors are run. Returns false otherwise if there
259   // are interceptors registered. After the interceptors are done running \a f
260   // will be invoked. This is to be used only by BaseAsyncRequest and
261   // SyncRequest.
RunInterceptors(std::function<void (void)> f)262   bool RunInterceptors(std::function<void(void)> f) {
263     // This is used only by the server for initial call request
264     GPR_CODEGEN_ASSERT(reverse_ == true);
265     GPR_CODEGEN_ASSERT(call_->client_rpc_info() == nullptr);
266     auto* server_rpc_info = call_->server_rpc_info();
267     if (server_rpc_info == nullptr || server_rpc_info->interceptors_.empty()) {
268       return true;
269     }
270     callback_ = std::move(f);
271     RunServerInterceptors();
272     return false;
273   }
274 
275  private:
RunClientInterceptors()276   void RunClientInterceptors() {
277     auto* rpc_info = call_->client_rpc_info();
278     if (!reverse_) {
279       current_interceptor_index_ = 0;
280     } else {
281       if (rpc_info->hijacked_) {
282         current_interceptor_index_ = rpc_info->hijacked_interceptor_;
283       } else {
284         current_interceptor_index_ = rpc_info->interceptors_.size() - 1;
285       }
286     }
287     rpc_info->RunInterceptor(this, current_interceptor_index_);
288   }
289 
RunServerInterceptors()290   void RunServerInterceptors() {
291     auto* rpc_info = call_->server_rpc_info();
292     if (!reverse_) {
293       current_interceptor_index_ = 0;
294     } else {
295       current_interceptor_index_ = rpc_info->interceptors_.size() - 1;
296     }
297     rpc_info->RunInterceptor(this, current_interceptor_index_);
298   }
299 
ProceedClient()300   void ProceedClient() {
301     auto* rpc_info = call_->client_rpc_info();
302     if (rpc_info->hijacked_ && !reverse_ &&
303         current_interceptor_index_ == rpc_info->hijacked_interceptor_ &&
304         !ran_hijacking_interceptor_) {
305       // We now need to provide hijacked recv ops to this interceptor
306       ClearHookPoints();
307       ops_->SetHijackingState();
308       ran_hijacking_interceptor_ = true;
309       rpc_info->RunInterceptor(this, current_interceptor_index_);
310       return;
311     }
312     if (!reverse_) {
313       current_interceptor_index_++;
314       // We are going down the stack of interceptors
315       if (current_interceptor_index_ < rpc_info->interceptors_.size()) {
316         if (rpc_info->hijacked_ &&
317             current_interceptor_index_ > rpc_info->hijacked_interceptor_) {
318           // This is a hijacked RPC and we are done with hijacking
319           ops_->ContinueFillOpsAfterInterception();
320         } else {
321           rpc_info->RunInterceptor(this, current_interceptor_index_);
322         }
323       } else {
324         // we are done running all the interceptors without any hijacking
325         ops_->ContinueFillOpsAfterInterception();
326       }
327     } else {
328       // We are going up the stack of interceptors
329       if (current_interceptor_index_ > 0) {
330         // Continue running interceptors
331         current_interceptor_index_--;
332         rpc_info->RunInterceptor(this, current_interceptor_index_);
333       } else {
334         // we are done running all the interceptors without any hijacking
335         ops_->ContinueFinalizeResultAfterInterception();
336       }
337     }
338   }
339 
ProceedServer()340   void ProceedServer() {
341     auto* rpc_info = call_->server_rpc_info();
342     if (!reverse_) {
343       current_interceptor_index_++;
344       if (current_interceptor_index_ < rpc_info->interceptors_.size()) {
345         return rpc_info->RunInterceptor(this, current_interceptor_index_);
346       } else if (ops_) {
347         return ops_->ContinueFillOpsAfterInterception();
348       }
349     } else {
350       // We are going up the stack of interceptors
351       if (current_interceptor_index_ > 0) {
352         // Continue running interceptors
353         current_interceptor_index_--;
354         return rpc_info->RunInterceptor(this, current_interceptor_index_);
355       } else if (ops_) {
356         return ops_->ContinueFinalizeResultAfterInterception();
357       }
358     }
359     GPR_CODEGEN_ASSERT(callback_);
360     callback_();
361   }
362 
ClearHookPoints()363   void ClearHookPoints() {
364     for (auto i = static_cast<experimental::InterceptionHookPoints>(0);
365          i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS;
366          i = static_cast<experimental::InterceptionHookPoints>(
367              static_cast<size_t>(i) + 1)) {
368       hooks_[static_cast<size_t>(i)] = false;
369     }
370   }
371 
372   std::array<bool,
373              static_cast<size_t>(
374                  experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)>
375       hooks_;
376 
377   size_t current_interceptor_index_ = 0;  // Current iterator
378   bool reverse_ = false;
379   bool ran_hijacking_interceptor_ = false;
380   Call* call_ = nullptr;  // The Call object is present along with CallOpSet
381                           // object/callback
382   CallOpSetInterface* ops_ = nullptr;
383   std::function<void(void)> callback_;
384 
385   ByteBuffer* send_message_ = nullptr;
386   bool* fail_send_message_ = nullptr;
387   const void** orig_send_message_ = nullptr;
388   std::function<Status(const void*)> serializer_;
389 
390   std::multimap<std::string, std::string>* send_initial_metadata_;
391 
392   grpc_status_code* code_ = nullptr;
393   std::string* error_details_ = nullptr;
394   std::string* error_message_ = nullptr;
395 
396   std::multimap<std::string, std::string>* send_trailing_metadata_ = nullptr;
397 
398   void* recv_message_ = nullptr;
399   bool* hijacked_recv_message_failed_ = nullptr;
400 
401   MetadataMap* recv_initial_metadata_ = nullptr;
402 
403   Status* recv_status_ = nullptr;
404 
405   MetadataMap* recv_trailing_metadata_ = nullptr;
406 };
407 
408 // A special implementation of InterceptorBatchMethods to send a Cancel
409 // notification down the interceptor stack
410 class CancelInterceptorBatchMethods
411     : public experimental::InterceptorBatchMethods {
412  public:
QueryInterceptionHookPoint(experimental::InterceptionHookPoints type)413   bool QueryInterceptionHookPoint(
414       experimental::InterceptionHookPoints type) override {
415     return type == experimental::InterceptionHookPoints::PRE_SEND_CANCEL;
416   }
417 
Proceed()418   void Proceed() override {
419     // This is a no-op. For actual continuation of the RPC simply needs to
420     // return from the Intercept method
421   }
422 
Hijack()423   void Hijack() override {
424     // Only the client can hijack when sending down initial metadata
425     GPR_CODEGEN_ASSERT(false &&
426                        "It is illegal to call Hijack on a method which has a "
427                        "Cancel notification");
428   }
429 
GetSerializedSendMessage()430   ByteBuffer* GetSerializedSendMessage() override {
431     GPR_CODEGEN_ASSERT(false &&
432                        "It is illegal to call GetSendMessage on a method which "
433                        "has a Cancel notification");
434     return nullptr;
435   }
436 
GetSendMessageStatus()437   bool GetSendMessageStatus() override {
438     GPR_CODEGEN_ASSERT(
439         false &&
440         "It is illegal to call GetSendMessageStatus on a method which "
441         "has a Cancel notification");
442     return false;
443   }
444 
GetSendMessage()445   const void* GetSendMessage() override {
446     GPR_CODEGEN_ASSERT(
447         false &&
448         "It is illegal to call GetOriginalSendMessage on a method which "
449         "has a Cancel notification");
450     return nullptr;
451   }
452 
ModifySendMessage(const void *)453   void ModifySendMessage(const void* /*message*/) override {
454     GPR_CODEGEN_ASSERT(
455         false &&
456         "It is illegal to call ModifySendMessage on a method which "
457         "has a Cancel notification");
458   }
459 
GetSendInitialMetadata()460   std::multimap<std::string, std::string>* GetSendInitialMetadata() override {
461     GPR_CODEGEN_ASSERT(false &&
462                        "It is illegal to call GetSendInitialMetadata on a "
463                        "method which has a Cancel notification");
464     return nullptr;
465   }
466 
GetSendStatus()467   Status GetSendStatus() override {
468     GPR_CODEGEN_ASSERT(false &&
469                        "It is illegal to call GetSendStatus on a method which "
470                        "has a Cancel notification");
471     return Status();
472   }
473 
ModifySendStatus(const Status &)474   void ModifySendStatus(const Status& /*status*/) override {
475     GPR_CODEGEN_ASSERT(false &&
476                        "It is illegal to call ModifySendStatus on a method "
477                        "which has a Cancel notification");
478   }
479 
GetSendTrailingMetadata()480   std::multimap<std::string, std::string>* GetSendTrailingMetadata() override {
481     GPR_CODEGEN_ASSERT(false &&
482                        "It is illegal to call GetSendTrailingMetadata on a "
483                        "method which has a Cancel notification");
484     return nullptr;
485   }
486 
GetRecvMessage()487   void* GetRecvMessage() override {
488     GPR_CODEGEN_ASSERT(false &&
489                        "It is illegal to call GetRecvMessage on a method which "
490                        "has a Cancel notification");
491     return nullptr;
492   }
493 
GetRecvInitialMetadata()494   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata()
495       override {
496     GPR_CODEGEN_ASSERT(false &&
497                        "It is illegal to call GetRecvInitialMetadata on a "
498                        "method which has a Cancel notification");
499     return nullptr;
500   }
501 
GetRecvStatus()502   Status* GetRecvStatus() override {
503     GPR_CODEGEN_ASSERT(false &&
504                        "It is illegal to call GetRecvStatus on a method which "
505                        "has a Cancel notification");
506     return nullptr;
507   }
508 
GetRecvTrailingMetadata()509   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata()
510       override {
511     GPR_CODEGEN_ASSERT(false &&
512                        "It is illegal to call GetRecvTrailingMetadata on a "
513                        "method which has a Cancel notification");
514     return nullptr;
515   }
516 
GetInterceptedChannel()517   std::unique_ptr<ChannelInterface> GetInterceptedChannel() override {
518     GPR_CODEGEN_ASSERT(false &&
519                        "It is illegal to call GetInterceptedChannel on a "
520                        "method which has a Cancel notification");
521     return std::unique_ptr<ChannelInterface>(nullptr);
522   }
523 
FailHijackedRecvMessage()524   void FailHijackedRecvMessage() override {
525     GPR_CODEGEN_ASSERT(false &&
526                        "It is illegal to call FailHijackedRecvMessage on a "
527                        "method which has a Cancel notification");
528   }
529 
FailHijackedSendMessage()530   void FailHijackedSendMessage() override {
531     GPR_CODEGEN_ASSERT(false &&
532                        "It is illegal to call FailHijackedSendMessage on a "
533                        "method which has a Cancel notification");
534   }
535 };
536 }  // namespace internal
537 }  // namespace grpc
538 
539 #endif  // GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H
540