• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *
3  * Copyright 2018 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #ifndef GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H
20 #define GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H
21 
22 #include <array>
23 #include <functional>
24 
25 #include <grpcpp/impl/codegen/call.h>
26 #include <grpcpp/impl/codegen/call_op_set_interface.h>
27 #include <grpcpp/impl/codegen/client_interceptor.h>
28 #include <grpcpp/impl/codegen/intercepted_channel.h>
29 #include <grpcpp/impl/codegen/server_interceptor.h>
30 
31 #include <grpc/impl/codegen/grpc_types.h>
32 
33 namespace grpc {
34 namespace internal {
35 
36 class InterceptorBatchMethodsImpl
37     : public experimental::InterceptorBatchMethods {
38  public:
InterceptorBatchMethodsImpl()39   InterceptorBatchMethodsImpl() {
40     for (auto i = static_cast<experimental::InterceptionHookPoints>(0);
41          i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS;
42          i = static_cast<experimental::InterceptionHookPoints>(
43              static_cast<size_t>(i) + 1)) {
44       hooks_[static_cast<size_t>(i)] = false;
45     }
46   }
47 
~InterceptorBatchMethodsImpl()48   ~InterceptorBatchMethodsImpl() override {}
49 
QueryInterceptionHookPoint(experimental::InterceptionHookPoints type)50   bool QueryInterceptionHookPoint(
51       experimental::InterceptionHookPoints type) override {
52     return hooks_[static_cast<size_t>(type)];
53   }
54 
Proceed()55   void Proceed() override {
56     if (call_->client_rpc_info() != nullptr) {
57       return ProceedClient();
58     }
59     GPR_CODEGEN_ASSERT(call_->server_rpc_info() != nullptr);
60     ProceedServer();
61   }
62 
Hijack()63   void Hijack() override {
64     // Only the client can hijack when sending down initial metadata
65     GPR_CODEGEN_ASSERT(!reverse_ && ops_ != nullptr &&
66                        call_->client_rpc_info() != nullptr);
67     // It is illegal to call Hijack twice
68     GPR_CODEGEN_ASSERT(!ran_hijacking_interceptor_);
69     auto* rpc_info = call_->client_rpc_info();
70     rpc_info->hijacked_ = true;
71     rpc_info->hijacked_interceptor_ = current_interceptor_index_;
72     ClearHookPoints();
73     ops_->SetHijackingState();
74     ran_hijacking_interceptor_ = true;
75     rpc_info->RunInterceptor(this, current_interceptor_index_);
76   }
77 
AddInterceptionHookPoint(experimental::InterceptionHookPoints type)78   void AddInterceptionHookPoint(experimental::InterceptionHookPoints type) {
79     hooks_[static_cast<size_t>(type)] = true;
80   }
81 
GetSerializedSendMessage()82   ByteBuffer* GetSerializedSendMessage() override {
83     GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr);
84     if (*orig_send_message_ != nullptr) {
85       GPR_CODEGEN_ASSERT(serializer_(*orig_send_message_).ok());
86       *orig_send_message_ = nullptr;
87     }
88     return send_message_;
89   }
90 
GetSendMessage()91   const void* GetSendMessage() override {
92     GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr);
93     return *orig_send_message_;
94   }
95 
ModifySendMessage(const void * message)96   void ModifySendMessage(const void* message) override {
97     GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr);
98     *orig_send_message_ = message;
99   }
100 
GetSendMessageStatus()101   bool GetSendMessageStatus() override { return !*fail_send_message_; }
102 
GetSendInitialMetadata()103   std::multimap<std::string, std::string>* GetSendInitialMetadata() override {
104     return send_initial_metadata_;
105   }
106 
GetSendStatus()107   Status GetSendStatus() override {
108     return Status(static_cast<StatusCode>(*code_), *error_message_,
109                   *error_details_);
110   }
111 
ModifySendStatus(const Status & status)112   void ModifySendStatus(const Status& status) override {
113     *code_ = static_cast<grpc_status_code>(status.error_code());
114     *error_details_ = status.error_details();
115     *error_message_ = status.error_message();
116   }
117 
GetSendTrailingMetadata()118   std::multimap<std::string, std::string>* GetSendTrailingMetadata() override {
119     return send_trailing_metadata_;
120   }
121 
GetRecvMessage()122   void* GetRecvMessage() override { return recv_message_; }
123 
GetRecvInitialMetadata()124   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata()
125       override {
126     return recv_initial_metadata_->map();
127   }
128 
GetRecvStatus()129   Status* GetRecvStatus() override { return recv_status_; }
130 
FailHijackedSendMessage()131   void FailHijackedSendMessage() override {
132     GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>(
133         experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]);
134     *fail_send_message_ = true;
135   }
136 
GetRecvTrailingMetadata()137   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata()
138       override {
139     return recv_trailing_metadata_->map();
140   }
141 
SetSendMessage(ByteBuffer * buf,const void ** msg,bool * fail_send_message,std::function<Status (const void *)> serializer)142   void SetSendMessage(ByteBuffer* buf, const void** msg,
143                       bool* fail_send_message,
144                       std::function<Status(const void*)> serializer) {
145     send_message_ = buf;
146     orig_send_message_ = msg;
147     fail_send_message_ = fail_send_message;
148     serializer_ = serializer;
149   }
150 
SetSendInitialMetadata(std::multimap<std::string,std::string> * metadata)151   void SetSendInitialMetadata(
152       std::multimap<std::string, std::string>* metadata) {
153     send_initial_metadata_ = metadata;
154   }
155 
SetSendStatus(grpc_status_code * code,std::string * error_details,std::string * error_message)156   void SetSendStatus(grpc_status_code* code, std::string* error_details,
157                      std::string* error_message) {
158     code_ = code;
159     error_details_ = error_details;
160     error_message_ = error_message;
161   }
162 
SetSendTrailingMetadata(std::multimap<std::string,std::string> * metadata)163   void SetSendTrailingMetadata(
164       std::multimap<std::string, std::string>* metadata) {
165     send_trailing_metadata_ = metadata;
166   }
167 
SetRecvMessage(void * message,bool * hijacked_recv_message_failed)168   void SetRecvMessage(void* message, bool* hijacked_recv_message_failed) {
169     recv_message_ = message;
170     hijacked_recv_message_failed_ = hijacked_recv_message_failed;
171   }
172 
SetRecvInitialMetadata(MetadataMap * map)173   void SetRecvInitialMetadata(MetadataMap* map) {
174     recv_initial_metadata_ = map;
175   }
176 
SetRecvStatus(Status * status)177   void SetRecvStatus(Status* status) { recv_status_ = status; }
178 
SetRecvTrailingMetadata(MetadataMap * map)179   void SetRecvTrailingMetadata(MetadataMap* map) {
180     recv_trailing_metadata_ = map;
181   }
182 
GetInterceptedChannel()183   std::unique_ptr<ChannelInterface> GetInterceptedChannel() override {
184     auto* info = call_->client_rpc_info();
185     if (info == nullptr) {
186       return std::unique_ptr<ChannelInterface>(nullptr);
187     }
188     // The intercepted channel starts from the interceptor just after the
189     // current interceptor
190     return std::unique_ptr<ChannelInterface>(new InterceptedChannel(
191         info->channel(), current_interceptor_index_ + 1));
192   }
193 
FailHijackedRecvMessage()194   void FailHijackedRecvMessage() override {
195     GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>(
196         experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)]);
197     *hijacked_recv_message_failed_ = true;
198   }
199 
200   // Clears all state
ClearState()201   void ClearState() {
202     reverse_ = false;
203     ran_hijacking_interceptor_ = false;
204     ClearHookPoints();
205   }
206 
207   // Prepares for Post_recv operations
SetReverse()208   void SetReverse() {
209     reverse_ = true;
210     ran_hijacking_interceptor_ = false;
211     ClearHookPoints();
212   }
213 
214   // This needs to be set before interceptors are run
SetCall(Call * call)215   void SetCall(Call* call) { call_ = call; }
216 
217   // This needs to be set before interceptors are run using RunInterceptors().
218   // Alternatively, RunInterceptors(std::function<void(void)> f) can be used.
SetCallOpSetInterface(CallOpSetInterface * ops)219   void SetCallOpSetInterface(CallOpSetInterface* ops) { ops_ = ops; }
220 
221   // SetCall should have been called before this.
222   // Returns true if the interceptors list is empty
InterceptorsListEmpty()223   bool InterceptorsListEmpty() {
224     auto* client_rpc_info = call_->client_rpc_info();
225     if (client_rpc_info != nullptr) {
226       if (client_rpc_info->interceptors_.empty()) {
227         return true;
228       } else {
229         return false;
230       }
231     }
232 
233     auto* server_rpc_info = call_->server_rpc_info();
234     if (server_rpc_info == nullptr || server_rpc_info->interceptors_.empty()) {
235       return true;
236     }
237     return false;
238   }
239 
240   // This should be used only by subclasses of CallOpSetInterface. SetCall and
241   // SetCallOpSetInterface should have been called before this. After all the
242   // interceptors are done running, either ContinueFillOpsAfterInterception or
243   // ContinueFinalizeOpsAfterInterception will be called. Note that neither of
244   // them is invoked if there were no interceptors registered.
RunInterceptors()245   bool RunInterceptors() {
246     GPR_CODEGEN_ASSERT(ops_);
247     auto* client_rpc_info = call_->client_rpc_info();
248     if (client_rpc_info != nullptr) {
249       if (client_rpc_info->interceptors_.empty()) {
250         return true;
251       } else {
252         RunClientInterceptors();
253         return false;
254       }
255     }
256 
257     auto* server_rpc_info = call_->server_rpc_info();
258     if (server_rpc_info == nullptr || server_rpc_info->interceptors_.empty()) {
259       return true;
260     }
261     RunServerInterceptors();
262     return false;
263   }
264 
265   // Returns true if no interceptors are run. Returns false otherwise if there
266   // are interceptors registered. After the interceptors are done running \a f
267   // will be invoked. This is to be used only by BaseAsyncRequest and
268   // SyncRequest.
RunInterceptors(std::function<void (void)> f)269   bool RunInterceptors(std::function<void(void)> f) {
270     // This is used only by the server for initial call request
271     GPR_CODEGEN_ASSERT(reverse_ == true);
272     GPR_CODEGEN_ASSERT(call_->client_rpc_info() == nullptr);
273     auto* server_rpc_info = call_->server_rpc_info();
274     if (server_rpc_info == nullptr || server_rpc_info->interceptors_.empty()) {
275       return true;
276     }
277     callback_ = std::move(f);
278     RunServerInterceptors();
279     return false;
280   }
281 
282  private:
RunClientInterceptors()283   void RunClientInterceptors() {
284     auto* rpc_info = call_->client_rpc_info();
285     if (!reverse_) {
286       current_interceptor_index_ = 0;
287     } else {
288       if (rpc_info->hijacked_) {
289         current_interceptor_index_ = rpc_info->hijacked_interceptor_;
290       } else {
291         current_interceptor_index_ = rpc_info->interceptors_.size() - 1;
292       }
293     }
294     rpc_info->RunInterceptor(this, current_interceptor_index_);
295   }
296 
RunServerInterceptors()297   void RunServerInterceptors() {
298     auto* rpc_info = call_->server_rpc_info();
299     if (!reverse_) {
300       current_interceptor_index_ = 0;
301     } else {
302       current_interceptor_index_ = rpc_info->interceptors_.size() - 1;
303     }
304     rpc_info->RunInterceptor(this, current_interceptor_index_);
305   }
306 
ProceedClient()307   void ProceedClient() {
308     auto* rpc_info = call_->client_rpc_info();
309     if (rpc_info->hijacked_ && !reverse_ &&
310         current_interceptor_index_ == rpc_info->hijacked_interceptor_ &&
311         !ran_hijacking_interceptor_) {
312       // We now need to provide hijacked recv ops to this interceptor
313       ClearHookPoints();
314       ops_->SetHijackingState();
315       ran_hijacking_interceptor_ = true;
316       rpc_info->RunInterceptor(this, current_interceptor_index_);
317       return;
318     }
319     if (!reverse_) {
320       current_interceptor_index_++;
321       // We are going down the stack of interceptors
322       if (current_interceptor_index_ < rpc_info->interceptors_.size()) {
323         if (rpc_info->hijacked_ &&
324             current_interceptor_index_ > rpc_info->hijacked_interceptor_) {
325           // This is a hijacked RPC and we are done with hijacking
326           ops_->ContinueFillOpsAfterInterception();
327         } else {
328           rpc_info->RunInterceptor(this, current_interceptor_index_);
329         }
330       } else {
331         // we are done running all the interceptors without any hijacking
332         ops_->ContinueFillOpsAfterInterception();
333       }
334     } else {
335       // We are going up the stack of interceptors
336       if (current_interceptor_index_ > 0) {
337         // Continue running interceptors
338         current_interceptor_index_--;
339         rpc_info->RunInterceptor(this, current_interceptor_index_);
340       } else {
341         // we are done running all the interceptors without any hijacking
342         ops_->ContinueFinalizeResultAfterInterception();
343       }
344     }
345   }
346 
ProceedServer()347   void ProceedServer() {
348     auto* rpc_info = call_->server_rpc_info();
349     if (!reverse_) {
350       current_interceptor_index_++;
351       if (current_interceptor_index_ < rpc_info->interceptors_.size()) {
352         return rpc_info->RunInterceptor(this, current_interceptor_index_);
353       } else if (ops_) {
354         return ops_->ContinueFillOpsAfterInterception();
355       }
356     } else {
357       // We are going up the stack of interceptors
358       if (current_interceptor_index_ > 0) {
359         // Continue running interceptors
360         current_interceptor_index_--;
361         return rpc_info->RunInterceptor(this, current_interceptor_index_);
362       } else if (ops_) {
363         return ops_->ContinueFinalizeResultAfterInterception();
364       }
365     }
366     GPR_CODEGEN_ASSERT(callback_);
367     callback_();
368   }
369 
ClearHookPoints()370   void ClearHookPoints() {
371     for (auto i = static_cast<experimental::InterceptionHookPoints>(0);
372          i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS;
373          i = static_cast<experimental::InterceptionHookPoints>(
374              static_cast<size_t>(i) + 1)) {
375       hooks_[static_cast<size_t>(i)] = false;
376     }
377   }
378 
379   std::array<bool,
380              static_cast<size_t>(
381                  experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)>
382       hooks_;
383 
384   size_t current_interceptor_index_ = 0;  // Current iterator
385   bool reverse_ = false;
386   bool ran_hijacking_interceptor_ = false;
387   Call* call_ = nullptr;  // The Call object is present along with CallOpSet
388                           // object/callback
389   CallOpSetInterface* ops_ = nullptr;
390   std::function<void(void)> callback_;
391 
392   ByteBuffer* send_message_ = nullptr;
393   bool* fail_send_message_ = nullptr;
394   const void** orig_send_message_ = nullptr;
395   std::function<Status(const void*)> serializer_;
396 
397   std::multimap<std::string, std::string>* send_initial_metadata_;
398 
399   grpc_status_code* code_ = nullptr;
400   std::string* error_details_ = nullptr;
401   std::string* error_message_ = nullptr;
402 
403   std::multimap<std::string, std::string>* send_trailing_metadata_ = nullptr;
404 
405   void* recv_message_ = nullptr;
406   bool* hijacked_recv_message_failed_ = nullptr;
407 
408   MetadataMap* recv_initial_metadata_ = nullptr;
409 
410   Status* recv_status_ = nullptr;
411 
412   MetadataMap* recv_trailing_metadata_ = nullptr;
413 };
414 
415 // A special implementation of InterceptorBatchMethods to send a Cancel
416 // notification down the interceptor stack
417 class CancelInterceptorBatchMethods
418     : public experimental::InterceptorBatchMethods {
419  public:
QueryInterceptionHookPoint(experimental::InterceptionHookPoints type)420   bool QueryInterceptionHookPoint(
421       experimental::InterceptionHookPoints type) override {
422     if (type == experimental::InterceptionHookPoints::PRE_SEND_CANCEL) {
423       return true;
424     } else {
425       return false;
426     }
427   }
428 
Proceed()429   void Proceed() override {
430     // This is a no-op. For actual continuation of the RPC simply needs to
431     // return from the Intercept method
432   }
433 
Hijack()434   void Hijack() override {
435     // Only the client can hijack when sending down initial metadata
436     GPR_CODEGEN_ASSERT(false &&
437                        "It is illegal to call Hijack on a method which has a "
438                        "Cancel notification");
439   }
440 
GetSerializedSendMessage()441   ByteBuffer* GetSerializedSendMessage() override {
442     GPR_CODEGEN_ASSERT(false &&
443                        "It is illegal to call GetSendMessage on a method which "
444                        "has a Cancel notification");
445     return nullptr;
446   }
447 
GetSendMessageStatus()448   bool GetSendMessageStatus() override {
449     GPR_CODEGEN_ASSERT(
450         false &&
451         "It is illegal to call GetSendMessageStatus on a method which "
452         "has a Cancel notification");
453     return false;
454   }
455 
GetSendMessage()456   const void* GetSendMessage() override {
457     GPR_CODEGEN_ASSERT(
458         false &&
459         "It is illegal to call GetOriginalSendMessage on a method which "
460         "has a Cancel notification");
461     return nullptr;
462   }
463 
ModifySendMessage(const void *)464   void ModifySendMessage(const void* /*message*/) override {
465     GPR_CODEGEN_ASSERT(
466         false &&
467         "It is illegal to call ModifySendMessage on a method which "
468         "has a Cancel notification");
469   }
470 
GetSendInitialMetadata()471   std::multimap<std::string, std::string>* GetSendInitialMetadata() override {
472     GPR_CODEGEN_ASSERT(false &&
473                        "It is illegal to call GetSendInitialMetadata on a "
474                        "method which has a Cancel notification");
475     return nullptr;
476   }
477 
GetSendStatus()478   Status GetSendStatus() override {
479     GPR_CODEGEN_ASSERT(false &&
480                        "It is illegal to call GetSendStatus on a method which "
481                        "has a Cancel notification");
482     return Status();
483   }
484 
ModifySendStatus(const Status &)485   void ModifySendStatus(const Status& /*status*/) override {
486     GPR_CODEGEN_ASSERT(false &&
487                        "It is illegal to call ModifySendStatus on a method "
488                        "which has a Cancel notification");
489   }
490 
GetSendTrailingMetadata()491   std::multimap<std::string, std::string>* GetSendTrailingMetadata() override {
492     GPR_CODEGEN_ASSERT(false &&
493                        "It is illegal to call GetSendTrailingMetadata on a "
494                        "method which has a Cancel notification");
495     return nullptr;
496   }
497 
GetRecvMessage()498   void* GetRecvMessage() override {
499     GPR_CODEGEN_ASSERT(false &&
500                        "It is illegal to call GetRecvMessage on a method which "
501                        "has a Cancel notification");
502     return nullptr;
503   }
504 
GetRecvInitialMetadata()505   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata()
506       override {
507     GPR_CODEGEN_ASSERT(false &&
508                        "It is illegal to call GetRecvInitialMetadata on a "
509                        "method which has a Cancel notification");
510     return nullptr;
511   }
512 
GetRecvStatus()513   Status* GetRecvStatus() override {
514     GPR_CODEGEN_ASSERT(false &&
515                        "It is illegal to call GetRecvStatus on a method which "
516                        "has a Cancel notification");
517     return nullptr;
518   }
519 
GetRecvTrailingMetadata()520   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata()
521       override {
522     GPR_CODEGEN_ASSERT(false &&
523                        "It is illegal to call GetRecvTrailingMetadata on a "
524                        "method which has a Cancel notification");
525     return nullptr;
526   }
527 
GetInterceptedChannel()528   std::unique_ptr<ChannelInterface> GetInterceptedChannel() override {
529     GPR_CODEGEN_ASSERT(false &&
530                        "It is illegal to call GetInterceptedChannel on a "
531                        "method which has a Cancel notification");
532     return std::unique_ptr<ChannelInterface>(nullptr);
533   }
534 
FailHijackedRecvMessage()535   void FailHijackedRecvMessage() override {
536     GPR_CODEGEN_ASSERT(false &&
537                        "It is illegal to call FailHijackedRecvMessage on a "
538                        "method which has a Cancel notification");
539   }
540 
FailHijackedSendMessage()541   void FailHijackedSendMessage() override {
542     GPR_CODEGEN_ASSERT(false &&
543                        "It is illegal to call FailHijackedSendMessage on a "
544                        "method which has a Cancel notification");
545   }
546 };
547 }  // namespace internal
548 }  // namespace grpc
549 
550 #endif  // GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H
551