1 /* 2 * 3 * Copyright 2018 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 #ifndef GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H 20 #define GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H 21 22 #include <array> 23 #include <functional> 24 25 #include <grpcpp/impl/codegen/call.h> 26 #include <grpcpp/impl/codegen/call_op_set_interface.h> 27 #include <grpcpp/impl/codegen/client_interceptor.h> 28 #include <grpcpp/impl/codegen/intercepted_channel.h> 29 #include <grpcpp/impl/codegen/server_interceptor.h> 30 31 #include <grpc/impl/codegen/grpc_types.h> 32 33 namespace grpc { 34 namespace internal { 35 36 class InterceptorBatchMethodsImpl 37 : public experimental::InterceptorBatchMethods { 38 public: InterceptorBatchMethodsImpl()39 InterceptorBatchMethodsImpl() { 40 for (auto i = static_cast<experimental::InterceptionHookPoints>(0); 41 i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS; 42 i = static_cast<experimental::InterceptionHookPoints>( 43 static_cast<size_t>(i) + 1)) { 44 hooks_[static_cast<size_t>(i)] = false; 45 } 46 } 47 ~InterceptorBatchMethodsImpl()48 ~InterceptorBatchMethodsImpl() override {} 49 QueryInterceptionHookPoint(experimental::InterceptionHookPoints type)50 bool QueryInterceptionHookPoint( 51 experimental::InterceptionHookPoints type) override { 52 return hooks_[static_cast<size_t>(type)]; 53 } 54 Proceed()55 void Proceed() override { 56 if (call_->client_rpc_info() != nullptr) { 57 return ProceedClient(); 58 } 59 GPR_CODEGEN_ASSERT(call_->server_rpc_info() != nullptr); 60 ProceedServer(); 61 } 62 Hijack()63 void Hijack() override { 64 // Only the client can hijack when sending down initial metadata 65 GPR_CODEGEN_ASSERT(!reverse_ && ops_ != nullptr && 66 call_->client_rpc_info() != nullptr); 67 // It is illegal to call Hijack twice 68 GPR_CODEGEN_ASSERT(!ran_hijacking_interceptor_); 69 auto* rpc_info = call_->client_rpc_info(); 70 rpc_info->hijacked_ = true; 71 rpc_info->hijacked_interceptor_ = current_interceptor_index_; 72 ClearHookPoints(); 73 ops_->SetHijackingState(); 74 ran_hijacking_interceptor_ = true; 75 rpc_info->RunInterceptor(this, current_interceptor_index_); 76 } 77 AddInterceptionHookPoint(experimental::InterceptionHookPoints type)78 void AddInterceptionHookPoint(experimental::InterceptionHookPoints type) { 79 hooks_[static_cast<size_t>(type)] = true; 80 } 81 GetSerializedSendMessage()82 ByteBuffer* GetSerializedSendMessage() override { 83 GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr); 84 if (*orig_send_message_ != nullptr) { 85 GPR_CODEGEN_ASSERT(serializer_(*orig_send_message_).ok()); 86 *orig_send_message_ = nullptr; 87 } 88 return send_message_; 89 } 90 GetSendMessage()91 const void* GetSendMessage() override { 92 GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr); 93 return *orig_send_message_; 94 } 95 ModifySendMessage(const void * message)96 void ModifySendMessage(const void* message) override { 97 GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr); 98 *orig_send_message_ = message; 99 } 100 GetSendMessageStatus()101 bool GetSendMessageStatus() override { return !*fail_send_message_; } 102 GetSendInitialMetadata()103 std::multimap<std::string, std::string>* GetSendInitialMetadata() override { 104 return send_initial_metadata_; 105 } 106 GetSendStatus()107 Status GetSendStatus() override { 108 return Status(static_cast<StatusCode>(*code_), *error_message_, 109 *error_details_); 110 } 111 ModifySendStatus(const Status & status)112 void ModifySendStatus(const Status& status) override { 113 *code_ = static_cast<grpc_status_code>(status.error_code()); 114 *error_details_ = status.error_details(); 115 *error_message_ = status.error_message(); 116 } 117 GetSendTrailingMetadata()118 std::multimap<std::string, std::string>* GetSendTrailingMetadata() override { 119 return send_trailing_metadata_; 120 } 121 GetRecvMessage()122 void* GetRecvMessage() override { return recv_message_; } 123 GetRecvInitialMetadata()124 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata() 125 override { 126 return recv_initial_metadata_->map(); 127 } 128 GetRecvStatus()129 Status* GetRecvStatus() override { return recv_status_; } 130 FailHijackedSendMessage()131 void FailHijackedSendMessage() override { 132 GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>( 133 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]); 134 *fail_send_message_ = true; 135 } 136 GetRecvTrailingMetadata()137 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata() 138 override { 139 return recv_trailing_metadata_->map(); 140 } 141 SetSendMessage(ByteBuffer * buf,const void ** msg,bool * fail_send_message,std::function<Status (const void *)> serializer)142 void SetSendMessage(ByteBuffer* buf, const void** msg, 143 bool* fail_send_message, 144 std::function<Status(const void*)> serializer) { 145 send_message_ = buf; 146 orig_send_message_ = msg; 147 fail_send_message_ = fail_send_message; 148 serializer_ = serializer; 149 } 150 SetSendInitialMetadata(std::multimap<std::string,std::string> * metadata)151 void SetSendInitialMetadata( 152 std::multimap<std::string, std::string>* metadata) { 153 send_initial_metadata_ = metadata; 154 } 155 SetSendStatus(grpc_status_code * code,std::string * error_details,std::string * error_message)156 void SetSendStatus(grpc_status_code* code, std::string* error_details, 157 std::string* error_message) { 158 code_ = code; 159 error_details_ = error_details; 160 error_message_ = error_message; 161 } 162 SetSendTrailingMetadata(std::multimap<std::string,std::string> * metadata)163 void SetSendTrailingMetadata( 164 std::multimap<std::string, std::string>* metadata) { 165 send_trailing_metadata_ = metadata; 166 } 167 SetRecvMessage(void * message,bool * hijacked_recv_message_failed)168 void SetRecvMessage(void* message, bool* hijacked_recv_message_failed) { 169 recv_message_ = message; 170 hijacked_recv_message_failed_ = hijacked_recv_message_failed; 171 } 172 SetRecvInitialMetadata(MetadataMap * map)173 void SetRecvInitialMetadata(MetadataMap* map) { 174 recv_initial_metadata_ = map; 175 } 176 SetRecvStatus(Status * status)177 void SetRecvStatus(Status* status) { recv_status_ = status; } 178 SetRecvTrailingMetadata(MetadataMap * map)179 void SetRecvTrailingMetadata(MetadataMap* map) { 180 recv_trailing_metadata_ = map; 181 } 182 GetInterceptedChannel()183 std::unique_ptr<ChannelInterface> GetInterceptedChannel() override { 184 auto* info = call_->client_rpc_info(); 185 if (info == nullptr) { 186 return std::unique_ptr<ChannelInterface>(nullptr); 187 } 188 // The intercepted channel starts from the interceptor just after the 189 // current interceptor 190 return std::unique_ptr<ChannelInterface>(new InterceptedChannel( 191 info->channel(), current_interceptor_index_ + 1)); 192 } 193 FailHijackedRecvMessage()194 void FailHijackedRecvMessage() override { 195 GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>( 196 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)]); 197 *hijacked_recv_message_failed_ = true; 198 } 199 200 // Clears all state ClearState()201 void ClearState() { 202 reverse_ = false; 203 ran_hijacking_interceptor_ = false; 204 ClearHookPoints(); 205 } 206 207 // Prepares for Post_recv operations SetReverse()208 void SetReverse() { 209 reverse_ = true; 210 ran_hijacking_interceptor_ = false; 211 ClearHookPoints(); 212 } 213 214 // This needs to be set before interceptors are run SetCall(Call * call)215 void SetCall(Call* call) { call_ = call; } 216 217 // This needs to be set before interceptors are run using RunInterceptors(). 218 // Alternatively, RunInterceptors(std::function<void(void)> f) can be used. SetCallOpSetInterface(CallOpSetInterface * ops)219 void SetCallOpSetInterface(CallOpSetInterface* ops) { ops_ = ops; } 220 221 // SetCall should have been called before this. 222 // Returns true if the interceptors list is empty InterceptorsListEmpty()223 bool InterceptorsListEmpty() { 224 auto* client_rpc_info = call_->client_rpc_info(); 225 if (client_rpc_info != nullptr) { 226 return client_rpc_info->interceptors_.empty(); 227 } 228 229 auto* server_rpc_info = call_->server_rpc_info(); 230 return server_rpc_info == nullptr || server_rpc_info->interceptors_.empty(); 231 } 232 233 // This should be used only by subclasses of CallOpSetInterface. SetCall and 234 // SetCallOpSetInterface should have been called before this. After all the 235 // interceptors are done running, either ContinueFillOpsAfterInterception or 236 // ContinueFinalizeOpsAfterInterception will be called. Note that neither of 237 // them is invoked if there were no interceptors registered. RunInterceptors()238 bool RunInterceptors() { 239 GPR_CODEGEN_ASSERT(ops_); 240 auto* client_rpc_info = call_->client_rpc_info(); 241 if (client_rpc_info != nullptr) { 242 if (client_rpc_info->interceptors_.empty()) { 243 return true; 244 } else { 245 RunClientInterceptors(); 246 return false; 247 } 248 } 249 250 auto* server_rpc_info = call_->server_rpc_info(); 251 if (server_rpc_info == nullptr || server_rpc_info->interceptors_.empty()) { 252 return true; 253 } 254 RunServerInterceptors(); 255 return false; 256 } 257 258 // Returns true if no interceptors are run. Returns false otherwise if there 259 // are interceptors registered. After the interceptors are done running \a f 260 // will be invoked. This is to be used only by BaseAsyncRequest and 261 // SyncRequest. RunInterceptors(std::function<void (void)> f)262 bool RunInterceptors(std::function<void(void)> f) { 263 // This is used only by the server for initial call request 264 GPR_CODEGEN_ASSERT(reverse_ == true); 265 GPR_CODEGEN_ASSERT(call_->client_rpc_info() == nullptr); 266 auto* server_rpc_info = call_->server_rpc_info(); 267 if (server_rpc_info == nullptr || server_rpc_info->interceptors_.empty()) { 268 return true; 269 } 270 callback_ = std::move(f); 271 RunServerInterceptors(); 272 return false; 273 } 274 275 private: RunClientInterceptors()276 void RunClientInterceptors() { 277 auto* rpc_info = call_->client_rpc_info(); 278 if (!reverse_) { 279 current_interceptor_index_ = 0; 280 } else { 281 if (rpc_info->hijacked_) { 282 current_interceptor_index_ = rpc_info->hijacked_interceptor_; 283 } else { 284 current_interceptor_index_ = rpc_info->interceptors_.size() - 1; 285 } 286 } 287 rpc_info->RunInterceptor(this, current_interceptor_index_); 288 } 289 RunServerInterceptors()290 void RunServerInterceptors() { 291 auto* rpc_info = call_->server_rpc_info(); 292 if (!reverse_) { 293 current_interceptor_index_ = 0; 294 } else { 295 current_interceptor_index_ = rpc_info->interceptors_.size() - 1; 296 } 297 rpc_info->RunInterceptor(this, current_interceptor_index_); 298 } 299 ProceedClient()300 void ProceedClient() { 301 auto* rpc_info = call_->client_rpc_info(); 302 if (rpc_info->hijacked_ && !reverse_ && 303 current_interceptor_index_ == rpc_info->hijacked_interceptor_ && 304 !ran_hijacking_interceptor_) { 305 // We now need to provide hijacked recv ops to this interceptor 306 ClearHookPoints(); 307 ops_->SetHijackingState(); 308 ran_hijacking_interceptor_ = true; 309 rpc_info->RunInterceptor(this, current_interceptor_index_); 310 return; 311 } 312 if (!reverse_) { 313 current_interceptor_index_++; 314 // We are going down the stack of interceptors 315 if (current_interceptor_index_ < rpc_info->interceptors_.size()) { 316 if (rpc_info->hijacked_ && 317 current_interceptor_index_ > rpc_info->hijacked_interceptor_) { 318 // This is a hijacked RPC and we are done with hijacking 319 ops_->ContinueFillOpsAfterInterception(); 320 } else { 321 rpc_info->RunInterceptor(this, current_interceptor_index_); 322 } 323 } else { 324 // we are done running all the interceptors without any hijacking 325 ops_->ContinueFillOpsAfterInterception(); 326 } 327 } else { 328 // We are going up the stack of interceptors 329 if (current_interceptor_index_ > 0) { 330 // Continue running interceptors 331 current_interceptor_index_--; 332 rpc_info->RunInterceptor(this, current_interceptor_index_); 333 } else { 334 // we are done running all the interceptors without any hijacking 335 ops_->ContinueFinalizeResultAfterInterception(); 336 } 337 } 338 } 339 ProceedServer()340 void ProceedServer() { 341 auto* rpc_info = call_->server_rpc_info(); 342 if (!reverse_) { 343 current_interceptor_index_++; 344 if (current_interceptor_index_ < rpc_info->interceptors_.size()) { 345 return rpc_info->RunInterceptor(this, current_interceptor_index_); 346 } else if (ops_) { 347 return ops_->ContinueFillOpsAfterInterception(); 348 } 349 } else { 350 // We are going up the stack of interceptors 351 if (current_interceptor_index_ > 0) { 352 // Continue running interceptors 353 current_interceptor_index_--; 354 return rpc_info->RunInterceptor(this, current_interceptor_index_); 355 } else if (ops_) { 356 return ops_->ContinueFinalizeResultAfterInterception(); 357 } 358 } 359 GPR_CODEGEN_ASSERT(callback_); 360 callback_(); 361 } 362 ClearHookPoints()363 void ClearHookPoints() { 364 for (auto i = static_cast<experimental::InterceptionHookPoints>(0); 365 i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS; 366 i = static_cast<experimental::InterceptionHookPoints>( 367 static_cast<size_t>(i) + 1)) { 368 hooks_[static_cast<size_t>(i)] = false; 369 } 370 } 371 372 std::array<bool, 373 static_cast<size_t>( 374 experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)> 375 hooks_; 376 377 size_t current_interceptor_index_ = 0; // Current iterator 378 bool reverse_ = false; 379 bool ran_hijacking_interceptor_ = false; 380 Call* call_ = nullptr; // The Call object is present along with CallOpSet 381 // object/callback 382 CallOpSetInterface* ops_ = nullptr; 383 std::function<void(void)> callback_; 384 385 ByteBuffer* send_message_ = nullptr; 386 bool* fail_send_message_ = nullptr; 387 const void** orig_send_message_ = nullptr; 388 std::function<Status(const void*)> serializer_; 389 390 std::multimap<std::string, std::string>* send_initial_metadata_; 391 392 grpc_status_code* code_ = nullptr; 393 std::string* error_details_ = nullptr; 394 std::string* error_message_ = nullptr; 395 396 std::multimap<std::string, std::string>* send_trailing_metadata_ = nullptr; 397 398 void* recv_message_ = nullptr; 399 bool* hijacked_recv_message_failed_ = nullptr; 400 401 MetadataMap* recv_initial_metadata_ = nullptr; 402 403 Status* recv_status_ = nullptr; 404 405 MetadataMap* recv_trailing_metadata_ = nullptr; 406 }; 407 408 // A special implementation of InterceptorBatchMethods to send a Cancel 409 // notification down the interceptor stack 410 class CancelInterceptorBatchMethods 411 : public experimental::InterceptorBatchMethods { 412 public: QueryInterceptionHookPoint(experimental::InterceptionHookPoints type)413 bool QueryInterceptionHookPoint( 414 experimental::InterceptionHookPoints type) override { 415 return type == experimental::InterceptionHookPoints::PRE_SEND_CANCEL; 416 } 417 Proceed()418 void Proceed() override { 419 // This is a no-op. For actual continuation of the RPC simply needs to 420 // return from the Intercept method 421 } 422 Hijack()423 void Hijack() override { 424 // Only the client can hijack when sending down initial metadata 425 GPR_CODEGEN_ASSERT(false && 426 "It is illegal to call Hijack on a method which has a " 427 "Cancel notification"); 428 } 429 GetSerializedSendMessage()430 ByteBuffer* GetSerializedSendMessage() override { 431 GPR_CODEGEN_ASSERT(false && 432 "It is illegal to call GetSendMessage on a method which " 433 "has a Cancel notification"); 434 return nullptr; 435 } 436 GetSendMessageStatus()437 bool GetSendMessageStatus() override { 438 GPR_CODEGEN_ASSERT( 439 false && 440 "It is illegal to call GetSendMessageStatus on a method which " 441 "has a Cancel notification"); 442 return false; 443 } 444 GetSendMessage()445 const void* GetSendMessage() override { 446 GPR_CODEGEN_ASSERT( 447 false && 448 "It is illegal to call GetOriginalSendMessage on a method which " 449 "has a Cancel notification"); 450 return nullptr; 451 } 452 ModifySendMessage(const void *)453 void ModifySendMessage(const void* /*message*/) override { 454 GPR_CODEGEN_ASSERT( 455 false && 456 "It is illegal to call ModifySendMessage on a method which " 457 "has a Cancel notification"); 458 } 459 GetSendInitialMetadata()460 std::multimap<std::string, std::string>* GetSendInitialMetadata() override { 461 GPR_CODEGEN_ASSERT(false && 462 "It is illegal to call GetSendInitialMetadata on a " 463 "method which has a Cancel notification"); 464 return nullptr; 465 } 466 GetSendStatus()467 Status GetSendStatus() override { 468 GPR_CODEGEN_ASSERT(false && 469 "It is illegal to call GetSendStatus on a method which " 470 "has a Cancel notification"); 471 return Status(); 472 } 473 ModifySendStatus(const Status &)474 void ModifySendStatus(const Status& /*status*/) override { 475 GPR_CODEGEN_ASSERT(false && 476 "It is illegal to call ModifySendStatus on a method " 477 "which has a Cancel notification"); 478 } 479 GetSendTrailingMetadata()480 std::multimap<std::string, std::string>* GetSendTrailingMetadata() override { 481 GPR_CODEGEN_ASSERT(false && 482 "It is illegal to call GetSendTrailingMetadata on a " 483 "method which has a Cancel notification"); 484 return nullptr; 485 } 486 GetRecvMessage()487 void* GetRecvMessage() override { 488 GPR_CODEGEN_ASSERT(false && 489 "It is illegal to call GetRecvMessage on a method which " 490 "has a Cancel notification"); 491 return nullptr; 492 } 493 GetRecvInitialMetadata()494 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata() 495 override { 496 GPR_CODEGEN_ASSERT(false && 497 "It is illegal to call GetRecvInitialMetadata on a " 498 "method which has a Cancel notification"); 499 return nullptr; 500 } 501 GetRecvStatus()502 Status* GetRecvStatus() override { 503 GPR_CODEGEN_ASSERT(false && 504 "It is illegal to call GetRecvStatus on a method which " 505 "has a Cancel notification"); 506 return nullptr; 507 } 508 GetRecvTrailingMetadata()509 std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata() 510 override { 511 GPR_CODEGEN_ASSERT(false && 512 "It is illegal to call GetRecvTrailingMetadata on a " 513 "method which has a Cancel notification"); 514 return nullptr; 515 } 516 GetInterceptedChannel()517 std::unique_ptr<ChannelInterface> GetInterceptedChannel() override { 518 GPR_CODEGEN_ASSERT(false && 519 "It is illegal to call GetInterceptedChannel on a " 520 "method which has a Cancel notification"); 521 return std::unique_ptr<ChannelInterface>(nullptr); 522 } 523 FailHijackedRecvMessage()524 void FailHijackedRecvMessage() override { 525 GPR_CODEGEN_ASSERT(false && 526 "It is illegal to call FailHijackedRecvMessage on a " 527 "method which has a Cancel notification"); 528 } 529 FailHijackedSendMessage()530 void FailHijackedSendMessage() override { 531 GPR_CODEGEN_ASSERT(false && 532 "It is illegal to call FailHijackedSendMessage on a " 533 "method which has a Cancel notification"); 534 } 535 }; 536 } // namespace internal 537 } // namespace grpc 538 539 #endif // GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H 540