1 /* 2 * 3 * Copyright 2018 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 #ifndef GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H 20 #define GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H 21 22 #include <array> 23 #include <functional> 24 25 #include <grpcpp/impl/codegen/call.h> 26 #include <grpcpp/impl/codegen/call_op_set_interface.h> 27 #include <grpcpp/impl/codegen/client_interceptor.h> 28 #include <grpcpp/impl/codegen/intercepted_channel.h> 29 #include <grpcpp/impl/codegen/server_interceptor.h> 30 31 #include <grpc/impl/codegen/grpc_types.h> 32 33 namespace grpc { 34 namespace internal { 35 36 class InterceptorBatchMethodsImpl 37 : public experimental::InterceptorBatchMethods { 38 public: InterceptorBatchMethodsImpl()39 InterceptorBatchMethodsImpl() { 40 for (auto i = static_cast<experimental::InterceptionHookPoints>(0); 41 i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS; 42 i = static_cast<experimental::InterceptionHookPoints>( 43 static_cast<size_t>(i) + 1)) { 44 hooks_[static_cast<size_t>(i)] = false; 45 } 46 } 47 ~InterceptorBatchMethodsImpl()48 ~InterceptorBatchMethodsImpl() override {} 49 QueryInterceptionHookPoint(experimental::InterceptionHookPoints type)50 bool QueryInterceptionHookPoint( 51 experimental::InterceptionHookPoints type) override { 52 return hooks_[static_cast<size_t>(type)]; 53 } 54 Proceed()55 void Proceed() override { 56 if (call_->client_rpc_info() != nullptr) { 57 return ProceedClient(); 58 } 59 GPR_CODEGEN_ASSERT(call_->server_rpc_info() != nullptr); 60 ProceedServer(); 61 } 62 Hijack()63 void Hijack() override { 64 // Only the client can hijack when sending down initial metadata 65 GPR_CODEGEN_ASSERT(!reverse_ && ops_ != nullptr && 66 call_->client_rpc_info() != nullptr); 67 // It is illegal to call Hijack twice 68 GPR_CODEGEN_ASSERT(!ran_hijacking_interceptor_); 69 auto* rpc_info = call_->client_rpc_info(); 70 rpc_info->hijacked_ = true; 71 rpc_info->hijacked_interceptor_ = current_interceptor_index_; 72 ClearHookPoints(); 73 ops_->SetHijackingState(); 74 ran_hijacking_interceptor_ = true; 75 rpc_info->RunInterceptor(this, current_interceptor_index_); 76 } 77 AddInterceptionHookPoint(experimental::InterceptionHookPoints type)78 void AddInterceptionHookPoint(experimental::InterceptionHookPoints type) { 79 hooks_[static_cast<size_t>(type)] = true; 80 } 81 GetSerializedSendMessage()82 ByteBuffer* GetSerializedSendMessage() override { 83 GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr); 84 if (*orig_send_message_ != nullptr) { 85 GPR_CODEGEN_ASSERT(serializer_(*orig_send_message_).ok()); 86 *orig_send_message_ = nullptr; 87 } 88 return send_message_; 89 } 90 GetSendMessage()91 const void* GetSendMessage() override { 92 GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr); 93 return *orig_send_message_; 94 } 95 ModifySendMessage(const void * message)96 void ModifySendMessage(const void* message) override { 97 GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr); 98 *orig_send_message_ = message; 99 } 100 GetSendMessageStatus()101 bool GetSendMessageStatus() override { return !*fail_send_message_; } 102 GetSendInitialMetadata()103 std::multimap<std::string, std::string>* GetSendInitialMetadata() override { 104 return send_initial_metadata_; 105 } 106 GetSendStatus()107 Status GetSendStatus() override { 108 return Status(static_cast<StatusCode>(*code_), *error_message_, 109 *error_details_); 110 } 111 ModifySendStatus(const Status & status)112 void ModifySendStatus(const Status& status) override { 113 *code_ = static_cast<grpc_status_code>(status.error_code()); 114 *error_details_ = status.error_details(); 115 *error_message_ = status.error_message(); 116 } 117 GetSendTrailingMetadata()118 std::multimap<std::string, std::string>* GetSendTrailingMetadata() override { 119 return send_trailing_metadata_; 120 } 121 GetRecvMessage()122 void* GetRecvMessage() override { return recv_message_; } 123 GetRecvInitialMetadata()124 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata() 125 override { 126 return recv_initial_metadata_->map(); 127 } 128 GetRecvStatus()129 Status* GetRecvStatus() override { return recv_status_; } 130 FailHijackedSendMessage()131 void FailHijackedSendMessage() override { 132 GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>( 133 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]); 134 *fail_send_message_ = true; 135 } 136 GetRecvTrailingMetadata()137 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata() 138 override { 139 return recv_trailing_metadata_->map(); 140 } 141 SetSendMessage(ByteBuffer * buf,const void ** msg,bool * fail_send_message,std::function<Status (const void *)> serializer)142 void SetSendMessage(ByteBuffer* buf, const void** msg, 143 bool* fail_send_message, 144 std::function<Status(const void*)> serializer) { 145 send_message_ = buf; 146 orig_send_message_ = msg; 147 fail_send_message_ = fail_send_message; 148 serializer_ = serializer; 149 } 150 SetSendInitialMetadata(std::multimap<std::string,std::string> * metadata)151 void SetSendInitialMetadata( 152 std::multimap<std::string, std::string>* metadata) { 153 send_initial_metadata_ = metadata; 154 } 155 SetSendStatus(grpc_status_code * code,std::string * error_details,std::string * error_message)156 void SetSendStatus(grpc_status_code* code, std::string* error_details, 157 std::string* error_message) { 158 code_ = code; 159 error_details_ = error_details; 160 error_message_ = error_message; 161 } 162 SetSendTrailingMetadata(std::multimap<std::string,std::string> * metadata)163 void SetSendTrailingMetadata( 164 std::multimap<std::string, std::string>* metadata) { 165 send_trailing_metadata_ = metadata; 166 } 167 SetRecvMessage(void * message,bool * hijacked_recv_message_failed)168 void SetRecvMessage(void* message, bool* hijacked_recv_message_failed) { 169 recv_message_ = message; 170 hijacked_recv_message_failed_ = hijacked_recv_message_failed; 171 } 172 SetRecvInitialMetadata(MetadataMap * map)173 void SetRecvInitialMetadata(MetadataMap* map) { 174 recv_initial_metadata_ = map; 175 } 176 SetRecvStatus(Status * status)177 void SetRecvStatus(Status* status) { recv_status_ = status; } 178 SetRecvTrailingMetadata(MetadataMap * map)179 void SetRecvTrailingMetadata(MetadataMap* map) { 180 recv_trailing_metadata_ = map; 181 } 182 GetInterceptedChannel()183 std::unique_ptr<ChannelInterface> GetInterceptedChannel() override { 184 auto* info = call_->client_rpc_info(); 185 if (info == nullptr) { 186 return std::unique_ptr<ChannelInterface>(nullptr); 187 } 188 // The intercepted channel starts from the interceptor just after the 189 // current interceptor 190 return std::unique_ptr<ChannelInterface>(new InterceptedChannel( 191 info->channel(), current_interceptor_index_ + 1)); 192 } 193 FailHijackedRecvMessage()194 void FailHijackedRecvMessage() override { 195 GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>( 196 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)]); 197 *hijacked_recv_message_failed_ = true; 198 } 199 200 // Clears all state ClearState()201 void ClearState() { 202 reverse_ = false; 203 ran_hijacking_interceptor_ = false; 204 ClearHookPoints(); 205 } 206 207 // Prepares for Post_recv operations SetReverse()208 void SetReverse() { 209 reverse_ = true; 210 ran_hijacking_interceptor_ = false; 211 ClearHookPoints(); 212 } 213 214 // This needs to be set before interceptors are run SetCall(Call * call)215 void SetCall(Call* call) { call_ = call; } 216 217 // This needs to be set before interceptors are run using RunInterceptors(). 218 // Alternatively, RunInterceptors(std::function<void(void)> f) can be used. SetCallOpSetInterface(CallOpSetInterface * ops)219 void SetCallOpSetInterface(CallOpSetInterface* ops) { ops_ = ops; } 220 221 // SetCall should have been called before this. 222 // Returns true if the interceptors list is empty InterceptorsListEmpty()223 bool InterceptorsListEmpty() { 224 auto* client_rpc_info = call_->client_rpc_info(); 225 if (client_rpc_info != nullptr) { 226 if (client_rpc_info->interceptors_.empty()) { 227 return true; 228 } else { 229 return false; 230 } 231 } 232 233 auto* server_rpc_info = call_->server_rpc_info(); 234 if (server_rpc_info == nullptr || server_rpc_info->interceptors_.empty()) { 235 return true; 236 } 237 return false; 238 } 239 240 // This should be used only by subclasses of CallOpSetInterface. SetCall and 241 // SetCallOpSetInterface should have been called before this. After all the 242 // interceptors are done running, either ContinueFillOpsAfterInterception or 243 // ContinueFinalizeOpsAfterInterception will be called. Note that neither of 244 // them is invoked if there were no interceptors registered. RunInterceptors()245 bool RunInterceptors() { 246 GPR_CODEGEN_ASSERT(ops_); 247 auto* client_rpc_info = call_->client_rpc_info(); 248 if (client_rpc_info != nullptr) { 249 if (client_rpc_info->interceptors_.empty()) { 250 return true; 251 } else { 252 RunClientInterceptors(); 253 return false; 254 } 255 } 256 257 auto* server_rpc_info = call_->server_rpc_info(); 258 if (server_rpc_info == nullptr || server_rpc_info->interceptors_.empty()) { 259 return true; 260 } 261 RunServerInterceptors(); 262 return false; 263 } 264 265 // Returns true if no interceptors are run. Returns false otherwise if there 266 // are interceptors registered. After the interceptors are done running \a f 267 // will be invoked. This is to be used only by BaseAsyncRequest and 268 // SyncRequest. RunInterceptors(std::function<void (void)> f)269 bool RunInterceptors(std::function<void(void)> f) { 270 // This is used only by the server for initial call request 271 GPR_CODEGEN_ASSERT(reverse_ == true); 272 GPR_CODEGEN_ASSERT(call_->client_rpc_info() == nullptr); 273 auto* server_rpc_info = call_->server_rpc_info(); 274 if (server_rpc_info == nullptr || server_rpc_info->interceptors_.empty()) { 275 return true; 276 } 277 callback_ = std::move(f); 278 RunServerInterceptors(); 279 return false; 280 } 281 282 private: RunClientInterceptors()283 void RunClientInterceptors() { 284 auto* rpc_info = call_->client_rpc_info(); 285 if (!reverse_) { 286 current_interceptor_index_ = 0; 287 } else { 288 if (rpc_info->hijacked_) { 289 current_interceptor_index_ = rpc_info->hijacked_interceptor_; 290 } else { 291 current_interceptor_index_ = rpc_info->interceptors_.size() - 1; 292 } 293 } 294 rpc_info->RunInterceptor(this, current_interceptor_index_); 295 } 296 RunServerInterceptors()297 void RunServerInterceptors() { 298 auto* rpc_info = call_->server_rpc_info(); 299 if (!reverse_) { 300 current_interceptor_index_ = 0; 301 } else { 302 current_interceptor_index_ = rpc_info->interceptors_.size() - 1; 303 } 304 rpc_info->RunInterceptor(this, current_interceptor_index_); 305 } 306 ProceedClient()307 void ProceedClient() { 308 auto* rpc_info = call_->client_rpc_info(); 309 if (rpc_info->hijacked_ && !reverse_ && 310 current_interceptor_index_ == rpc_info->hijacked_interceptor_ && 311 !ran_hijacking_interceptor_) { 312 // We now need to provide hijacked recv ops to this interceptor 313 ClearHookPoints(); 314 ops_->SetHijackingState(); 315 ran_hijacking_interceptor_ = true; 316 rpc_info->RunInterceptor(this, current_interceptor_index_); 317 return; 318 } 319 if (!reverse_) { 320 current_interceptor_index_++; 321 // We are going down the stack of interceptors 322 if (current_interceptor_index_ < rpc_info->interceptors_.size()) { 323 if (rpc_info->hijacked_ && 324 current_interceptor_index_ > rpc_info->hijacked_interceptor_) { 325 // This is a hijacked RPC and we are done with hijacking 326 ops_->ContinueFillOpsAfterInterception(); 327 } else { 328 rpc_info->RunInterceptor(this, current_interceptor_index_); 329 } 330 } else { 331 // we are done running all the interceptors without any hijacking 332 ops_->ContinueFillOpsAfterInterception(); 333 } 334 } else { 335 // We are going up the stack of interceptors 336 if (current_interceptor_index_ > 0) { 337 // Continue running interceptors 338 current_interceptor_index_--; 339 rpc_info->RunInterceptor(this, current_interceptor_index_); 340 } else { 341 // we are done running all the interceptors without any hijacking 342 ops_->ContinueFinalizeResultAfterInterception(); 343 } 344 } 345 } 346 ProceedServer()347 void ProceedServer() { 348 auto* rpc_info = call_->server_rpc_info(); 349 if (!reverse_) { 350 current_interceptor_index_++; 351 if (current_interceptor_index_ < rpc_info->interceptors_.size()) { 352 return rpc_info->RunInterceptor(this, current_interceptor_index_); 353 } else if (ops_) { 354 return ops_->ContinueFillOpsAfterInterception(); 355 } 356 } else { 357 // We are going up the stack of interceptors 358 if (current_interceptor_index_ > 0) { 359 // Continue running interceptors 360 current_interceptor_index_--; 361 return rpc_info->RunInterceptor(this, current_interceptor_index_); 362 } else if (ops_) { 363 return ops_->ContinueFinalizeResultAfterInterception(); 364 } 365 } 366 GPR_CODEGEN_ASSERT(callback_); 367 callback_(); 368 } 369 ClearHookPoints()370 void ClearHookPoints() { 371 for (auto i = static_cast<experimental::InterceptionHookPoints>(0); 372 i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS; 373 i = static_cast<experimental::InterceptionHookPoints>( 374 static_cast<size_t>(i) + 1)) { 375 hooks_[static_cast<size_t>(i)] = false; 376 } 377 } 378 379 std::array<bool, 380 static_cast<size_t>( 381 experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)> 382 hooks_; 383 384 size_t current_interceptor_index_ = 0; // Current iterator 385 bool reverse_ = false; 386 bool ran_hijacking_interceptor_ = false; 387 Call* call_ = nullptr; // The Call object is present along with CallOpSet 388 // object/callback 389 CallOpSetInterface* ops_ = nullptr; 390 std::function<void(void)> callback_; 391 392 ByteBuffer* send_message_ = nullptr; 393 bool* fail_send_message_ = nullptr; 394 const void** orig_send_message_ = nullptr; 395 std::function<Status(const void*)> serializer_; 396 397 std::multimap<std::string, std::string>* send_initial_metadata_; 398 399 grpc_status_code* code_ = nullptr; 400 std::string* error_details_ = nullptr; 401 std::string* error_message_ = nullptr; 402 403 std::multimap<std::string, std::string>* send_trailing_metadata_ = nullptr; 404 405 void* recv_message_ = nullptr; 406 bool* hijacked_recv_message_failed_ = nullptr; 407 408 MetadataMap* recv_initial_metadata_ = nullptr; 409 410 Status* recv_status_ = nullptr; 411 412 MetadataMap* recv_trailing_metadata_ = nullptr; 413 }; 414 415 // A special implementation of InterceptorBatchMethods to send a Cancel 416 // notification down the interceptor stack 417 class CancelInterceptorBatchMethods 418 : public experimental::InterceptorBatchMethods { 419 public: QueryInterceptionHookPoint(experimental::InterceptionHookPoints type)420 bool QueryInterceptionHookPoint( 421 experimental::InterceptionHookPoints type) override { 422 if (type == experimental::InterceptionHookPoints::PRE_SEND_CANCEL) { 423 return true; 424 } else { 425 return false; 426 } 427 } 428 Proceed()429 void Proceed() override { 430 // This is a no-op. For actual continuation of the RPC simply needs to 431 // return from the Intercept method 432 } 433 Hijack()434 void Hijack() override { 435 // Only the client can hijack when sending down initial metadata 436 GPR_CODEGEN_ASSERT(false && 437 "It is illegal to call Hijack on a method which has a " 438 "Cancel notification"); 439 } 440 GetSerializedSendMessage()441 ByteBuffer* GetSerializedSendMessage() override { 442 GPR_CODEGEN_ASSERT(false && 443 "It is illegal to call GetSendMessage on a method which " 444 "has a Cancel notification"); 445 return nullptr; 446 } 447 GetSendMessageStatus()448 bool GetSendMessageStatus() override { 449 GPR_CODEGEN_ASSERT( 450 false && 451 "It is illegal to call GetSendMessageStatus on a method which " 452 "has a Cancel notification"); 453 return false; 454 } 455 GetSendMessage()456 const void* GetSendMessage() override { 457 GPR_CODEGEN_ASSERT( 458 false && 459 "It is illegal to call GetOriginalSendMessage on a method which " 460 "has a Cancel notification"); 461 return nullptr; 462 } 463 ModifySendMessage(const void *)464 void ModifySendMessage(const void* /*message*/) override { 465 GPR_CODEGEN_ASSERT( 466 false && 467 "It is illegal to call ModifySendMessage on a method which " 468 "has a Cancel notification"); 469 } 470 GetSendInitialMetadata()471 std::multimap<std::string, std::string>* GetSendInitialMetadata() override { 472 GPR_CODEGEN_ASSERT(false && 473 "It is illegal to call GetSendInitialMetadata on a " 474 "method which has a Cancel notification"); 475 return nullptr; 476 } 477 GetSendStatus()478 Status GetSendStatus() override { 479 GPR_CODEGEN_ASSERT(false && 480 "It is illegal to call GetSendStatus on a method which " 481 "has a Cancel notification"); 482 return Status(); 483 } 484 ModifySendStatus(const Status &)485 void ModifySendStatus(const Status& /*status*/) override { 486 GPR_CODEGEN_ASSERT(false && 487 "It is illegal to call ModifySendStatus on a method " 488 "which has a Cancel notification"); 489 } 490 GetSendTrailingMetadata()491 std::multimap<std::string, std::string>* GetSendTrailingMetadata() override { 492 GPR_CODEGEN_ASSERT(false && 493 "It is illegal to call GetSendTrailingMetadata on a " 494 "method which has a Cancel notification"); 495 return nullptr; 496 } 497 GetRecvMessage()498 void* GetRecvMessage() override { 499 GPR_CODEGEN_ASSERT(false && 500 "It is illegal to call GetRecvMessage on a method which " 501 "has a Cancel notification"); 502 return nullptr; 503 } 504 GetRecvInitialMetadata()505 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata() 506 override { 507 GPR_CODEGEN_ASSERT(false && 508 "It is illegal to call GetRecvInitialMetadata on a " 509 "method which has a Cancel notification"); 510 return nullptr; 511 } 512 GetRecvStatus()513 Status* GetRecvStatus() override { 514 GPR_CODEGEN_ASSERT(false && 515 "It is illegal to call GetRecvStatus on a method which " 516 "has a Cancel notification"); 517 return nullptr; 518 } 519 GetRecvTrailingMetadata()520 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata() 521 override { 522 GPR_CODEGEN_ASSERT(false && 523 "It is illegal to call GetRecvTrailingMetadata on a " 524 "method which has a Cancel notification"); 525 return nullptr; 526 } 527 GetInterceptedChannel()528 std::unique_ptr<ChannelInterface> GetInterceptedChannel() override { 529 GPR_CODEGEN_ASSERT(false && 530 "It is illegal to call GetInterceptedChannel on a " 531 "method which has a Cancel notification"); 532 return std::unique_ptr<ChannelInterface>(nullptr); 533 } 534 FailHijackedRecvMessage()535 void FailHijackedRecvMessage() override { 536 GPR_CODEGEN_ASSERT(false && 537 "It is illegal to call FailHijackedRecvMessage on a " 538 "method which has a Cancel notification"); 539 } 540 FailHijackedSendMessage()541 void FailHijackedSendMessage() override { 542 GPR_CODEGEN_ASSERT(false && 543 "It is illegal to call FailHijackedSendMessage on a " 544 "method which has a Cancel notification"); 545 } 546 }; 547 } // namespace internal 548 } // namespace grpc 549 550 #endif // GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H 551