1# Copyright 2020 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Testing the compatibility between AsyncIO stack and the old stack.""" 15 16import asyncio 17from concurrent.futures import ThreadPoolExecutor 18import logging 19import os 20import random 21import threading 22from typing import Callable, Iterable, Sequence, Tuple 23import unittest 24 25import grpc 26from grpc.experimental import aio 27 28from src.proto.grpc.testing import messages_pb2 29from src.proto.grpc.testing import test_pb2_grpc 30from tests.unit.framework.common import test_constants 31from tests_aio.unit import _common 32from tests_aio.unit._test_base import AioTestBase 33from tests_aio.unit._test_server import TestServiceServicer 34from tests_aio.unit._test_server import start_test_server 35 36_NUM_STREAM_RESPONSES = 5 37_REQUEST_PAYLOAD_SIZE = 7 38_RESPONSE_PAYLOAD_SIZE = 42 39_REQUEST = b"\x03\x07" 40 41 42def _unique_options() -> Sequence[Tuple[str, float]]: 43 return (("iv", random.random()),) 44 45 46@unittest.skipIf( 47 os.environ.get("GRPC_ASYNCIO_ENGINE", "").lower() == "custom_io_manager", 48 "Compatible mode needs POLLER completion queue.", 49) 50class TestCompatibility(AioTestBase): 51 async def setUp(self): 52 self._async_server = aio.server( 53 options=(("grpc.so_reuseport", 0),), 54 migration_thread_pool=ThreadPoolExecutor(), 55 ) 56 57 test_pb2_grpc.add_TestServiceServicer_to_server( 58 TestServiceServicer(), self._async_server 59 ) 60 self._adhoc_handlers = _common.AdhocGenericHandler() 61 self._async_server.add_generic_rpc_handlers((self._adhoc_handlers,)) 62 63 port = self._async_server.add_insecure_port("[::]:0") 64 address = "localhost:%d" % port 65 await self._async_server.start() 66 67 # Create async stub 68 self._async_channel = aio.insecure_channel( 69 address, options=_unique_options() 70 ) 71 self._async_stub = test_pb2_grpc.TestServiceStub(self._async_channel) 72 73 # Create sync stub 74 self._sync_channel = grpc.insecure_channel( 75 address, options=_unique_options() 76 ) 77 self._sync_stub = test_pb2_grpc.TestServiceStub(self._sync_channel) 78 79 async def tearDown(self): 80 self._sync_channel.close() 81 await self._async_channel.close() 82 await self._async_server.stop(None) 83 84 async def _run_in_another_thread(self, func: Callable[[], None]): 85 work_done = asyncio.Event() 86 87 def thread_work(): 88 func() 89 self.loop.call_soon_threadsafe(work_done.set) 90 91 thread = threading.Thread(target=thread_work, daemon=True) 92 thread.start() 93 await work_done.wait() 94 thread.join() 95 96 async def test_unary_unary(self): 97 # Calling async API in this thread 98 await self._async_stub.UnaryCall( 99 messages_pb2.SimpleRequest(), timeout=test_constants.LONG_TIMEOUT 100 ) 101 102 # Calling sync API in a different thread 103 def sync_work() -> None: 104 response, call = self._sync_stub.UnaryCall.with_call( 105 messages_pb2.SimpleRequest(), 106 timeout=test_constants.LONG_TIMEOUT, 107 ) 108 self.assertIsInstance(response, messages_pb2.SimpleResponse) 109 self.assertEqual(grpc.StatusCode.OK, call.code()) 110 111 await self._run_in_another_thread(sync_work) 112 113 async def test_unary_stream(self): 114 request = messages_pb2.StreamingOutputCallRequest() 115 for _ in range(_NUM_STREAM_RESPONSES): 116 request.response_parameters.append( 117 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) 118 ) 119 120 # Calling async API in this thread 121 call = self._async_stub.StreamingOutputCall(request) 122 123 for _ in range(_NUM_STREAM_RESPONSES): 124 await call.read() 125 self.assertEqual(grpc.StatusCode.OK, await call.code()) 126 127 # Calling sync API in a different thread 128 def sync_work() -> None: 129 response_iterator = self._sync_stub.StreamingOutputCall(request) 130 for response in response_iterator: 131 assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body) 132 self.assertEqual(grpc.StatusCode.OK, response_iterator.code()) 133 134 await self._run_in_another_thread(sync_work) 135 136 async def test_stream_unary(self): 137 payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) 138 request = messages_pb2.StreamingInputCallRequest(payload=payload) 139 140 # Calling async API in this thread 141 async def gen(): 142 for _ in range(_NUM_STREAM_RESPONSES): 143 yield request 144 145 response = await self._async_stub.StreamingInputCall(gen()) 146 self.assertEqual( 147 _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, 148 response.aggregated_payload_size, 149 ) 150 151 # Calling sync API in a different thread 152 def sync_work() -> None: 153 response = self._sync_stub.StreamingInputCall( 154 iter([request] * _NUM_STREAM_RESPONSES) 155 ) 156 self.assertEqual( 157 _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, 158 response.aggregated_payload_size, 159 ) 160 161 await self._run_in_another_thread(sync_work) 162 163 async def test_stream_stream(self): 164 request = messages_pb2.StreamingOutputCallRequest() 165 request.response_parameters.append( 166 messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) 167 ) 168 169 # Calling async API in this thread 170 call = self._async_stub.FullDuplexCall() 171 172 for _ in range(_NUM_STREAM_RESPONSES): 173 await call.write(request) 174 response = await call.read() 175 assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body) 176 177 await call.done_writing() 178 assert await call.code() == grpc.StatusCode.OK 179 180 # Calling sync API in a different thread 181 def sync_work() -> None: 182 response_iterator = self._sync_stub.FullDuplexCall(iter([request])) 183 for response in response_iterator: 184 assert _RESPONSE_PAYLOAD_SIZE == len(response.payload.body) 185 self.assertEqual(grpc.StatusCode.OK, response_iterator.code()) 186 187 await self._run_in_another_thread(sync_work) 188 189 async def test_server(self): 190 class GenericHandlers(grpc.GenericRpcHandler): 191 def service(self, handler_call_details): 192 return grpc.unary_unary_rpc_method_handler(lambda x, _: x) 193 194 # It's fine to instantiate server object in the event loop thread. 195 # The server will spawn its own serving thread. 196 server = grpc.server( 197 ThreadPoolExecutor(), handlers=(GenericHandlers(),) 198 ) 199 port = server.add_insecure_port("localhost:0") 200 server.start() 201 202 def sync_work() -> None: 203 for _ in range(100): 204 with grpc.insecure_channel("localhost:%d" % port) as channel: 205 response = channel.unary_unary("/test/test")(b"\x07\x08") 206 self.assertEqual(response, b"\x07\x08") 207 208 await self._run_in_another_thread(sync_work) 209 210 async def test_many_loop(self): 211 address, server = await start_test_server() 212 213 # Run another loop in another thread 214 def sync_work(): 215 async def async_work(): 216 # Create async stub 217 async_channel = aio.insecure_channel( 218 address, options=_unique_options() 219 ) 220 async_stub = test_pb2_grpc.TestServiceStub(async_channel) 221 222 call = async_stub.UnaryCall(messages_pb2.SimpleRequest()) 223 response = await call 224 self.assertIsInstance(response, messages_pb2.SimpleResponse) 225 self.assertEqual(grpc.StatusCode.OK, await call.code()) 226 227 loop = asyncio.new_event_loop() 228 loop.run_until_complete(async_work()) 229 230 await self._run_in_another_thread(sync_work) 231 await server.stop(None) 232 233 async def test_sync_unary_unary_success(self): 234 @grpc.unary_unary_rpc_method_handler 235 def echo_unary_unary(request: bytes, unused_context): 236 return request 237 238 self._adhoc_handlers.set_adhoc_handler(echo_unary_unary) 239 response = await self._async_channel.unary_unary(_common.ADHOC_METHOD)( 240 _REQUEST 241 ) 242 self.assertEqual(_REQUEST, response) 243 244 async def test_sync_unary_unary_metadata(self): 245 metadata = (("unique", "key-42"),) 246 247 @grpc.unary_unary_rpc_method_handler 248 def metadata_unary_unary(request: bytes, context: grpc.ServicerContext): 249 context.send_initial_metadata(metadata) 250 return request 251 252 self._adhoc_handlers.set_adhoc_handler(metadata_unary_unary) 253 call = self._async_channel.unary_unary(_common.ADHOC_METHOD)(_REQUEST) 254 self.assertTrue( 255 _common.seen_metadata( 256 aio.Metadata(*metadata), await call.initial_metadata() 257 ) 258 ) 259 260 async def test_sync_unary_unary_abort(self): 261 @grpc.unary_unary_rpc_method_handler 262 def abort_unary_unary(request: bytes, context: grpc.ServicerContext): 263 context.abort(grpc.StatusCode.INTERNAL, "Test") 264 265 self._adhoc_handlers.set_adhoc_handler(abort_unary_unary) 266 with self.assertRaises(aio.AioRpcError) as exception_context: 267 await self._async_channel.unary_unary(_common.ADHOC_METHOD)( 268 _REQUEST 269 ) 270 self.assertEqual( 271 grpc.StatusCode.INTERNAL, exception_context.exception.code() 272 ) 273 274 async def test_sync_unary_unary_set_code(self): 275 @grpc.unary_unary_rpc_method_handler 276 def set_code_unary_unary(request: bytes, context: grpc.ServicerContext): 277 context.set_code(grpc.StatusCode.INTERNAL) 278 279 self._adhoc_handlers.set_adhoc_handler(set_code_unary_unary) 280 with self.assertRaises(aio.AioRpcError) as exception_context: 281 await self._async_channel.unary_unary(_common.ADHOC_METHOD)( 282 _REQUEST 283 ) 284 self.assertEqual( 285 grpc.StatusCode.INTERNAL, exception_context.exception.code() 286 ) 287 288 async def test_sync_unary_stream_success(self): 289 @grpc.unary_stream_rpc_method_handler 290 def echo_unary_stream(request: bytes, unused_context): 291 for _ in range(_NUM_STREAM_RESPONSES): 292 yield request 293 294 self._adhoc_handlers.set_adhoc_handler(echo_unary_stream) 295 call = self._async_channel.unary_stream(_common.ADHOC_METHOD)(_REQUEST) 296 async for response in call: 297 self.assertEqual(_REQUEST, response) 298 299 async def test_sync_unary_stream_error(self): 300 @grpc.unary_stream_rpc_method_handler 301 def error_unary_stream(request: bytes, unused_context): 302 for _ in range(_NUM_STREAM_RESPONSES): 303 yield request 304 raise RuntimeError("Test") 305 306 self._adhoc_handlers.set_adhoc_handler(error_unary_stream) 307 call = self._async_channel.unary_stream(_common.ADHOC_METHOD)(_REQUEST) 308 with self.assertRaises(aio.AioRpcError) as exception_context: 309 async for response in call: 310 self.assertEqual(_REQUEST, response) 311 self.assertEqual( 312 grpc.StatusCode.UNKNOWN, exception_context.exception.code() 313 ) 314 315 async def test_sync_stream_unary_success(self): 316 @grpc.stream_unary_rpc_method_handler 317 def echo_stream_unary( 318 request_iterator: Iterable[bytes], unused_context 319 ): 320 self.assertEqual(len(list(request_iterator)), _NUM_STREAM_RESPONSES) 321 return _REQUEST 322 323 self._adhoc_handlers.set_adhoc_handler(echo_stream_unary) 324 request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES) 325 response = await self._async_channel.stream_unary(_common.ADHOC_METHOD)( 326 request_iterator 327 ) 328 self.assertEqual(_REQUEST, response) 329 330 async def test_sync_stream_unary_error(self): 331 @grpc.stream_unary_rpc_method_handler 332 def echo_stream_unary( 333 request_iterator: Iterable[bytes], unused_context 334 ): 335 self.assertEqual(len(list(request_iterator)), _NUM_STREAM_RESPONSES) 336 raise RuntimeError("Test") 337 338 self._adhoc_handlers.set_adhoc_handler(echo_stream_unary) 339 request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES) 340 with self.assertRaises(aio.AioRpcError) as exception_context: 341 response = await self._async_channel.stream_unary( 342 _common.ADHOC_METHOD 343 )(request_iterator) 344 self.assertEqual( 345 grpc.StatusCode.UNKNOWN, exception_context.exception.code() 346 ) 347 348 async def test_sync_stream_stream_success(self): 349 @grpc.stream_stream_rpc_method_handler 350 def echo_stream_stream( 351 request_iterator: Iterable[bytes], unused_context 352 ): 353 for request in request_iterator: 354 yield request 355 356 self._adhoc_handlers.set_adhoc_handler(echo_stream_stream) 357 request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES) 358 call = self._async_channel.stream_stream(_common.ADHOC_METHOD)( 359 request_iterator 360 ) 361 async for response in call: 362 self.assertEqual(_REQUEST, response) 363 364 async def test_sync_stream_stream_error(self): 365 @grpc.stream_stream_rpc_method_handler 366 def echo_stream_stream( 367 request_iterator: Iterable[bytes], unused_context 368 ): 369 for request in request_iterator: 370 yield request 371 raise RuntimeError("test") 372 373 self._adhoc_handlers.set_adhoc_handler(echo_stream_stream) 374 request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES) 375 call = self._async_channel.stream_stream(_common.ADHOC_METHOD)( 376 request_iterator 377 ) 378 with self.assertRaises(aio.AioRpcError) as exception_context: 379 async for response in call: 380 self.assertEqual(_REQUEST, response) 381 self.assertEqual( 382 grpc.StatusCode.UNKNOWN, exception_context.exception.code() 383 ) 384 385 386if __name__ == "__main__": 387 logging.basicConfig(level=logging.DEBUG) 388 unittest.main(verbosity=2) 389