1# Copyright 2020 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Test helpers for RPC invocation tests.""" 15 16import datetime 17import threading 18 19import grpc 20from grpc.framework.foundation import logging_pool 21 22from tests.unit import test_common 23from tests.unit import thread_pool 24from tests.unit.framework.common import test_constants 25from tests.unit.framework.common import test_control 26 27_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 28_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:] 29_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 30_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] 31 32_UNARY_UNARY = '/test/UnaryUnary' 33_UNARY_STREAM = '/test/UnaryStream' 34_UNARY_STREAM_NON_BLOCKING = '/test/UnaryStreamNonBlocking' 35_STREAM_UNARY = '/test/StreamUnary' 36_STREAM_STREAM = '/test/StreamStream' 37_STREAM_STREAM_NON_BLOCKING = '/test/StreamStreamNonBlocking' 38 39TIMEOUT_SHORT = datetime.timedelta(seconds=1).total_seconds() 40 41 42class Callback(object): 43 44 def __init__(self): 45 self._condition = threading.Condition() 46 self._value = None 47 self._called = False 48 49 def __call__(self, value): 50 with self._condition: 51 self._value = value 52 self._called = True 53 self._condition.notify_all() 54 55 def value(self): 56 with self._condition: 57 while not self._called: 58 self._condition.wait() 59 return self._value 60 61 62class _Handler(object): 63 64 def __init__(self, control, thread_pool): 65 self._control = control 66 self._thread_pool = thread_pool 67 non_blocking_functions = (self.handle_unary_stream_non_blocking, 68 self.handle_stream_stream_non_blocking) 69 for non_blocking_function in non_blocking_functions: 70 non_blocking_function.__func__.experimental_non_blocking = True 71 non_blocking_function.__func__.experimental_thread_pool = self._thread_pool 72 73 def handle_unary_unary(self, request, servicer_context): 74 self._control.control() 75 if servicer_context is not None: 76 servicer_context.set_trailing_metadata((( 77 'testkey', 78 'testvalue', 79 ),)) 80 # TODO(https://github.com/grpc/grpc/issues/8483): test the values 81 # returned by these methods rather than only "smoke" testing that 82 # the return after having been called. 83 servicer_context.is_active() 84 servicer_context.time_remaining() 85 return request 86 87 def handle_unary_stream(self, request, servicer_context): 88 for _ in range(test_constants.STREAM_LENGTH): 89 self._control.control() 90 yield request 91 self._control.control() 92 if servicer_context is not None: 93 servicer_context.set_trailing_metadata((( 94 'testkey', 95 'testvalue', 96 ),)) 97 98 def handle_unary_stream_non_blocking(self, request, servicer_context, 99 on_next): 100 for _ in range(test_constants.STREAM_LENGTH): 101 self._control.control() 102 on_next(request) 103 self._control.control() 104 if servicer_context is not None: 105 servicer_context.set_trailing_metadata((( 106 'testkey', 107 'testvalue', 108 ),)) 109 on_next(None) 110 111 def handle_stream_unary(self, request_iterator, servicer_context): 112 if servicer_context is not None: 113 servicer_context.invocation_metadata() 114 self._control.control() 115 response_elements = [] 116 for request in request_iterator: 117 self._control.control() 118 response_elements.append(request) 119 self._control.control() 120 if servicer_context is not None: 121 servicer_context.set_trailing_metadata((( 122 'testkey', 123 'testvalue', 124 ),)) 125 return b''.join(response_elements) 126 127 def handle_stream_stream(self, request_iterator, servicer_context): 128 self._control.control() 129 if servicer_context is not None: 130 servicer_context.set_trailing_metadata((( 131 'testkey', 132 'testvalue', 133 ),)) 134 for request in request_iterator: 135 self._control.control() 136 yield request 137 self._control.control() 138 139 def handle_stream_stream_non_blocking(self, request_iterator, 140 servicer_context, on_next): 141 self._control.control() 142 if servicer_context is not None: 143 servicer_context.set_trailing_metadata((( 144 'testkey', 145 'testvalue', 146 ),)) 147 for request in request_iterator: 148 self._control.control() 149 on_next(request) 150 self._control.control() 151 on_next(None) 152 153 154class _MethodHandler(grpc.RpcMethodHandler): 155 156 def __init__(self, request_streaming, response_streaming, 157 request_deserializer, response_serializer, unary_unary, 158 unary_stream, stream_unary, stream_stream): 159 self.request_streaming = request_streaming 160 self.response_streaming = response_streaming 161 self.request_deserializer = request_deserializer 162 self.response_serializer = response_serializer 163 self.unary_unary = unary_unary 164 self.unary_stream = unary_stream 165 self.stream_unary = stream_unary 166 self.stream_stream = stream_stream 167 168 169class _GenericHandler(grpc.GenericRpcHandler): 170 171 def __init__(self, handler): 172 self._handler = handler 173 174 def service(self, handler_call_details): 175 if handler_call_details.method == _UNARY_UNARY: 176 return _MethodHandler(False, False, None, None, 177 self._handler.handle_unary_unary, None, None, 178 None) 179 elif handler_call_details.method == _UNARY_STREAM: 180 return _MethodHandler(False, True, _DESERIALIZE_REQUEST, 181 _SERIALIZE_RESPONSE, None, 182 self._handler.handle_unary_stream, None, None) 183 elif handler_call_details.method == _UNARY_STREAM_NON_BLOCKING: 184 return _MethodHandler( 185 False, True, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None, 186 self._handler.handle_unary_stream_non_blocking, None, None) 187 elif handler_call_details.method == _STREAM_UNARY: 188 return _MethodHandler(True, False, _DESERIALIZE_REQUEST, 189 _SERIALIZE_RESPONSE, None, None, 190 self._handler.handle_stream_unary, None) 191 elif handler_call_details.method == _STREAM_STREAM: 192 return _MethodHandler(True, True, None, None, None, None, None, 193 self._handler.handle_stream_stream) 194 elif handler_call_details.method == _STREAM_STREAM_NON_BLOCKING: 195 return _MethodHandler( 196 True, True, None, None, None, None, None, 197 self._handler.handle_stream_stream_non_blocking) 198 else: 199 return None 200 201 202def unary_unary_multi_callable(channel): 203 return channel.unary_unary(_UNARY_UNARY) 204 205 206def unary_stream_multi_callable(channel): 207 return channel.unary_stream(_UNARY_STREAM, 208 request_serializer=_SERIALIZE_REQUEST, 209 response_deserializer=_DESERIALIZE_RESPONSE) 210 211 212def unary_stream_non_blocking_multi_callable(channel): 213 return channel.unary_stream(_UNARY_STREAM_NON_BLOCKING, 214 request_serializer=_SERIALIZE_REQUEST, 215 response_deserializer=_DESERIALIZE_RESPONSE) 216 217 218def stream_unary_multi_callable(channel): 219 return channel.stream_unary(_STREAM_UNARY, 220 request_serializer=_SERIALIZE_REQUEST, 221 response_deserializer=_DESERIALIZE_RESPONSE) 222 223 224def stream_stream_multi_callable(channel): 225 return channel.stream_stream(_STREAM_STREAM) 226 227 228def stream_stream_non_blocking_multi_callable(channel): 229 return channel.stream_stream(_STREAM_STREAM_NON_BLOCKING) 230 231 232class BaseRPCTest(object): 233 234 def setUp(self): 235 self._control = test_control.PauseFailControl() 236 self._thread_pool = thread_pool.RecordingThreadPool(max_workers=None) 237 self._handler = _Handler(self._control, self._thread_pool) 238 239 self._server = test_common.test_server() 240 port = self._server.add_insecure_port('[::]:0') 241 self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) 242 self._server.start() 243 244 self._channel = grpc.insecure_channel('localhost:%d' % port) 245 246 def tearDown(self): 247 self._server.stop(None) 248 self._channel.close() 249 250 def _consume_one_stream_response_unary_request(self, multi_callable): 251 request = b'\x57\x38' 252 253 response_iterator = multi_callable( 254 request, 255 metadata=(('test', 'ConsumingOneStreamResponseUnaryRequest'),)) 256 next(response_iterator) 257 258 def _consume_some_but_not_all_stream_responses_unary_request( 259 self, multi_callable): 260 request = b'\x57\x38' 261 262 response_iterator = multi_callable( 263 request, 264 metadata=(('test', 265 'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),)) 266 for _ in range(test_constants.STREAM_LENGTH // 2): 267 next(response_iterator) 268 269 def _consume_some_but_not_all_stream_responses_stream_request( 270 self, multi_callable): 271 requests = tuple( 272 b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) 273 request_iterator = iter(requests) 274 275 response_iterator = multi_callable( 276 request_iterator, 277 metadata=(('test', 278 'ConsumingSomeButNotAllStreamResponsesStreamRequest'),)) 279 for _ in range(test_constants.STREAM_LENGTH // 2): 280 next(response_iterator) 281 282 def _consume_too_many_stream_responses_stream_request(self, multi_callable): 283 requests = tuple( 284 b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) 285 request_iterator = iter(requests) 286 287 response_iterator = multi_callable( 288 request_iterator, 289 metadata=(('test', 290 'ConsumingTooManyStreamResponsesStreamRequest'),)) 291 for _ in range(test_constants.STREAM_LENGTH): 292 next(response_iterator) 293 for _ in range(test_constants.STREAM_LENGTH): 294 with self.assertRaises(StopIteration): 295 next(response_iterator) 296 297 self.assertIsNotNone(response_iterator.initial_metadata()) 298 self.assertIs(grpc.StatusCode.OK, response_iterator.code()) 299 self.assertIsNotNone(response_iterator.details()) 300 self.assertIsNotNone(response_iterator.trailing_metadata()) 301 302 def _cancelled_unary_request_stream_response(self, multi_callable): 303 request = b'\x07\x19' 304 305 with self._control.pause(): 306 response_iterator = multi_callable( 307 request, 308 metadata=(('test', 'CancelledUnaryRequestStreamResponse'),)) 309 self._control.block_until_paused() 310 response_iterator.cancel() 311 312 with self.assertRaises(grpc.RpcError) as exception_context: 313 next(response_iterator) 314 self.assertIs(grpc.StatusCode.CANCELLED, 315 exception_context.exception.code()) 316 self.assertIsNotNone(response_iterator.initial_metadata()) 317 self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code()) 318 self.assertIsNotNone(response_iterator.details()) 319 self.assertIsNotNone(response_iterator.trailing_metadata()) 320 321 def _cancelled_stream_request_stream_response(self, multi_callable): 322 requests = tuple( 323 b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) 324 request_iterator = iter(requests) 325 326 with self._control.pause(): 327 response_iterator = multi_callable( 328 request_iterator, 329 metadata=(('test', 'CancelledStreamRequestStreamResponse'),)) 330 response_iterator.cancel() 331 332 with self.assertRaises(grpc.RpcError): 333 next(response_iterator) 334 self.assertIsNotNone(response_iterator.initial_metadata()) 335 self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code()) 336 self.assertIsNotNone(response_iterator.details()) 337 self.assertIsNotNone(response_iterator.trailing_metadata()) 338 339 def _expired_unary_request_stream_response(self, multi_callable): 340 request = b'\x07\x19' 341 342 with self._control.pause(): 343 with self.assertRaises(grpc.RpcError) as exception_context: 344 response_iterator = multi_callable( 345 request, 346 timeout=test_constants.SHORT_TIMEOUT, 347 metadata=(('test', 'ExpiredUnaryRequestStreamResponse'),)) 348 next(response_iterator) 349 350 self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, 351 exception_context.exception.code()) 352 self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, 353 response_iterator.code()) 354 355 def _expired_stream_request_stream_response(self, multi_callable): 356 requests = tuple( 357 b'\x67\x18' for _ in range(test_constants.STREAM_LENGTH)) 358 request_iterator = iter(requests) 359 360 with self._control.pause(): 361 with self.assertRaises(grpc.RpcError) as exception_context: 362 response_iterator = multi_callable( 363 request_iterator, 364 timeout=test_constants.SHORT_TIMEOUT, 365 metadata=(('test', 'ExpiredStreamRequestStreamResponse'),)) 366 next(response_iterator) 367 368 self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, 369 exception_context.exception.code()) 370 self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, 371 response_iterator.code()) 372 373 def _failed_unary_request_stream_response(self, multi_callable): 374 request = b'\x37\x17' 375 376 with self.assertRaises(grpc.RpcError) as exception_context: 377 with self._control.fail(): 378 response_iterator = multi_callable( 379 request, 380 metadata=(('test', 'FailedUnaryRequestStreamResponse'),)) 381 next(response_iterator) 382 383 self.assertIs(grpc.StatusCode.UNKNOWN, 384 exception_context.exception.code()) 385 386 def _failed_stream_request_stream_response(self, multi_callable): 387 requests = tuple( 388 b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) 389 request_iterator = iter(requests) 390 391 with self._control.fail(): 392 with self.assertRaises(grpc.RpcError) as exception_context: 393 response_iterator = multi_callable( 394 request_iterator, 395 metadata=(('test', 'FailedStreamRequestStreamResponse'),)) 396 tuple(response_iterator) 397 398 self.assertIs(grpc.StatusCode.UNKNOWN, 399 exception_context.exception.code()) 400 self.assertIs(grpc.StatusCode.UNKNOWN, response_iterator.code()) 401 402 def _ignored_unary_stream_request_future_unary_response( 403 self, multi_callable): 404 request = b'\x37\x17' 405 406 multi_callable(request, 407 metadata=(('test', 408 'IgnoredUnaryRequestStreamResponse'),)) 409 410 def _ignored_stream_request_stream_response(self, multi_callable): 411 requests = tuple( 412 b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) 413 request_iterator = iter(requests) 414 415 multi_callable(request_iterator, 416 metadata=(('test', 417 'IgnoredStreamRequestStreamResponse'),)) 418