1# Copyright 2020 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Test helpers for RPC invocation tests.""" 15 16import datetime 17import threading 18 19import grpc 20from grpc.framework.foundation import logging_pool 21 22from tests.unit import test_common 23from tests.unit import thread_pool 24from tests.unit.framework.common import test_constants 25from tests.unit.framework.common import test_control 26 27_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 28_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2 :] 29_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 30_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[: len(bytestring) // 3] 31 32_SERVICE_NAME = "test" 33_UNARY_UNARY = "UnaryUnary" 34_UNARY_STREAM = "UnaryStream" 35_UNARY_STREAM_NON_BLOCKING = "UnaryStreamNonBlocking" 36_STREAM_UNARY = "StreamUnary" 37_STREAM_STREAM = "StreamStream" 38_STREAM_STREAM_NON_BLOCKING = "StreamStreamNonBlocking" 39 40TIMEOUT_SHORT = datetime.timedelta(seconds=4).total_seconds() 41 42 43class Callback(object): 44 def __init__(self): 45 self._condition = threading.Condition() 46 self._value = None 47 self._called = False 48 49 def __call__(self, value): 50 with self._condition: 51 self._value = value 52 self._called = True 53 self._condition.notify_all() 54 55 def value(self): 56 with self._condition: 57 while not self._called: 58 self._condition.wait() 59 return self._value 60 61 62class _Handler(object): 63 def __init__(self, control, thread_pool): 64 self._control = control 65 self._thread_pool = thread_pool 66 non_blocking_functions = ( 67 self.handle_unary_stream_non_blocking, 68 self.handle_stream_stream_non_blocking, 69 ) 70 for non_blocking_function in non_blocking_functions: 71 non_blocking_function.__func__.experimental_non_blocking = True 72 non_blocking_function.__func__.experimental_thread_pool = ( 73 self._thread_pool 74 ) 75 76 def handle_unary_unary(self, request, servicer_context): 77 self._control.control() 78 if servicer_context is not None: 79 servicer_context.set_trailing_metadata( 80 ( 81 ( 82 "testkey", 83 "testvalue", 84 ), 85 ) 86 ) 87 # TODO(https://github.com/grpc/grpc/issues/8483): test the values 88 # returned by these methods rather than only "smoke" testing that 89 # the return after having been called. 90 servicer_context.is_active() 91 servicer_context.time_remaining() 92 return request 93 94 def handle_unary_stream(self, request, servicer_context): 95 for _ in range(test_constants.STREAM_LENGTH): 96 self._control.control() 97 yield request 98 self._control.control() 99 if servicer_context is not None: 100 servicer_context.set_trailing_metadata( 101 ( 102 ( 103 "testkey", 104 "testvalue", 105 ), 106 ) 107 ) 108 109 def handle_unary_stream_non_blocking( 110 self, request, servicer_context, on_next 111 ): 112 for _ in range(test_constants.STREAM_LENGTH): 113 self._control.control() 114 on_next(request) 115 self._control.control() 116 if servicer_context is not None: 117 servicer_context.set_trailing_metadata( 118 ( 119 ( 120 "testkey", 121 "testvalue", 122 ), 123 ) 124 ) 125 on_next(None) 126 127 def handle_stream_unary(self, request_iterator, servicer_context): 128 if servicer_context is not None: 129 servicer_context.invocation_metadata() 130 self._control.control() 131 response_elements = [] 132 for request in request_iterator: 133 self._control.control() 134 response_elements.append(request) 135 self._control.control() 136 if servicer_context is not None: 137 servicer_context.set_trailing_metadata( 138 ( 139 ( 140 "testkey", 141 "testvalue", 142 ), 143 ) 144 ) 145 return b"".join(response_elements) 146 147 def handle_stream_stream(self, request_iterator, servicer_context): 148 self._control.control() 149 if servicer_context is not None: 150 servicer_context.set_trailing_metadata( 151 ( 152 ( 153 "testkey", 154 "testvalue", 155 ), 156 ) 157 ) 158 for request in request_iterator: 159 self._control.control() 160 yield request 161 self._control.control() 162 163 def handle_stream_stream_non_blocking( 164 self, request_iterator, servicer_context, on_next 165 ): 166 self._control.control() 167 if servicer_context is not None: 168 servicer_context.set_trailing_metadata( 169 ( 170 ( 171 "testkey", 172 "testvalue", 173 ), 174 ) 175 ) 176 for request in request_iterator: 177 self._control.control() 178 on_next(request) 179 self._control.control() 180 on_next(None) 181 182 183class _MethodHandler(grpc.RpcMethodHandler): 184 def __init__( 185 self, 186 request_streaming, 187 response_streaming, 188 request_deserializer, 189 response_serializer, 190 unary_unary, 191 unary_stream, 192 stream_unary, 193 stream_stream, 194 ): 195 self.request_streaming = request_streaming 196 self.response_streaming = response_streaming 197 self.request_deserializer = request_deserializer 198 self.response_serializer = response_serializer 199 self.unary_unary = unary_unary 200 self.unary_stream = unary_stream 201 self.stream_unary = stream_unary 202 self.stream_stream = stream_stream 203 204 205class _GenericHandler(grpc.GenericRpcHandler): 206 def __init__(self, handler): 207 self._handler = handler 208 209 def service(self, handler_call_details): 210 if handler_call_details.method == _UNARY_UNARY: 211 return _MethodHandler( 212 False, 213 False, 214 None, 215 None, 216 self._handler.handle_unary_unary, 217 None, 218 None, 219 None, 220 ) 221 elif handler_call_details.method == _UNARY_STREAM: 222 return _MethodHandler( 223 False, 224 True, 225 _DESERIALIZE_REQUEST, 226 _SERIALIZE_RESPONSE, 227 None, 228 self._handler.handle_unary_stream, 229 None, 230 None, 231 ) 232 elif handler_call_details.method == _UNARY_STREAM_NON_BLOCKING: 233 return _MethodHandler( 234 False, 235 True, 236 _DESERIALIZE_REQUEST, 237 _SERIALIZE_RESPONSE, 238 None, 239 self._handler.handle_unary_stream_non_blocking, 240 None, 241 None, 242 ) 243 elif handler_call_details.method == _STREAM_UNARY: 244 return _MethodHandler( 245 True, 246 False, 247 _DESERIALIZE_REQUEST, 248 _SERIALIZE_RESPONSE, 249 None, 250 None, 251 self._handler.handle_stream_unary, 252 None, 253 ) 254 elif handler_call_details.method == _STREAM_STREAM: 255 return _MethodHandler( 256 True, 257 True, 258 None, 259 None, 260 None, 261 None, 262 None, 263 self._handler.handle_stream_stream, 264 ) 265 elif handler_call_details.method == _STREAM_STREAM_NON_BLOCKING: 266 return _MethodHandler( 267 True, 268 True, 269 None, 270 None, 271 None, 272 None, 273 None, 274 self._handler.handle_stream_stream_non_blocking, 275 ) 276 else: 277 return None 278 279 280def get_method_handlers(handler): 281 return { 282 _UNARY_UNARY: _MethodHandler( 283 False, 284 False, 285 None, 286 None, 287 handler.handle_unary_unary, 288 None, 289 None, 290 None, 291 ), 292 _UNARY_STREAM: _MethodHandler( 293 False, 294 True, 295 _DESERIALIZE_REQUEST, 296 _SERIALIZE_RESPONSE, 297 None, 298 handler.handle_unary_stream, 299 None, 300 None, 301 ), 302 _UNARY_STREAM_NON_BLOCKING: _MethodHandler( 303 False, 304 True, 305 _DESERIALIZE_REQUEST, 306 _SERIALIZE_RESPONSE, 307 None, 308 handler.handle_unary_stream_non_blocking, 309 None, 310 None, 311 ), 312 _STREAM_UNARY: _MethodHandler( 313 True, 314 False, 315 _DESERIALIZE_REQUEST, 316 _SERIALIZE_RESPONSE, 317 None, 318 None, 319 handler.handle_stream_unary, 320 None, 321 ), 322 _STREAM_STREAM: _MethodHandler( 323 True, 324 True, 325 None, 326 None, 327 None, 328 None, 329 None, 330 handler.handle_stream_stream, 331 ), 332 _STREAM_STREAM_NON_BLOCKING: _MethodHandler( 333 True, 334 True, 335 None, 336 None, 337 None, 338 None, 339 None, 340 handler.handle_stream_stream_non_blocking, 341 ), 342 } 343 344 345def unary_unary_multi_callable(channel): 346 return channel.unary_unary( 347 grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_UNARY), 348 _registered_method=True, 349 ) 350 351 352def unary_stream_multi_callable(channel): 353 return channel.unary_stream( 354 grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_STREAM), 355 request_serializer=_SERIALIZE_REQUEST, 356 response_deserializer=_DESERIALIZE_RESPONSE, 357 _registered_method=True, 358 ) 359 360 361def unary_stream_non_blocking_multi_callable(channel): 362 return channel.unary_stream( 363 grpc._common.fully_qualified_method( 364 _SERVICE_NAME, _UNARY_STREAM_NON_BLOCKING 365 ), 366 request_serializer=_SERIALIZE_REQUEST, 367 response_deserializer=_DESERIALIZE_RESPONSE, 368 _registered_method=True, 369 ) 370 371 372def stream_unary_multi_callable(channel): 373 return channel.stream_unary( 374 grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_UNARY), 375 request_serializer=_SERIALIZE_REQUEST, 376 response_deserializer=_DESERIALIZE_RESPONSE, 377 _registered_method=True, 378 ) 379 380 381def stream_stream_multi_callable(channel): 382 return channel.stream_stream( 383 grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_STREAM), 384 _registered_method=True, 385 ) 386 387 388def stream_stream_non_blocking_multi_callable(channel): 389 return channel.stream_stream( 390 grpc._common.fully_qualified_method( 391 _SERVICE_NAME, _STREAM_STREAM_NON_BLOCKING 392 ), 393 _registered_method=True, 394 ) 395 396 397class BaseRPCTest(object): 398 def setUp(self): 399 self._control = test_control.PauseFailControl() 400 self._thread_pool = thread_pool.RecordingThreadPool(max_workers=None) 401 self._handler = _Handler(self._control, self._thread_pool) 402 403 self._server = test_common.test_server() 404 port = self._server.add_insecure_port("[::]:0") 405 self._server.add_registered_method_handlers( 406 _SERVICE_NAME, get_method_handlers(self._handler) 407 ) 408 self._server.start() 409 410 self._channel = grpc.insecure_channel("localhost:%d" % port) 411 412 def tearDown(self): 413 self._server.stop(None) 414 self._channel.close() 415 416 def _consume_one_stream_response_unary_request(self, multi_callable): 417 request = b"\x57\x38" 418 419 response_iterator = multi_callable( 420 request, 421 metadata=(("test", "ConsumingOneStreamResponseUnaryRequest"),), 422 ) 423 next(response_iterator) 424 425 def _consume_some_but_not_all_stream_responses_unary_request( 426 self, multi_callable 427 ): 428 request = b"\x57\x38" 429 430 response_iterator = multi_callable( 431 request, 432 metadata=( 433 ("test", "ConsumingSomeButNotAllStreamResponsesUnaryRequest"), 434 ), 435 ) 436 for _ in range(test_constants.STREAM_LENGTH // 2): 437 next(response_iterator) 438 439 def _consume_some_but_not_all_stream_responses_stream_request( 440 self, multi_callable 441 ): 442 requests = tuple( 443 b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH) 444 ) 445 request_iterator = iter(requests) 446 447 response_iterator = multi_callable( 448 request_iterator, 449 metadata=( 450 ("test", "ConsumingSomeButNotAllStreamResponsesStreamRequest"), 451 ), 452 ) 453 for _ in range(test_constants.STREAM_LENGTH // 2): 454 next(response_iterator) 455 456 def _consume_too_many_stream_responses_stream_request(self, multi_callable): 457 requests = tuple( 458 b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH) 459 ) 460 request_iterator = iter(requests) 461 462 response_iterator = multi_callable( 463 request_iterator, 464 metadata=( 465 ("test", "ConsumingTooManyStreamResponsesStreamRequest"), 466 ), 467 ) 468 for _ in range(test_constants.STREAM_LENGTH): 469 next(response_iterator) 470 for _ in range(test_constants.STREAM_LENGTH): 471 with self.assertRaises(StopIteration): 472 next(response_iterator) 473 474 self.assertIsNotNone(response_iterator.initial_metadata()) 475 self.assertIs(grpc.StatusCode.OK, response_iterator.code()) 476 self.assertIsNotNone(response_iterator.details()) 477 self.assertIsNotNone(response_iterator.trailing_metadata()) 478 479 def _cancelled_unary_request_stream_response(self, multi_callable): 480 request = b"\x07\x19" 481 482 with self._control.pause(): 483 response_iterator = multi_callable( 484 request, 485 metadata=(("test", "CancelledUnaryRequestStreamResponse"),), 486 ) 487 self._control.block_until_paused() 488 response_iterator.cancel() 489 490 with self.assertRaises(grpc.RpcError) as exception_context: 491 next(response_iterator) 492 self.assertIs( 493 grpc.StatusCode.CANCELLED, exception_context.exception.code() 494 ) 495 self.assertIsNotNone(response_iterator.initial_metadata()) 496 self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code()) 497 self.assertIsNotNone(response_iterator.details()) 498 self.assertIsNotNone(response_iterator.trailing_metadata()) 499 500 def _cancelled_stream_request_stream_response(self, multi_callable): 501 requests = tuple( 502 b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) 503 ) 504 request_iterator = iter(requests) 505 506 with self._control.pause(): 507 response_iterator = multi_callable( 508 request_iterator, 509 metadata=(("test", "CancelledStreamRequestStreamResponse"),), 510 ) 511 response_iterator.cancel() 512 513 with self.assertRaises(grpc.RpcError): 514 next(response_iterator) 515 self.assertIsNotNone(response_iterator.initial_metadata()) 516 self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code()) 517 self.assertIsNotNone(response_iterator.details()) 518 self.assertIsNotNone(response_iterator.trailing_metadata()) 519 520 def _expired_unary_request_stream_response(self, multi_callable): 521 request = b"\x07\x19" 522 523 with self._control.pause(): 524 with self.assertRaises(grpc.RpcError) as exception_context: 525 response_iterator = multi_callable( 526 request, 527 timeout=test_constants.SHORT_TIMEOUT, 528 metadata=(("test", "ExpiredUnaryRequestStreamResponse"),), 529 ) 530 next(response_iterator) 531 532 self.assertIs( 533 grpc.StatusCode.DEADLINE_EXCEEDED, 534 exception_context.exception.code(), 535 ) 536 self.assertIs( 537 grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code() 538 ) 539 540 def _expired_stream_request_stream_response(self, multi_callable): 541 requests = tuple( 542 b"\x67\x18" for _ in range(test_constants.STREAM_LENGTH) 543 ) 544 request_iterator = iter(requests) 545 546 with self._control.pause(): 547 with self.assertRaises(grpc.RpcError) as exception_context: 548 response_iterator = multi_callable( 549 request_iterator, 550 timeout=test_constants.SHORT_TIMEOUT, 551 metadata=(("test", "ExpiredStreamRequestStreamResponse"),), 552 ) 553 next(response_iterator) 554 555 self.assertIs( 556 grpc.StatusCode.DEADLINE_EXCEEDED, 557 exception_context.exception.code(), 558 ) 559 self.assertIs( 560 grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code() 561 ) 562 563 def _failed_unary_request_stream_response(self, multi_callable): 564 request = b"\x37\x17" 565 566 with self.assertRaises(grpc.RpcError) as exception_context: 567 with self._control.fail(): 568 response_iterator = multi_callable( 569 request, 570 metadata=(("test", "FailedUnaryRequestStreamResponse"),), 571 ) 572 next(response_iterator) 573 574 self.assertIs( 575 grpc.StatusCode.UNKNOWN, exception_context.exception.code() 576 ) 577 578 def _failed_stream_request_stream_response(self, multi_callable): 579 requests = tuple( 580 b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH) 581 ) 582 request_iterator = iter(requests) 583 584 with self._control.fail(): 585 with self.assertRaises(grpc.RpcError) as exception_context: 586 response_iterator = multi_callable( 587 request_iterator, 588 metadata=(("test", "FailedStreamRequestStreamResponse"),), 589 ) 590 tuple(response_iterator) 591 592 self.assertIs( 593 grpc.StatusCode.UNKNOWN, exception_context.exception.code() 594 ) 595 self.assertIs(grpc.StatusCode.UNKNOWN, response_iterator.code()) 596 597 def _ignored_unary_stream_request_future_unary_response( 598 self, multi_callable 599 ): 600 request = b"\x37\x17" 601 602 multi_callable( 603 request, metadata=(("test", "IgnoredUnaryRequestStreamResponse"),) 604 ) 605 606 def _ignored_stream_request_stream_response(self, multi_callable): 607 requests = tuple( 608 b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH) 609 ) 610 request_iterator = iter(requests) 611 612 multi_callable( 613 request_iterator, 614 metadata=(("test", "IgnoredStreamRequestStreamResponse"),), 615 ) 616