1# Copyright 2020 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Testing the compatibility between AsyncIO stack and the old stack.""" 15 16import asyncio 17import logging 18import os 19import random 20import threading 21import unittest 22from concurrent.futures import ThreadPoolExecutor 23from typing import Callable, Iterable, Sequence, Tuple 24 25import grpc 26from grpc.experimental import aio 27 28from src.proto.grpc.testing import messages_pb2, test_pb2_grpc 29from tests.unit.framework.common import test_constants 30from tests_aio.unit import _common 31from tests_aio.unit._test_base import AioTestBase 32from tests_aio.unit._test_server import TestServiceServicer, start_test_server 33 34_NUM_STREAM_RESPONSES = 5 35_REQUEST_PAYLOAD_SIZE = 7 36_RESPONSE_PAYLOAD_SIZE = 42 37_REQUEST = b'\x03\x07' 38_ADHOC_METHOD = '/test/AdHoc' 39 40 41def _unique_options() -> Sequence[Tuple[str, float]]: 42 return (('iv', random.random()),) 43 44 45class _AdhocGenericHandler(grpc.GenericRpcHandler): 46 _handler: grpc.RpcMethodHandler 47 48 def __init__(self): 49 self._handler = None 50 51 def set_adhoc_handler(self, handler: grpc.RpcMethodHandler): 52 self._handler = handler 53 54 def service(self, handler_call_details): 55 if handler_call_details.method == _ADHOC_METHOD: 56 return self._handler 57 else: 58 return None 59 60 61@unittest.skipIf( 62 os.environ.get('GRPC_ASYNCIO_ENGINE', '').lower() == 'custom_io_manager', 63 'Compatible mode needs POLLER completion queue.') 64class TestCompatibility(AioTestBase): 65 66 async def setUp(self): 67 self._async_server = aio.server( 68 options=(('grpc.so_reuseport', 0),), 69 migration_thread_pool=ThreadPoolExecutor()) 70 71 test_pb2_grpc.add_TestServiceServicer_to_server(TestServiceServicer(), 72 self._async_server) 73 self._adhoc_handlers = _AdhocGenericHandler() 74 self._async_server.add_generic_rpc_handlers((self._adhoc_handlers,)) 75 76 port = self._async_server.add_insecure_port('[::]:0') 77 address = 'localhost:%d' % port 78 await self._async_server.start() 79 80 # Create async stub 81 self._async_channel = aio.insecure_channel(address, 82 options=_unique_options()) 83 self._async_stub = test_pb2_grpc.TestServiceStub(self._async_channel) 84 85 # Create sync stub 86 self._sync_channel = grpc.insecure_channel(address, 87 options=_unique_options()) 88 self._sync_stub = test_pb2_grpc.TestServiceStub(self._sync_channel) 89 90 async def tearDown(self): 91 self._sync_channel.close() 92 await self._async_channel.close() 93 await self._async_server.stop(None) 94 95 async def _run_in_another_thread(self, func: Callable[[], None]): 96 work_done = asyncio.Event(loop=self.loop) 97 98 def thread_work(): 99 func() 100 self.loop.call_soon_threadsafe(work_done.set) 101 102 thread = threading.Thread(target=thread_work, daemon=True) 103 thread.start() 104 await work_done.wait() 105 thread.join() 106 107 async def test_unary_unary(self): 108 # Calling async API in this thread 109 await self._async_stub.UnaryCall(messages_pb2.SimpleRequest(), 110 timeout=test_constants.LONG_TIMEOUT) 111 112 # Calling sync API in a different thread 113 def sync_work() -> None: 114 response, call = self._sync_stub.UnaryCall.with_call( 115 messages_pb2.SimpleRequest(), 116 timeout=test_constants.LONG_TIMEOUT) 117 self.assertIsInstance(response, messages_pb2.SimpleResponse) 118 self.assertEqual(grpc.StatusCode.OK, call.code()) 119 120 await self._run_in_another_thread(sync_work) 121 122 async def test_unary_stream(self): 123 request = messages_pb2.StreamingOutputCallRequest() 124 for _ in range(_NUM_STREAM_RESPONSES): 125 request.response_parameters.append( 126 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) 127 128 # Calling async API in this thread 129 call = self._async_stub.StreamingOutputCall(request) 130 131 for _ in range(_NUM_STREAM_RESPONSES): 132 await call.read() 133 self.assertEqual(grpc.StatusCode.OK, await call.code()) 134 135 # Calling sync API in a different thread 136 def sync_work() -> None: 137 response_iterator = self._sync_stub.StreamingOutputCall(request) 138 for response in response_iterator: 139 assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body) 140 self.assertEqual(grpc.StatusCode.OK, response_iterator.code()) 141 142 await self._run_in_another_thread(sync_work) 143 144 async def test_stream_unary(self): 145 payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) 146 request = messages_pb2.StreamingInputCallRequest(payload=payload) 147 148 # Calling async API in this thread 149 async def gen(): 150 for _ in range(_NUM_STREAM_RESPONSES): 151 yield request 152 153 response = await self._async_stub.StreamingInputCall(gen()) 154 self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, 155 response.aggregated_payload_size) 156 157 # Calling sync API in a different thread 158 def sync_work() -> None: 159 response = self._sync_stub.StreamingInputCall( 160 iter([request] * _NUM_STREAM_RESPONSES)) 161 self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, 162 response.aggregated_payload_size) 163 164 await self._run_in_another_thread(sync_work) 165 166 async def test_stream_stream(self): 167 request = messages_pb2.StreamingOutputCallRequest() 168 request.response_parameters.append( 169 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) 170 171 # Calling async API in this thread 172 call = self._async_stub.FullDuplexCall() 173 174 for _ in range(_NUM_STREAM_RESPONSES): 175 await call.write(request) 176 response = await call.read() 177 assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body) 178 179 await call.done_writing() 180 assert await call.code() == grpc.StatusCode.OK 181 182 # Calling sync API in a different thread 183 def sync_work() -> None: 184 response_iterator = self._sync_stub.FullDuplexCall(iter([request])) 185 for response in response_iterator: 186 assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body) 187 self.assertEqual(grpc.StatusCode.OK, response_iterator.code()) 188 189 await self._run_in_another_thread(sync_work) 190 191 async def test_server(self): 192 193 class GenericHandlers(grpc.GenericRpcHandler): 194 195 def service(self, handler_call_details): 196 return grpc.unary_unary_rpc_method_handler(lambda x, _: x) 197 198 # It's fine to instantiate server object in the event loop thread. 199 # The server will spawn its own serving thread. 200 server = grpc.server(ThreadPoolExecutor(), 201 handlers=(GenericHandlers(),)) 202 port = server.add_insecure_port('localhost:0') 203 server.start() 204 205 def sync_work() -> None: 206 for _ in range(100): 207 with grpc.insecure_channel('localhost:%d' % port) as channel: 208 response = channel.unary_unary('/test/test')(b'\x07\x08') 209 self.assertEqual(response, b'\x07\x08') 210 211 await self._run_in_another_thread(sync_work) 212 213 async def test_many_loop(self): 214 address, server = await start_test_server() 215 216 # Run another loop in another thread 217 def sync_work(): 218 219 async def async_work(): 220 # Create async stub 221 async_channel = aio.insecure_channel(address, 222 options=_unique_options()) 223 async_stub = test_pb2_grpc.TestServiceStub(async_channel) 224 225 call = async_stub.UnaryCall(messages_pb2.SimpleRequest()) 226 response = await call 227 self.assertIsInstance(response, messages_pb2.SimpleResponse) 228 self.assertEqual(grpc.StatusCode.OK, await call.code()) 229 230 loop = asyncio.new_event_loop() 231 loop.run_until_complete(async_work()) 232 233 await self._run_in_another_thread(sync_work) 234 await server.stop(None) 235 236 async def test_sync_unary_unary_success(self): 237 238 @grpc.unary_unary_rpc_method_handler 239 def echo_unary_unary(request: bytes, unused_context): 240 return request 241 242 self._adhoc_handlers.set_adhoc_handler(echo_unary_unary) 243 response = await self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST 244 ) 245 self.assertEqual(_REQUEST, response) 246 247 async def test_sync_unary_unary_metadata(self): 248 metadata = (('unique', 'key-42'),) 249 250 @grpc.unary_unary_rpc_method_handler 251 def metadata_unary_unary(request: bytes, context: grpc.ServicerContext): 252 context.send_initial_metadata(metadata) 253 return request 254 255 self._adhoc_handlers.set_adhoc_handler(metadata_unary_unary) 256 call = self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST) 257 self.assertTrue( 258 _common.seen_metadata(aio.Metadata(*metadata), await 259 call.initial_metadata())) 260 261 async def test_sync_unary_unary_abort(self): 262 263 @grpc.unary_unary_rpc_method_handler 264 def abort_unary_unary(request: bytes, context: grpc.ServicerContext): 265 context.abort(grpc.StatusCode.INTERNAL, 'Test') 266 267 self._adhoc_handlers.set_adhoc_handler(abort_unary_unary) 268 with self.assertRaises(aio.AioRpcError) as exception_context: 269 await self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST) 270 self.assertEqual(grpc.StatusCode.INTERNAL, 271 exception_context.exception.code()) 272 273 async def test_sync_unary_unary_set_code(self): 274 275 @grpc.unary_unary_rpc_method_handler 276 def set_code_unary_unary(request: bytes, context: grpc.ServicerContext): 277 context.set_code(grpc.StatusCode.INTERNAL) 278 279 self._adhoc_handlers.set_adhoc_handler(set_code_unary_unary) 280 with self.assertRaises(aio.AioRpcError) as exception_context: 281 await self._async_channel.unary_unary(_ADHOC_METHOD)(_REQUEST) 282 self.assertEqual(grpc.StatusCode.INTERNAL, 283 exception_context.exception.code()) 284 285 async def test_sync_unary_stream_success(self): 286 287 @grpc.unary_stream_rpc_method_handler 288 def echo_unary_stream(request: bytes, unused_context): 289 for _ in range(_NUM_STREAM_RESPONSES): 290 yield request 291 292 self._adhoc_handlers.set_adhoc_handler(echo_unary_stream) 293 call = self._async_channel.unary_stream(_ADHOC_METHOD)(_REQUEST) 294 async for response in call: 295 self.assertEqual(_REQUEST, response) 296 297 async def test_sync_unary_stream_error(self): 298 299 @grpc.unary_stream_rpc_method_handler 300 def error_unary_stream(request: bytes, unused_context): 301 for _ in range(_NUM_STREAM_RESPONSES): 302 yield request 303 raise RuntimeError('Test') 304 305 self._adhoc_handlers.set_adhoc_handler(error_unary_stream) 306 call = self._async_channel.unary_stream(_ADHOC_METHOD)(_REQUEST) 307 with self.assertRaises(aio.AioRpcError) as exception_context: 308 async for response in call: 309 self.assertEqual(_REQUEST, response) 310 self.assertEqual(grpc.StatusCode.UNKNOWN, 311 exception_context.exception.code()) 312 313 async def test_sync_stream_unary_success(self): 314 315 @grpc.stream_unary_rpc_method_handler 316 def echo_stream_unary(request_iterator: Iterable[bytes], 317 unused_context): 318 self.assertEqual(len(list(request_iterator)), _NUM_STREAM_RESPONSES) 319 return _REQUEST 320 321 self._adhoc_handlers.set_adhoc_handler(echo_stream_unary) 322 request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES) 323 response = await self._async_channel.stream_unary(_ADHOC_METHOD)( 324 request_iterator) 325 self.assertEqual(_REQUEST, response) 326 327 async def test_sync_stream_unary_error(self): 328 329 @grpc.stream_unary_rpc_method_handler 330 def echo_stream_unary(request_iterator: Iterable[bytes], 331 unused_context): 332 self.assertEqual(len(list(request_iterator)), _NUM_STREAM_RESPONSES) 333 raise RuntimeError('Test') 334 335 self._adhoc_handlers.set_adhoc_handler(echo_stream_unary) 336 request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES) 337 with self.assertRaises(aio.AioRpcError) as exception_context: 338 response = await self._async_channel.stream_unary(_ADHOC_METHOD)( 339 request_iterator) 340 self.assertEqual(grpc.StatusCode.UNKNOWN, 341 exception_context.exception.code()) 342 343 async def test_sync_stream_stream_success(self): 344 345 @grpc.stream_stream_rpc_method_handler 346 def echo_stream_stream(request_iterator: Iterable[bytes], 347 unused_context): 348 for request in request_iterator: 349 yield request 350 351 self._adhoc_handlers.set_adhoc_handler(echo_stream_stream) 352 request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES) 353 call = self._async_channel.stream_stream(_ADHOC_METHOD)( 354 request_iterator) 355 async for response in call: 356 self.assertEqual(_REQUEST, response) 357 358 async def test_sync_stream_stream_error(self): 359 360 @grpc.stream_stream_rpc_method_handler 361 def echo_stream_stream(request_iterator: Iterable[bytes], 362 unused_context): 363 for request in request_iterator: 364 yield request 365 raise RuntimeError('test') 366 367 self._adhoc_handlers.set_adhoc_handler(echo_stream_stream) 368 request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES) 369 call = self._async_channel.stream_stream(_ADHOC_METHOD)( 370 request_iterator) 371 with self.assertRaises(aio.AioRpcError) as exception_context: 372 async for response in call: 373 self.assertEqual(_REQUEST, response) 374 self.assertEqual(grpc.StatusCode.UNKNOWN, 375 exception_context.exception.code()) 376 377 378if __name__ == '__main__': 379 logging.basicConfig(level=logging.DEBUG) 380 unittest.main(verbosity=2) 381