• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #ifndef GRPCPP_IMPL_INTERCEPTOR_COMMON_H
20 #define GRPCPP_IMPL_INTERCEPTOR_COMMON_H
21 
22 #include <grpc/impl/grpc_types.h>
23 #include <grpcpp/impl/call.h>
24 #include <grpcpp/impl/call_op_set_interface.h>
25 #include <grpcpp/impl/intercepted_channel.h>
26 #include <grpcpp/support/client_interceptor.h>
27 #include <grpcpp/support/server_interceptor.h>
28 
29 #include <array>
30 #include <functional>
31 
32 #include "absl/log/absl_check.h"
33 
34 namespace grpc {
35 namespace internal {
36 
37 class InterceptorBatchMethodsImpl
38     : public experimental::InterceptorBatchMethods {
39  public:
InterceptorBatchMethodsImpl()40   InterceptorBatchMethodsImpl() {
41     for (auto i = static_cast<experimental::InterceptionHookPoints>(0);
42          i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS;
43          i = static_cast<experimental::InterceptionHookPoints>(
44              static_cast<size_t>(i) + 1)) {
45       hooks_[static_cast<size_t>(i)] = false;
46     }
47   }
48 
~InterceptorBatchMethodsImpl()49   ~InterceptorBatchMethodsImpl() override {}
50 
QueryInterceptionHookPoint(experimental::InterceptionHookPoints type)51   bool QueryInterceptionHookPoint(
52       experimental::InterceptionHookPoints type) override {
53     return hooks_[static_cast<size_t>(type)];
54   }
55 
Proceed()56   void Proceed() override {
57     if (call_->client_rpc_info() != nullptr) {
58       return ProceedClient();
59     }
60     ABSL_CHECK_NE(call_->server_rpc_info(), nullptr);
61     ProceedServer();
62   }
63 
Hijack()64   void Hijack() override {
65     // Only the client can hijack when sending down initial metadata
66     ABSL_CHECK(!reverse_ && ops_ != nullptr &&
67                call_->client_rpc_info() != nullptr);
68     // It is illegal to call Hijack twice
69     ABSL_CHECK(!ran_hijacking_interceptor_);
70     auto* rpc_info = call_->client_rpc_info();
71     rpc_info->hijacked_ = true;
72     rpc_info->hijacked_interceptor_ = current_interceptor_index_;
73     ClearHookPoints();
74     ops_->SetHijackingState();
75     ran_hijacking_interceptor_ = true;
76     rpc_info->RunInterceptor(this, current_interceptor_index_);
77   }
78 
AddInterceptionHookPoint(experimental::InterceptionHookPoints type)79   void AddInterceptionHookPoint(experimental::InterceptionHookPoints type) {
80     hooks_[static_cast<size_t>(type)] = true;
81   }
82 
GetSerializedSendMessage()83   ByteBuffer* GetSerializedSendMessage() override {
84     ABSL_CHECK_NE(orig_send_message_, nullptr);
85     if (*orig_send_message_ != nullptr) {
86       ABSL_CHECK(serializer_(*orig_send_message_).ok());
87       *orig_send_message_ = nullptr;
88     }
89     return send_message_;
90   }
91 
GetSendMessage()92   const void* GetSendMessage() override {
93     ABSL_CHECK_NE(orig_send_message_, nullptr);
94     return *orig_send_message_;
95   }
96 
ModifySendMessage(const void * message)97   void ModifySendMessage(const void* message) override {
98     ABSL_CHECK_NE(orig_send_message_, nullptr);
99     *orig_send_message_ = message;
100   }
101 
GetSendMessageStatus()102   bool GetSendMessageStatus() override { return !*fail_send_message_; }
103 
GetSendInitialMetadata()104   std::multimap<std::string, std::string>* GetSendInitialMetadata() override {
105     return send_initial_metadata_;
106   }
107 
GetSendStatus()108   Status GetSendStatus() override {
109     return Status(static_cast<StatusCode>(*code_), *error_message_,
110                   *error_details_);
111   }
112 
ModifySendStatus(const Status & status)113   void ModifySendStatus(const Status& status) override {
114     *code_ = static_cast<grpc_status_code>(status.error_code());
115     *error_details_ = status.error_details();
116     *error_message_ = status.error_message();
117   }
118 
GetSendTrailingMetadata()119   std::multimap<std::string, std::string>* GetSendTrailingMetadata() override {
120     return send_trailing_metadata_;
121   }
122 
GetRecvMessage()123   void* GetRecvMessage() override { return recv_message_; }
124 
GetRecvInitialMetadata()125   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata()
126       override {
127     return recv_initial_metadata_->map();
128   }
129 
GetRecvStatus()130   Status* GetRecvStatus() override { return recv_status_; }
131 
FailHijackedSendMessage()132   void FailHijackedSendMessage() override {
133     ABSL_CHECK(hooks_[static_cast<size_t>(
134         experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]);
135     *fail_send_message_ = true;
136   }
137 
GetRecvTrailingMetadata()138   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata()
139       override {
140     return recv_trailing_metadata_->map();
141   }
142 
SetSendMessage(ByteBuffer * buf,const void ** msg,bool * fail_send_message,std::function<Status (const void *)> serializer)143   void SetSendMessage(ByteBuffer* buf, const void** msg,
144                       bool* fail_send_message,
145                       std::function<Status(const void*)> serializer) {
146     send_message_ = buf;
147     orig_send_message_ = msg;
148     fail_send_message_ = fail_send_message;
149     serializer_ = serializer;
150   }
151 
SetSendInitialMetadata(std::multimap<std::string,std::string> * metadata)152   void SetSendInitialMetadata(
153       std::multimap<std::string, std::string>* metadata) {
154     send_initial_metadata_ = metadata;
155   }
156 
SetSendStatus(grpc_status_code * code,std::string * error_details,std::string * error_message)157   void SetSendStatus(grpc_status_code* code, std::string* error_details,
158                      std::string* error_message) {
159     code_ = code;
160     error_details_ = error_details;
161     error_message_ = error_message;
162   }
163 
SetSendTrailingMetadata(std::multimap<std::string,std::string> * metadata)164   void SetSendTrailingMetadata(
165       std::multimap<std::string, std::string>* metadata) {
166     send_trailing_metadata_ = metadata;
167   }
168 
SetRecvMessage(void * message,bool * hijacked_recv_message_failed)169   void SetRecvMessage(void* message, bool* hijacked_recv_message_failed) {
170     recv_message_ = message;
171     hijacked_recv_message_failed_ = hijacked_recv_message_failed;
172   }
173 
SetRecvInitialMetadata(MetadataMap * map)174   void SetRecvInitialMetadata(MetadataMap* map) {
175     recv_initial_metadata_ = map;
176   }
177 
SetRecvStatus(Status * status)178   void SetRecvStatus(Status* status) { recv_status_ = status; }
179 
SetRecvTrailingMetadata(MetadataMap * map)180   void SetRecvTrailingMetadata(MetadataMap* map) {
181     recv_trailing_metadata_ = map;
182   }
183 
GetInterceptedChannel()184   std::unique_ptr<ChannelInterface> GetInterceptedChannel() override {
185     auto* info = call_->client_rpc_info();
186     if (info == nullptr) {
187       return std::unique_ptr<ChannelInterface>(nullptr);
188     }
189     // The intercepted channel starts from the interceptor just after the
190     // current interceptor
191     return std::unique_ptr<ChannelInterface>(new InterceptedChannel(
192         info->channel(), current_interceptor_index_ + 1));
193   }
194 
FailHijackedRecvMessage()195   void FailHijackedRecvMessage() override {
196     ABSL_CHECK(hooks_[static_cast<size_t>(
197         experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)]);
198     *hijacked_recv_message_failed_ = true;
199   }
200 
201   // Clears all state
ClearState()202   void ClearState() {
203     reverse_ = false;
204     ran_hijacking_interceptor_ = false;
205     ClearHookPoints();
206   }
207 
208   // Prepares for Post_recv operations
SetReverse()209   void SetReverse() {
210     reverse_ = true;
211     ran_hijacking_interceptor_ = false;
212     ClearHookPoints();
213   }
214 
215   // This needs to be set before interceptors are run
SetCall(Call * call)216   void SetCall(Call* call) { call_ = call; }
217 
218   // This needs to be set before interceptors are run using RunInterceptors().
219   // Alternatively, RunInterceptors(std::function<void(void)> f) can be used.
SetCallOpSetInterface(CallOpSetInterface * ops)220   void SetCallOpSetInterface(CallOpSetInterface* ops) { ops_ = ops; }
221 
222   // SetCall should have been called before this.
223   // Returns true if the interceptors list is empty
InterceptorsListEmpty()224   bool InterceptorsListEmpty() {
225     auto* client_rpc_info = call_->client_rpc_info();
226     if (client_rpc_info != nullptr) {
227       return client_rpc_info->interceptors_.empty();
228     }
229 
230     auto* server_rpc_info = call_->server_rpc_info();
231     return server_rpc_info == nullptr || server_rpc_info->interceptors_.empty();
232   }
233 
234   // This should be used only by subclasses of CallOpSetInterface. SetCall and
235   // SetCallOpSetInterface should have been called before this. After all the
236   // interceptors are done running, either ContinueFillOpsAfterInterception or
237   // ContinueFinalizeOpsAfterInterception will be called. Note that neither of
238   // them is invoked if there were no interceptors registered.
RunInterceptors()239   bool RunInterceptors() {
240     ABSL_CHECK(ops_);
241     auto* client_rpc_info = call_->client_rpc_info();
242     if (client_rpc_info != nullptr) {
243       if (client_rpc_info->interceptors_.empty()) {
244         return true;
245       } else {
246         RunClientInterceptors();
247         return false;
248       }
249     }
250 
251     auto* server_rpc_info = call_->server_rpc_info();
252     if (server_rpc_info == nullptr || server_rpc_info->interceptors_.empty()) {
253       return true;
254     }
255     RunServerInterceptors();
256     return false;
257   }
258 
259   // Returns true if no interceptors are run. Returns false otherwise if there
260   // are interceptors registered. After the interceptors are done running \a f
261   // will be invoked. This is to be used only by BaseAsyncRequest and
262   // SyncRequest.
RunInterceptors(std::function<void (void)> f)263   bool RunInterceptors(std::function<void(void)> f) {
264     // This is used only by the server for initial call request
265     ABSL_CHECK_EQ(reverse_, true);
266     ABSL_CHECK_EQ(call_->client_rpc_info(), nullptr);
267     auto* server_rpc_info = call_->server_rpc_info();
268     if (server_rpc_info == nullptr || server_rpc_info->interceptors_.empty()) {
269       return true;
270     }
271     callback_ = std::move(f);
272     RunServerInterceptors();
273     return false;
274   }
275 
276  private:
RunClientInterceptors()277   void RunClientInterceptors() {
278     auto* rpc_info = call_->client_rpc_info();
279     if (!reverse_) {
280       current_interceptor_index_ = 0;
281     } else {
282       if (rpc_info->hijacked_) {
283         current_interceptor_index_ = rpc_info->hijacked_interceptor_;
284       } else {
285         current_interceptor_index_ = rpc_info->interceptors_.size() - 1;
286       }
287     }
288     rpc_info->RunInterceptor(this, current_interceptor_index_);
289   }
290 
RunServerInterceptors()291   void RunServerInterceptors() {
292     auto* rpc_info = call_->server_rpc_info();
293     if (!reverse_) {
294       current_interceptor_index_ = 0;
295     } else {
296       current_interceptor_index_ = rpc_info->interceptors_.size() - 1;
297     }
298     rpc_info->RunInterceptor(this, current_interceptor_index_);
299   }
300 
ProceedClient()301   void ProceedClient() {
302     auto* rpc_info = call_->client_rpc_info();
303     if (rpc_info->hijacked_ && !reverse_ &&
304         current_interceptor_index_ == rpc_info->hijacked_interceptor_ &&
305         !ran_hijacking_interceptor_) {
306       // We now need to provide hijacked recv ops to this interceptor
307       ClearHookPoints();
308       ops_->SetHijackingState();
309       ran_hijacking_interceptor_ = true;
310       rpc_info->RunInterceptor(this, current_interceptor_index_);
311       return;
312     }
313     if (!reverse_) {
314       current_interceptor_index_++;
315       // We are going down the stack of interceptors
316       if (current_interceptor_index_ < rpc_info->interceptors_.size()) {
317         if (rpc_info->hijacked_ &&
318             current_interceptor_index_ > rpc_info->hijacked_interceptor_) {
319           // This is a hijacked RPC and we are done with hijacking
320           ops_->ContinueFillOpsAfterInterception();
321         } else {
322           rpc_info->RunInterceptor(this, current_interceptor_index_);
323         }
324       } else {
325         // we are done running all the interceptors without any hijacking
326         ops_->ContinueFillOpsAfterInterception();
327       }
328     } else {
329       // We are going up the stack of interceptors
330       if (current_interceptor_index_ > 0) {
331         // Continue running interceptors
332         current_interceptor_index_--;
333         rpc_info->RunInterceptor(this, current_interceptor_index_);
334       } else {
335         // we are done running all the interceptors without any hijacking
336         ops_->ContinueFinalizeResultAfterInterception();
337       }
338     }
339   }
340 
ProceedServer()341   void ProceedServer() {
342     auto* rpc_info = call_->server_rpc_info();
343     if (!reverse_) {
344       current_interceptor_index_++;
345       if (current_interceptor_index_ < rpc_info->interceptors_.size()) {
346         return rpc_info->RunInterceptor(this, current_interceptor_index_);
347       } else if (ops_) {
348         return ops_->ContinueFillOpsAfterInterception();
349       }
350     } else {
351       // We are going up the stack of interceptors
352       if (current_interceptor_index_ > 0) {
353         // Continue running interceptors
354         current_interceptor_index_--;
355         return rpc_info->RunInterceptor(this, current_interceptor_index_);
356       } else if (ops_) {
357         return ops_->ContinueFinalizeResultAfterInterception();
358       }
359     }
360     ABSL_CHECK(callback_);
361     callback_();
362   }
363 
ClearHookPoints()364   void ClearHookPoints() {
365     for (auto i = static_cast<experimental::InterceptionHookPoints>(0);
366          i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS;
367          i = static_cast<experimental::InterceptionHookPoints>(
368              static_cast<size_t>(i) + 1)) {
369       hooks_[static_cast<size_t>(i)] = false;
370     }
371   }
372 
373   std::array<bool,
374              static_cast<size_t>(
375                  experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)>
376       hooks_;
377 
378   size_t current_interceptor_index_ = 0;  // Current iterator
379   bool reverse_ = false;
380   bool ran_hijacking_interceptor_ = false;
381   Call* call_ = nullptr;  // The Call object is present along with CallOpSet
382                           // object/callback
383   CallOpSetInterface* ops_ = nullptr;
384   std::function<void(void)> callback_;
385 
386   ByteBuffer* send_message_ = nullptr;
387   bool* fail_send_message_ = nullptr;
388   const void** orig_send_message_ = nullptr;
389   std::function<Status(const void*)> serializer_;
390 
391   std::multimap<std::string, std::string>* send_initial_metadata_;
392 
393   grpc_status_code* code_ = nullptr;
394   std::string* error_details_ = nullptr;
395   std::string* error_message_ = nullptr;
396 
397   std::multimap<std::string, std::string>* send_trailing_metadata_ = nullptr;
398 
399   void* recv_message_ = nullptr;
400   bool* hijacked_recv_message_failed_ = nullptr;
401 
402   MetadataMap* recv_initial_metadata_ = nullptr;
403 
404   Status* recv_status_ = nullptr;
405 
406   MetadataMap* recv_trailing_metadata_ = nullptr;
407 };
408 
409 // A special implementation of InterceptorBatchMethods to send a Cancel
410 // notification down the interceptor stack
411 class CancelInterceptorBatchMethods
412     : public experimental::InterceptorBatchMethods {
413  public:
QueryInterceptionHookPoint(experimental::InterceptionHookPoints type)414   bool QueryInterceptionHookPoint(
415       experimental::InterceptionHookPoints type) override {
416     return type == experimental::InterceptionHookPoints::PRE_SEND_CANCEL;
417   }
418 
Proceed()419   void Proceed() override {
420     // This is a no-op. For actual continuation of the RPC simply needs to
421     // return from the Intercept method
422   }
423 
Hijack()424   void Hijack() override {
425     // Only the client can hijack when sending down initial metadata
426     ABSL_CHECK(false) << "It is illegal to call Hijack on a method which has a "
427                          "Cancel notification";
428   }
429 
GetSerializedSendMessage()430   ByteBuffer* GetSerializedSendMessage() override {
431     ABSL_CHECK(false)
432         << "It is illegal to call GetSendMessage on a method which "
433            "has a Cancel notification";
434     return nullptr;
435   }
436 
GetSendMessageStatus()437   bool GetSendMessageStatus() override {
438     ABSL_CHECK(false)
439         << "It is illegal to call GetSendMessageStatus on a method which "
440            "has a Cancel notification";
441     return false;
442   }
443 
GetSendMessage()444   const void* GetSendMessage() override {
445     ABSL_CHECK(false)
446         << "It is illegal to call GetOriginalSendMessage on a method which "
447            "has a Cancel notification";
448     return nullptr;
449   }
450 
ModifySendMessage(const void *)451   void ModifySendMessage(const void* /*message*/) override {
452     ABSL_CHECK(false)
453         << "It is illegal to call ModifySendMessage on a method which "
454            "has a Cancel notification";
455   }
456 
GetSendInitialMetadata()457   std::multimap<std::string, std::string>* GetSendInitialMetadata() override {
458     ABSL_CHECK(false) << "It is illegal to call GetSendInitialMetadata on a "
459                          "method which has a Cancel notification";
460     return nullptr;
461   }
462 
GetSendStatus()463   Status GetSendStatus() override {
464     ABSL_CHECK(false)
465         << "It is illegal to call GetSendStatus on a method which "
466            "has a Cancel notification";
467     return Status();
468   }
469 
ModifySendStatus(const Status &)470   void ModifySendStatus(const Status& /*status*/) override {
471     ABSL_CHECK(false) << "It is illegal to call ModifySendStatus on a method "
472                          "which has a Cancel notification";
473   }
474 
GetSendTrailingMetadata()475   std::multimap<std::string, std::string>* GetSendTrailingMetadata() override {
476     ABSL_CHECK(false) << "It is illegal to call GetSendTrailingMetadata on a "
477                          "method which has a Cancel notification";
478     return nullptr;
479   }
480 
GetRecvMessage()481   void* GetRecvMessage() override {
482     ABSL_CHECK(false)
483         << "It is illegal to call GetRecvMessage on a method which "
484            "has a Cancel notification";
485     return nullptr;
486   }
487 
GetRecvInitialMetadata()488   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata()
489       override {
490     ABSL_CHECK(false) << "It is illegal to call GetRecvInitialMetadata on a "
491                          "method which has a Cancel notification";
492     return nullptr;
493   }
494 
GetRecvStatus()495   Status* GetRecvStatus() override {
496     ABSL_CHECK(false)
497         << "It is illegal to call GetRecvStatus on a method which "
498            "has a Cancel notification";
499     return nullptr;
500   }
501 
GetRecvTrailingMetadata()502   std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata()
503       override {
504     ABSL_CHECK(false) << "It is illegal to call GetRecvTrailingMetadata on a "
505                          "method which has a Cancel notification";
506     return nullptr;
507   }
508 
GetInterceptedChannel()509   std::unique_ptr<ChannelInterface> GetInterceptedChannel() override {
510     ABSL_CHECK(false) << "It is illegal to call GetInterceptedChannel on a "
511                          "method which has a Cancel notification";
512     return std::unique_ptr<ChannelInterface>(nullptr);
513   }
514 
FailHijackedRecvMessage()515   void FailHijackedRecvMessage() override {
516     ABSL_CHECK(false) << "It is illegal to call FailHijackedRecvMessage on a "
517                          "method which has a Cancel notification";
518   }
519 
FailHijackedSendMessage()520   void FailHijackedSendMessage() override {
521     ABSL_CHECK(false) << "It is illegal to call FailHijackedSendMessage on a "
522                          "method which has a Cancel notification";
523   }
524 };
525 }  // namespace internal
526 }  // namespace grpc
527 
528 #endif  // GRPCPP_IMPL_INTERCEPTOR_COMMON_H
529