1 /* 2 * 3 * Copyright 2018 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 #ifndef GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H 20 #define GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H 21 22 #include <array> 23 #include <functional> 24 25 #include <grpcpp/impl/codegen/call.h> 26 #include <grpcpp/impl/codegen/call_op_set_interface.h> 27 #include <grpcpp/impl/codegen/client_interceptor.h> 28 #include <grpcpp/impl/codegen/intercepted_channel.h> 29 #include <grpcpp/impl/codegen/server_interceptor.h> 30 31 #include <grpc/impl/codegen/grpc_types.h> 32 33 namespace grpc { 34 namespace internal { 35 36 class InterceptorBatchMethodsImpl 37 : public experimental::InterceptorBatchMethods { 38 public: InterceptorBatchMethodsImpl()39 InterceptorBatchMethodsImpl() { 40 for (auto i = static_cast<experimental::InterceptionHookPoints>(0); 41 i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS; 42 i = static_cast<experimental::InterceptionHookPoints>( 43 static_cast<size_t>(i) + 1)) { 44 hooks_[static_cast<size_t>(i)] = false; 45 } 46 } 47 ~InterceptorBatchMethodsImpl()48 ~InterceptorBatchMethodsImpl() {} 49 QueryInterceptionHookPoint(experimental::InterceptionHookPoints type)50 bool QueryInterceptionHookPoint( 51 experimental::InterceptionHookPoints type) override { 52 return hooks_[static_cast<size_t>(type)]; 53 } 54 Proceed()55 void Proceed() override { 56 if (call_->client_rpc_info() != nullptr) { 57 return ProceedClient(); 58 } 59 GPR_CODEGEN_ASSERT(call_->server_rpc_info() != nullptr); 60 ProceedServer(); 61 } 62 Hijack()63 void Hijack() override { 64 // Only the client can hijack when sending down initial metadata 65 GPR_CODEGEN_ASSERT(!reverse_ && ops_ != nullptr && 66 call_->client_rpc_info() != nullptr); 67 // It is illegal to call Hijack twice 68 GPR_CODEGEN_ASSERT(!ran_hijacking_interceptor_); 69 auto* rpc_info = call_->client_rpc_info(); 70 rpc_info->hijacked_ = true; 71 rpc_info->hijacked_interceptor_ = current_interceptor_index_; 72 ClearHookPoints(); 73 ops_->SetHijackingState(); 74 ran_hijacking_interceptor_ = true; 75 rpc_info->RunInterceptor(this, current_interceptor_index_); 76 } 77 AddInterceptionHookPoint(experimental::InterceptionHookPoints type)78 void AddInterceptionHookPoint(experimental::InterceptionHookPoints type) { 79 hooks_[static_cast<size_t>(type)] = true; 80 } 81 GetSerializedSendMessage()82 ByteBuffer* GetSerializedSendMessage() override { 83 GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr); 84 if (*orig_send_message_ != nullptr) { 85 GPR_CODEGEN_ASSERT(serializer_(*orig_send_message_).ok()); 86 *orig_send_message_ = nullptr; 87 } 88 return send_message_; 89 } 90 GetSendMessage()91 const void* GetSendMessage() override { 92 GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr); 93 return *orig_send_message_; 94 } 95 ModifySendMessage(const void * message)96 void ModifySendMessage(const void* message) override { 97 GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr); 98 *orig_send_message_ = message; 99 } 100 GetSendMessageStatus()101 bool GetSendMessageStatus() override { return !*fail_send_message_; } 102 GetSendInitialMetadata()103 std::multimap<std::string, std::string>* GetSendInitialMetadata() override { 104 return send_initial_metadata_; 105 } 106 GetSendStatus()107 Status GetSendStatus() override { 108 return Status(static_cast<StatusCode>(*code_), *error_message_, 109 *error_details_); 110 } 111 ModifySendStatus(const Status & status)112 void ModifySendStatus(const Status& status) override { 113 *code_ = static_cast<grpc_status_code>(status.error_code()); 114 *error_details_ = status.error_details(); 115 *error_message_ = status.error_message(); 116 } 117 GetSendTrailingMetadata()118 std::multimap<std::string, std::string>* GetSendTrailingMetadata() override { 119 return send_trailing_metadata_; 120 } 121 GetRecvMessage()122 void* GetRecvMessage() override { return recv_message_; } 123 GetRecvInitialMetadata()124 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata() 125 override { 126 return recv_initial_metadata_->map(); 127 } 128 GetRecvStatus()129 Status* GetRecvStatus() override { return recv_status_; } 130 FailHijackedSendMessage()131 void FailHijackedSendMessage() override { 132 GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>( 133 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]); 134 *fail_send_message_ = true; 135 } 136 GetRecvTrailingMetadata()137 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata() 138 override { 139 return recv_trailing_metadata_->map(); 140 } 141 SetSendMessage(ByteBuffer * buf,const void ** msg,bool * fail_send_message,std::function<Status (const void *)> serializer)142 void SetSendMessage(ByteBuffer* buf, const void** msg, 143 bool* fail_send_message, 144 std::function<Status(const void*)> serializer) { 145 send_message_ = buf; 146 orig_send_message_ = msg; 147 fail_send_message_ = fail_send_message; 148 serializer_ = serializer; 149 } 150 SetSendInitialMetadata(std::multimap<std::string,std::string> * metadata)151 void SetSendInitialMetadata( 152 std::multimap<std::string, std::string>* metadata) { 153 send_initial_metadata_ = metadata; 154 } 155 SetSendStatus(grpc_status_code * code,std::string * error_details,std::string * error_message)156 void SetSendStatus(grpc_status_code* code, std::string* error_details, 157 std::string* error_message) { 158 code_ = code; 159 error_details_ = error_details; 160 error_message_ = error_message; 161 } 162 SetSendTrailingMetadata(std::multimap<std::string,std::string> * metadata)163 void SetSendTrailingMetadata( 164 std::multimap<std::string, std::string>* metadata) { 165 send_trailing_metadata_ = metadata; 166 } 167 SetRecvMessage(void * message,bool * hijacked_recv_message_failed)168 void SetRecvMessage(void* message, bool* hijacked_recv_message_failed) { 169 recv_message_ = message; 170 hijacked_recv_message_failed_ = hijacked_recv_message_failed; 171 } 172 SetRecvInitialMetadata(MetadataMap * map)173 void SetRecvInitialMetadata(MetadataMap* map) { 174 recv_initial_metadata_ = map; 175 } 176 SetRecvStatus(Status * status)177 void SetRecvStatus(Status* status) { recv_status_ = status; } 178 SetRecvTrailingMetadata(MetadataMap * map)179 void SetRecvTrailingMetadata(MetadataMap* map) { 180 recv_trailing_metadata_ = map; 181 } 182 GetInterceptedChannel()183 std::unique_ptr<ChannelInterface> GetInterceptedChannel() override { 184 auto* info = call_->client_rpc_info(); 185 if (info == nullptr) { 186 return std::unique_ptr<ChannelInterface>(nullptr); 187 } 188 // The intercepted channel starts from the interceptor just after the 189 // current interceptor 190 return std::unique_ptr<ChannelInterface>(new InterceptedChannel( 191 info->channel(), current_interceptor_index_ + 1)); 192 } 193 FailHijackedRecvMessage()194 void FailHijackedRecvMessage() override { 195 GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>( 196 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)]); 197 *hijacked_recv_message_failed_ = true; 198 } 199 200 // Clears all state ClearState()201 void ClearState() { 202 reverse_ = false; 203 ran_hijacking_interceptor_ = false; 204 ClearHookPoints(); 205 } 206 207 // Prepares for Post_recv operations SetReverse()208 void SetReverse() { 209 reverse_ = true; 210 ran_hijacking_interceptor_ = false; 211 ClearHookPoints(); 212 } 213 214 // This needs to be set before interceptors are run SetCall(Call * call)215 void SetCall(Call* call) { call_ = call; } 216 217 // This needs to be set before interceptors are run using RunInterceptors(). 218 // Alternatively, RunInterceptors(std::function<void(void)> f) can be used. SetCallOpSetInterface(CallOpSetInterface * ops)219 void SetCallOpSetInterface(CallOpSetInterface* ops) { ops_ = ops; } 220 221 // SetCall should have been called before this. 222 // Returns true if the interceptors list is empty InterceptorsListEmpty()223 bool InterceptorsListEmpty() { 224 auto* client_rpc_info = call_->client_rpc_info(); 225 if (client_rpc_info != nullptr) { 226 if (client_rpc_info->interceptors_.size() == 0) { 227 return true; 228 } else { 229 return false; 230 } 231 } 232 233 auto* server_rpc_info = call_->server_rpc_info(); 234 if (server_rpc_info == nullptr || 235 server_rpc_info->interceptors_.size() == 0) { 236 return true; 237 } 238 return false; 239 } 240 241 // This should be used only by subclasses of CallOpSetInterface. SetCall and 242 // SetCallOpSetInterface should have been called before this. After all the 243 // interceptors are done running, either ContinueFillOpsAfterInterception or 244 // ContinueFinalizeOpsAfterInterception will be called. Note that neither of 245 // them is invoked if there were no interceptors registered. RunInterceptors()246 bool RunInterceptors() { 247 GPR_CODEGEN_ASSERT(ops_); 248 auto* client_rpc_info = call_->client_rpc_info(); 249 if (client_rpc_info != nullptr) { 250 if (client_rpc_info->interceptors_.size() == 0) { 251 return true; 252 } else { 253 RunClientInterceptors(); 254 return false; 255 } 256 } 257 258 auto* server_rpc_info = call_->server_rpc_info(); 259 if (server_rpc_info == nullptr || 260 server_rpc_info->interceptors_.size() == 0) { 261 return true; 262 } 263 RunServerInterceptors(); 264 return false; 265 } 266 267 // Returns true if no interceptors are run. Returns false otherwise if there 268 // are interceptors registered. After the interceptors are done running \a f 269 // will be invoked. This is to be used only by BaseAsyncRequest and 270 // SyncRequest. RunInterceptors(std::function<void (void)> f)271 bool RunInterceptors(std::function<void(void)> f) { 272 // This is used only by the server for initial call request 273 GPR_CODEGEN_ASSERT(reverse_ == true); 274 GPR_CODEGEN_ASSERT(call_->client_rpc_info() == nullptr); 275 auto* server_rpc_info = call_->server_rpc_info(); 276 if (server_rpc_info == nullptr || 277 server_rpc_info->interceptors_.size() == 0) { 278 return true; 279 } 280 callback_ = std::move(f); 281 RunServerInterceptors(); 282 return false; 283 } 284 285 private: RunClientInterceptors()286 void RunClientInterceptors() { 287 auto* rpc_info = call_->client_rpc_info(); 288 if (!reverse_) { 289 current_interceptor_index_ = 0; 290 } else { 291 if (rpc_info->hijacked_) { 292 current_interceptor_index_ = rpc_info->hijacked_interceptor_; 293 } else { 294 current_interceptor_index_ = rpc_info->interceptors_.size() - 1; 295 } 296 } 297 rpc_info->RunInterceptor(this, current_interceptor_index_); 298 } 299 RunServerInterceptors()300 void RunServerInterceptors() { 301 auto* rpc_info = call_->server_rpc_info(); 302 if (!reverse_) { 303 current_interceptor_index_ = 0; 304 } else { 305 current_interceptor_index_ = rpc_info->interceptors_.size() - 1; 306 } 307 rpc_info->RunInterceptor(this, current_interceptor_index_); 308 } 309 ProceedClient()310 void ProceedClient() { 311 auto* rpc_info = call_->client_rpc_info(); 312 if (rpc_info->hijacked_ && !reverse_ && 313 current_interceptor_index_ == rpc_info->hijacked_interceptor_ && 314 !ran_hijacking_interceptor_) { 315 // We now need to provide hijacked recv ops to this interceptor 316 ClearHookPoints(); 317 ops_->SetHijackingState(); 318 ran_hijacking_interceptor_ = true; 319 rpc_info->RunInterceptor(this, current_interceptor_index_); 320 return; 321 } 322 if (!reverse_) { 323 current_interceptor_index_++; 324 // We are going down the stack of interceptors 325 if (current_interceptor_index_ < rpc_info->interceptors_.size()) { 326 if (rpc_info->hijacked_ && 327 current_interceptor_index_ > rpc_info->hijacked_interceptor_) { 328 // This is a hijacked RPC and we are done with hijacking 329 ops_->ContinueFillOpsAfterInterception(); 330 } else { 331 rpc_info->RunInterceptor(this, current_interceptor_index_); 332 } 333 } else { 334 // we are done running all the interceptors without any hijacking 335 ops_->ContinueFillOpsAfterInterception(); 336 } 337 } else { 338 // We are going up the stack of interceptors 339 if (current_interceptor_index_ > 0) { 340 // Continue running interceptors 341 current_interceptor_index_--; 342 rpc_info->RunInterceptor(this, current_interceptor_index_); 343 } else { 344 // we are done running all the interceptors without any hijacking 345 ops_->ContinueFinalizeResultAfterInterception(); 346 } 347 } 348 } 349 ProceedServer()350 void ProceedServer() { 351 auto* rpc_info = call_->server_rpc_info(); 352 if (!reverse_) { 353 current_interceptor_index_++; 354 if (current_interceptor_index_ < rpc_info->interceptors_.size()) { 355 return rpc_info->RunInterceptor(this, current_interceptor_index_); 356 } else if (ops_) { 357 return ops_->ContinueFillOpsAfterInterception(); 358 } 359 } else { 360 // We are going up the stack of interceptors 361 if (current_interceptor_index_ > 0) { 362 // Continue running interceptors 363 current_interceptor_index_--; 364 return rpc_info->RunInterceptor(this, current_interceptor_index_); 365 } else if (ops_) { 366 return ops_->ContinueFinalizeResultAfterInterception(); 367 } 368 } 369 GPR_CODEGEN_ASSERT(callback_); 370 callback_(); 371 } 372 ClearHookPoints()373 void ClearHookPoints() { 374 for (auto i = static_cast<experimental::InterceptionHookPoints>(0); 375 i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS; 376 i = static_cast<experimental::InterceptionHookPoints>( 377 static_cast<size_t>(i) + 1)) { 378 hooks_[static_cast<size_t>(i)] = false; 379 } 380 } 381 382 std::array<bool, 383 static_cast<size_t>( 384 experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)> 385 hooks_; 386 387 size_t current_interceptor_index_ = 0; // Current iterator 388 bool reverse_ = false; 389 bool ran_hijacking_interceptor_ = false; 390 Call* call_ = nullptr; // The Call object is present along with CallOpSet 391 // object/callback 392 CallOpSetInterface* ops_ = nullptr; 393 std::function<void(void)> callback_; 394 395 ByteBuffer* send_message_ = nullptr; 396 bool* fail_send_message_ = nullptr; 397 const void** orig_send_message_ = nullptr; 398 std::function<Status(const void*)> serializer_; 399 400 std::multimap<std::string, std::string>* send_initial_metadata_; 401 402 grpc_status_code* code_ = nullptr; 403 std::string* error_details_ = nullptr; 404 std::string* error_message_ = nullptr; 405 406 std::multimap<std::string, std::string>* send_trailing_metadata_ = nullptr; 407 408 void* recv_message_ = nullptr; 409 bool* hijacked_recv_message_failed_ = nullptr; 410 411 MetadataMap* recv_initial_metadata_ = nullptr; 412 413 Status* recv_status_ = nullptr; 414 415 MetadataMap* recv_trailing_metadata_ = nullptr; 416 }; 417 418 // A special implementation of InterceptorBatchMethods to send a Cancel 419 // notification down the interceptor stack 420 class CancelInterceptorBatchMethods 421 : public experimental::InterceptorBatchMethods { 422 public: QueryInterceptionHookPoint(experimental::InterceptionHookPoints type)423 bool QueryInterceptionHookPoint( 424 experimental::InterceptionHookPoints type) override { 425 if (type == experimental::InterceptionHookPoints::PRE_SEND_CANCEL) { 426 return true; 427 } else { 428 return false; 429 } 430 } 431 Proceed()432 void Proceed() override { 433 // This is a no-op. For actual continuation of the RPC simply needs to 434 // return from the Intercept method 435 } 436 Hijack()437 void Hijack() override { 438 // Only the client can hijack when sending down initial metadata 439 GPR_CODEGEN_ASSERT(false && 440 "It is illegal to call Hijack on a method which has a " 441 "Cancel notification"); 442 } 443 GetSerializedSendMessage()444 ByteBuffer* GetSerializedSendMessage() override { 445 GPR_CODEGEN_ASSERT(false && 446 "It is illegal to call GetSendMessage on a method which " 447 "has a Cancel notification"); 448 return nullptr; 449 } 450 GetSendMessageStatus()451 bool GetSendMessageStatus() override { 452 GPR_CODEGEN_ASSERT( 453 false && 454 "It is illegal to call GetSendMessageStatus on a method which " 455 "has a Cancel notification"); 456 return false; 457 } 458 GetSendMessage()459 const void* GetSendMessage() override { 460 GPR_CODEGEN_ASSERT( 461 false && 462 "It is illegal to call GetOriginalSendMessage on a method which " 463 "has a Cancel notification"); 464 return nullptr; 465 } 466 ModifySendMessage(const void *)467 void ModifySendMessage(const void* /*message*/) override { 468 GPR_CODEGEN_ASSERT( 469 false && 470 "It is illegal to call ModifySendMessage on a method which " 471 "has a Cancel notification"); 472 } 473 GetSendInitialMetadata()474 std::multimap<std::string, std::string>* GetSendInitialMetadata() override { 475 GPR_CODEGEN_ASSERT(false && 476 "It is illegal to call GetSendInitialMetadata on a " 477 "method which has a Cancel notification"); 478 return nullptr; 479 } 480 GetSendStatus()481 Status GetSendStatus() override { 482 GPR_CODEGEN_ASSERT(false && 483 "It is illegal to call GetSendStatus on a method which " 484 "has a Cancel notification"); 485 return Status(); 486 } 487 ModifySendStatus(const Status &)488 void ModifySendStatus(const Status& /*status*/) override { 489 GPR_CODEGEN_ASSERT(false && 490 "It is illegal to call ModifySendStatus on a method " 491 "which has a Cancel notification"); 492 return; 493 } 494 GetSendTrailingMetadata()495 std::multimap<std::string, std::string>* GetSendTrailingMetadata() override { 496 GPR_CODEGEN_ASSERT(false && 497 "It is illegal to call GetSendTrailingMetadata on a " 498 "method which has a Cancel notification"); 499 return nullptr; 500 } 501 GetRecvMessage()502 void* GetRecvMessage() override { 503 GPR_CODEGEN_ASSERT(false && 504 "It is illegal to call GetRecvMessage on a method which " 505 "has a Cancel notification"); 506 return nullptr; 507 } 508 GetRecvInitialMetadata()509 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata() 510 override { 511 GPR_CODEGEN_ASSERT(false && 512 "It is illegal to call GetRecvInitialMetadata on a " 513 "method which has a Cancel notification"); 514 return nullptr; 515 } 516 GetRecvStatus()517 Status* GetRecvStatus() override { 518 GPR_CODEGEN_ASSERT(false && 519 "It is illegal to call GetRecvStatus on a method which " 520 "has a Cancel notification"); 521 return nullptr; 522 } 523 GetRecvTrailingMetadata()524 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata() 525 override { 526 GPR_CODEGEN_ASSERT(false && 527 "It is illegal to call GetRecvTrailingMetadata on a " 528 "method which has a Cancel notification"); 529 return nullptr; 530 } 531 GetInterceptedChannel()532 std::unique_ptr<ChannelInterface> GetInterceptedChannel() override { 533 GPR_CODEGEN_ASSERT(false && 534 "It is illegal to call GetInterceptedChannel on a " 535 "method which has a Cancel notification"); 536 return std::unique_ptr<ChannelInterface>(nullptr); 537 } 538 FailHijackedRecvMessage()539 void FailHijackedRecvMessage() override { 540 GPR_CODEGEN_ASSERT(false && 541 "It is illegal to call FailHijackedRecvMessage on a " 542 "method which has a Cancel notification"); 543 } 544 FailHijackedSendMessage()545 void FailHijackedSendMessage() override { 546 GPR_CODEGEN_ASSERT(false && 547 "It is illegal to call FailHijackedSendMessage on a " 548 "method which has a Cancel notification"); 549 } 550 }; 551 } // namespace internal 552 } // namespace grpc 553 554 #endif // GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H 555