1 // 2 // 3 // Copyright 2018 gRPC authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // 17 // 18 19 #ifndef GRPCPP_IMPL_INTERCEPTOR_COMMON_H 20 #define GRPCPP_IMPL_INTERCEPTOR_COMMON_H 21 22 #include <grpc/impl/grpc_types.h> 23 #include <grpcpp/impl/call.h> 24 #include <grpcpp/impl/call_op_set_interface.h> 25 #include <grpcpp/impl/intercepted_channel.h> 26 #include <grpcpp/support/client_interceptor.h> 27 #include <grpcpp/support/server_interceptor.h> 28 29 #include <array> 30 #include <functional> 31 32 #include "absl/log/absl_check.h" 33 34 namespace grpc { 35 namespace internal { 36 37 class InterceptorBatchMethodsImpl 38 : public experimental::InterceptorBatchMethods { 39 public: InterceptorBatchMethodsImpl()40 InterceptorBatchMethodsImpl() { 41 for (auto i = static_cast<experimental::InterceptionHookPoints>(0); 42 i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS; 43 i = static_cast<experimental::InterceptionHookPoints>( 44 static_cast<size_t>(i) + 1)) { 45 hooks_[static_cast<size_t>(i)] = false; 46 } 47 } 48 ~InterceptorBatchMethodsImpl()49 ~InterceptorBatchMethodsImpl() override {} 50 QueryInterceptionHookPoint(experimental::InterceptionHookPoints type)51 bool QueryInterceptionHookPoint( 52 experimental::InterceptionHookPoints type) override { 53 return hooks_[static_cast<size_t>(type)]; 54 } 55 Proceed()56 void Proceed() override { 57 if (call_->client_rpc_info() != nullptr) { 58 return ProceedClient(); 59 } 60 ABSL_CHECK_NE(call_->server_rpc_info(), nullptr); 61 ProceedServer(); 62 } 63 Hijack()64 void Hijack() override { 65 // Only the client can hijack when sending down initial metadata 66 ABSL_CHECK(!reverse_ && ops_ != nullptr && 67 call_->client_rpc_info() != nullptr); 68 // It is illegal to call Hijack twice 69 ABSL_CHECK(!ran_hijacking_interceptor_); 70 auto* rpc_info = call_->client_rpc_info(); 71 rpc_info->hijacked_ = true; 72 rpc_info->hijacked_interceptor_ = current_interceptor_index_; 73 ClearHookPoints(); 74 ops_->SetHijackingState(); 75 ran_hijacking_interceptor_ = true; 76 rpc_info->RunInterceptor(this, current_interceptor_index_); 77 } 78 AddInterceptionHookPoint(experimental::InterceptionHookPoints type)79 void AddInterceptionHookPoint(experimental::InterceptionHookPoints type) { 80 hooks_[static_cast<size_t>(type)] = true; 81 } 82 GetSerializedSendMessage()83 ByteBuffer* GetSerializedSendMessage() override { 84 ABSL_CHECK_NE(orig_send_message_, nullptr); 85 if (*orig_send_message_ != nullptr) { 86 ABSL_CHECK(serializer_(*orig_send_message_).ok()); 87 *orig_send_message_ = nullptr; 88 } 89 return send_message_; 90 } 91 GetSendMessage()92 const void* GetSendMessage() override { 93 ABSL_CHECK_NE(orig_send_message_, nullptr); 94 return *orig_send_message_; 95 } 96 ModifySendMessage(const void * message)97 void ModifySendMessage(const void* message) override { 98 ABSL_CHECK_NE(orig_send_message_, nullptr); 99 *orig_send_message_ = message; 100 } 101 GetSendMessageStatus()102 bool GetSendMessageStatus() override { return !*fail_send_message_; } 103 GetSendInitialMetadata()104 std::multimap<std::string, std::string>* GetSendInitialMetadata() override { 105 return send_initial_metadata_; 106 } 107 GetSendStatus()108 Status GetSendStatus() override { 109 return Status(static_cast<StatusCode>(*code_), *error_message_, 110 *error_details_); 111 } 112 ModifySendStatus(const Status & status)113 void ModifySendStatus(const Status& status) override { 114 *code_ = static_cast<grpc_status_code>(status.error_code()); 115 *error_details_ = status.error_details(); 116 *error_message_ = status.error_message(); 117 } 118 GetSendTrailingMetadata()119 std::multimap<std::string, std::string>* GetSendTrailingMetadata() override { 120 return send_trailing_metadata_; 121 } 122 GetRecvMessage()123 void* GetRecvMessage() override { return recv_message_; } 124 GetRecvInitialMetadata()125 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata() 126 override { 127 return recv_initial_metadata_->map(); 128 } 129 GetRecvStatus()130 Status* GetRecvStatus() override { return recv_status_; } 131 FailHijackedSendMessage()132 void FailHijackedSendMessage() override { 133 ABSL_CHECK(hooks_[static_cast<size_t>( 134 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]); 135 *fail_send_message_ = true; 136 } 137 GetRecvTrailingMetadata()138 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata() 139 override { 140 return recv_trailing_metadata_->map(); 141 } 142 SetSendMessage(ByteBuffer * buf,const void ** msg,bool * fail_send_message,std::function<Status (const void *)> serializer)143 void SetSendMessage(ByteBuffer* buf, const void** msg, 144 bool* fail_send_message, 145 std::function<Status(const void*)> serializer) { 146 send_message_ = buf; 147 orig_send_message_ = msg; 148 fail_send_message_ = fail_send_message; 149 serializer_ = serializer; 150 } 151 SetSendInitialMetadata(std::multimap<std::string,std::string> * metadata)152 void SetSendInitialMetadata( 153 std::multimap<std::string, std::string>* metadata) { 154 send_initial_metadata_ = metadata; 155 } 156 SetSendStatus(grpc_status_code * code,std::string * error_details,std::string * error_message)157 void SetSendStatus(grpc_status_code* code, std::string* error_details, 158 std::string* error_message) { 159 code_ = code; 160 error_details_ = error_details; 161 error_message_ = error_message; 162 } 163 SetSendTrailingMetadata(std::multimap<std::string,std::string> * metadata)164 void SetSendTrailingMetadata( 165 std::multimap<std::string, std::string>* metadata) { 166 send_trailing_metadata_ = metadata; 167 } 168 SetRecvMessage(void * message,bool * hijacked_recv_message_failed)169 void SetRecvMessage(void* message, bool* hijacked_recv_message_failed) { 170 recv_message_ = message; 171 hijacked_recv_message_failed_ = hijacked_recv_message_failed; 172 } 173 SetRecvInitialMetadata(MetadataMap * map)174 void SetRecvInitialMetadata(MetadataMap* map) { 175 recv_initial_metadata_ = map; 176 } 177 SetRecvStatus(Status * status)178 void SetRecvStatus(Status* status) { recv_status_ = status; } 179 SetRecvTrailingMetadata(MetadataMap * map)180 void SetRecvTrailingMetadata(MetadataMap* map) { 181 recv_trailing_metadata_ = map; 182 } 183 GetInterceptedChannel()184 std::unique_ptr<ChannelInterface> GetInterceptedChannel() override { 185 auto* info = call_->client_rpc_info(); 186 if (info == nullptr) { 187 return std::unique_ptr<ChannelInterface>(nullptr); 188 } 189 // The intercepted channel starts from the interceptor just after the 190 // current interceptor 191 return std::unique_ptr<ChannelInterface>(new InterceptedChannel( 192 info->channel(), current_interceptor_index_ + 1)); 193 } 194 FailHijackedRecvMessage()195 void FailHijackedRecvMessage() override { 196 ABSL_CHECK(hooks_[static_cast<size_t>( 197 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)]); 198 *hijacked_recv_message_failed_ = true; 199 } 200 201 // Clears all state ClearState()202 void ClearState() { 203 reverse_ = false; 204 ran_hijacking_interceptor_ = false; 205 ClearHookPoints(); 206 } 207 208 // Prepares for Post_recv operations SetReverse()209 void SetReverse() { 210 reverse_ = true; 211 ran_hijacking_interceptor_ = false; 212 ClearHookPoints(); 213 } 214 215 // This needs to be set before interceptors are run SetCall(Call * call)216 void SetCall(Call* call) { call_ = call; } 217 218 // This needs to be set before interceptors are run using RunInterceptors(). 219 // Alternatively, RunInterceptors(std::function<void(void)> f) can be used. SetCallOpSetInterface(CallOpSetInterface * ops)220 void SetCallOpSetInterface(CallOpSetInterface* ops) { ops_ = ops; } 221 222 // SetCall should have been called before this. 223 // Returns true if the interceptors list is empty InterceptorsListEmpty()224 bool InterceptorsListEmpty() { 225 auto* client_rpc_info = call_->client_rpc_info(); 226 if (client_rpc_info != nullptr) { 227 return client_rpc_info->interceptors_.empty(); 228 } 229 230 auto* server_rpc_info = call_->server_rpc_info(); 231 return server_rpc_info == nullptr || server_rpc_info->interceptors_.empty(); 232 } 233 234 // This should be used only by subclasses of CallOpSetInterface. SetCall and 235 // SetCallOpSetInterface should have been called before this. After all the 236 // interceptors are done running, either ContinueFillOpsAfterInterception or 237 // ContinueFinalizeOpsAfterInterception will be called. Note that neither of 238 // them is invoked if there were no interceptors registered. RunInterceptors()239 bool RunInterceptors() { 240 ABSL_CHECK(ops_); 241 auto* client_rpc_info = call_->client_rpc_info(); 242 if (client_rpc_info != nullptr) { 243 if (client_rpc_info->interceptors_.empty()) { 244 return true; 245 } else { 246 RunClientInterceptors(); 247 return false; 248 } 249 } 250 251 auto* server_rpc_info = call_->server_rpc_info(); 252 if (server_rpc_info == nullptr || server_rpc_info->interceptors_.empty()) { 253 return true; 254 } 255 RunServerInterceptors(); 256 return false; 257 } 258 259 // Returns true if no interceptors are run. Returns false otherwise if there 260 // are interceptors registered. After the interceptors are done running \a f 261 // will be invoked. This is to be used only by BaseAsyncRequest and 262 // SyncRequest. RunInterceptors(std::function<void (void)> f)263 bool RunInterceptors(std::function<void(void)> f) { 264 // This is used only by the server for initial call request 265 ABSL_CHECK_EQ(reverse_, true); 266 ABSL_CHECK_EQ(call_->client_rpc_info(), nullptr); 267 auto* server_rpc_info = call_->server_rpc_info(); 268 if (server_rpc_info == nullptr || server_rpc_info->interceptors_.empty()) { 269 return true; 270 } 271 callback_ = std::move(f); 272 RunServerInterceptors(); 273 return false; 274 } 275 276 private: RunClientInterceptors()277 void RunClientInterceptors() { 278 auto* rpc_info = call_->client_rpc_info(); 279 if (!reverse_) { 280 current_interceptor_index_ = 0; 281 } else { 282 if (rpc_info->hijacked_) { 283 current_interceptor_index_ = rpc_info->hijacked_interceptor_; 284 } else { 285 current_interceptor_index_ = rpc_info->interceptors_.size() - 1; 286 } 287 } 288 rpc_info->RunInterceptor(this, current_interceptor_index_); 289 } 290 RunServerInterceptors()291 void RunServerInterceptors() { 292 auto* rpc_info = call_->server_rpc_info(); 293 if (!reverse_) { 294 current_interceptor_index_ = 0; 295 } else { 296 current_interceptor_index_ = rpc_info->interceptors_.size() - 1; 297 } 298 rpc_info->RunInterceptor(this, current_interceptor_index_); 299 } 300 ProceedClient()301 void ProceedClient() { 302 auto* rpc_info = call_->client_rpc_info(); 303 if (rpc_info->hijacked_ && !reverse_ && 304 current_interceptor_index_ == rpc_info->hijacked_interceptor_ && 305 !ran_hijacking_interceptor_) { 306 // We now need to provide hijacked recv ops to this interceptor 307 ClearHookPoints(); 308 ops_->SetHijackingState(); 309 ran_hijacking_interceptor_ = true; 310 rpc_info->RunInterceptor(this, current_interceptor_index_); 311 return; 312 } 313 if (!reverse_) { 314 current_interceptor_index_++; 315 // We are going down the stack of interceptors 316 if (current_interceptor_index_ < rpc_info->interceptors_.size()) { 317 if (rpc_info->hijacked_ && 318 current_interceptor_index_ > rpc_info->hijacked_interceptor_) { 319 // This is a hijacked RPC and we are done with hijacking 320 ops_->ContinueFillOpsAfterInterception(); 321 } else { 322 rpc_info->RunInterceptor(this, current_interceptor_index_); 323 } 324 } else { 325 // we are done running all the interceptors without any hijacking 326 ops_->ContinueFillOpsAfterInterception(); 327 } 328 } else { 329 // We are going up the stack of interceptors 330 if (current_interceptor_index_ > 0) { 331 // Continue running interceptors 332 current_interceptor_index_--; 333 rpc_info->RunInterceptor(this, current_interceptor_index_); 334 } else { 335 // we are done running all the interceptors without any hijacking 336 ops_->ContinueFinalizeResultAfterInterception(); 337 } 338 } 339 } 340 ProceedServer()341 void ProceedServer() { 342 auto* rpc_info = call_->server_rpc_info(); 343 if (!reverse_) { 344 current_interceptor_index_++; 345 if (current_interceptor_index_ < rpc_info->interceptors_.size()) { 346 return rpc_info->RunInterceptor(this, current_interceptor_index_); 347 } else if (ops_) { 348 return ops_->ContinueFillOpsAfterInterception(); 349 } 350 } else { 351 // We are going up the stack of interceptors 352 if (current_interceptor_index_ > 0) { 353 // Continue running interceptors 354 current_interceptor_index_--; 355 return rpc_info->RunInterceptor(this, current_interceptor_index_); 356 } else if (ops_) { 357 return ops_->ContinueFinalizeResultAfterInterception(); 358 } 359 } 360 ABSL_CHECK(callback_); 361 callback_(); 362 } 363 ClearHookPoints()364 void ClearHookPoints() { 365 for (auto i = static_cast<experimental::InterceptionHookPoints>(0); 366 i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS; 367 i = static_cast<experimental::InterceptionHookPoints>( 368 static_cast<size_t>(i) + 1)) { 369 hooks_[static_cast<size_t>(i)] = false; 370 } 371 } 372 373 std::array<bool, 374 static_cast<size_t>( 375 experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)> 376 hooks_; 377 378 size_t current_interceptor_index_ = 0; // Current iterator 379 bool reverse_ = false; 380 bool ran_hijacking_interceptor_ = false; 381 Call* call_ = nullptr; // The Call object is present along with CallOpSet 382 // object/callback 383 CallOpSetInterface* ops_ = nullptr; 384 std::function<void(void)> callback_; 385 386 ByteBuffer* send_message_ = nullptr; 387 bool* fail_send_message_ = nullptr; 388 const void** orig_send_message_ = nullptr; 389 std::function<Status(const void*)> serializer_; 390 391 std::multimap<std::string, std::string>* send_initial_metadata_; 392 393 grpc_status_code* code_ = nullptr; 394 std::string* error_details_ = nullptr; 395 std::string* error_message_ = nullptr; 396 397 std::multimap<std::string, std::string>* send_trailing_metadata_ = nullptr; 398 399 void* recv_message_ = nullptr; 400 bool* hijacked_recv_message_failed_ = nullptr; 401 402 MetadataMap* recv_initial_metadata_ = nullptr; 403 404 Status* recv_status_ = nullptr; 405 406 MetadataMap* recv_trailing_metadata_ = nullptr; 407 }; 408 409 // A special implementation of InterceptorBatchMethods to send a Cancel 410 // notification down the interceptor stack 411 class CancelInterceptorBatchMethods 412 : public experimental::InterceptorBatchMethods { 413 public: QueryInterceptionHookPoint(experimental::InterceptionHookPoints type)414 bool QueryInterceptionHookPoint( 415 experimental::InterceptionHookPoints type) override { 416 return type == experimental::InterceptionHookPoints::PRE_SEND_CANCEL; 417 } 418 Proceed()419 void Proceed() override { 420 // This is a no-op. For actual continuation of the RPC simply needs to 421 // return from the Intercept method 422 } 423 Hijack()424 void Hijack() override { 425 // Only the client can hijack when sending down initial metadata 426 ABSL_CHECK(false) << "It is illegal to call Hijack on a method which has a " 427 "Cancel notification"; 428 } 429 GetSerializedSendMessage()430 ByteBuffer* GetSerializedSendMessage() override { 431 ABSL_CHECK(false) 432 << "It is illegal to call GetSendMessage on a method which " 433 "has a Cancel notification"; 434 return nullptr; 435 } 436 GetSendMessageStatus()437 bool GetSendMessageStatus() override { 438 ABSL_CHECK(false) 439 << "It is illegal to call GetSendMessageStatus on a method which " 440 "has a Cancel notification"; 441 return false; 442 } 443 GetSendMessage()444 const void* GetSendMessage() override { 445 ABSL_CHECK(false) 446 << "It is illegal to call GetOriginalSendMessage on a method which " 447 "has a Cancel notification"; 448 return nullptr; 449 } 450 ModifySendMessage(const void *)451 void ModifySendMessage(const void* /*message*/) override { 452 ABSL_CHECK(false) 453 << "It is illegal to call ModifySendMessage on a method which " 454 "has a Cancel notification"; 455 } 456 GetSendInitialMetadata()457 std::multimap<std::string, std::string>* GetSendInitialMetadata() override { 458 ABSL_CHECK(false) << "It is illegal to call GetSendInitialMetadata on a " 459 "method which has a Cancel notification"; 460 return nullptr; 461 } 462 GetSendStatus()463 Status GetSendStatus() override { 464 ABSL_CHECK(false) 465 << "It is illegal to call GetSendStatus on a method which " 466 "has a Cancel notification"; 467 return Status(); 468 } 469 ModifySendStatus(const Status &)470 void ModifySendStatus(const Status& /*status*/) override { 471 ABSL_CHECK(false) << "It is illegal to call ModifySendStatus on a method " 472 "which has a Cancel notification"; 473 } 474 GetSendTrailingMetadata()475 std::multimap<std::string, std::string>* GetSendTrailingMetadata() override { 476 ABSL_CHECK(false) << "It is illegal to call GetSendTrailingMetadata on a " 477 "method which has a Cancel notification"; 478 return nullptr; 479 } 480 GetRecvMessage()481 void* GetRecvMessage() override { 482 ABSL_CHECK(false) 483 << "It is illegal to call GetRecvMessage on a method which " 484 "has a Cancel notification"; 485 return nullptr; 486 } 487 GetRecvInitialMetadata()488 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata() 489 override { 490 ABSL_CHECK(false) << "It is illegal to call GetRecvInitialMetadata on a " 491 "method which has a Cancel notification"; 492 return nullptr; 493 } 494 GetRecvStatus()495 Status* GetRecvStatus() override { 496 ABSL_CHECK(false) 497 << "It is illegal to call GetRecvStatus on a method which " 498 "has a Cancel notification"; 499 return nullptr; 500 } 501 GetRecvTrailingMetadata()502 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata() 503 override { 504 ABSL_CHECK(false) << "It is illegal to call GetRecvTrailingMetadata on a " 505 "method which has a Cancel notification"; 506 return nullptr; 507 } 508 GetInterceptedChannel()509 std::unique_ptr<ChannelInterface> GetInterceptedChannel() override { 510 ABSL_CHECK(false) << "It is illegal to call GetInterceptedChannel on a " 511 "method which has a Cancel notification"; 512 return std::unique_ptr<ChannelInterface>(nullptr); 513 } 514 FailHijackedRecvMessage()515 void FailHijackedRecvMessage() override { 516 ABSL_CHECK(false) << "It is illegal to call FailHijackedRecvMessage on a " 517 "method which has a Cancel notification"; 518 } 519 FailHijackedSendMessage()520 void FailHijackedSendMessage() override { 521 ABSL_CHECK(false) << "It is illegal to call FailHijackedSendMessage on a " 522 "method which has a Cancel notification"; 523 } 524 }; 525 } // namespace internal 526 } // namespace grpc 527 528 #endif // GRPCPP_IMPL_INTERCEPTOR_COMMON_H 529