1# Copyright 2019 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests behavior around the Core channel arguments.""" 15 16import asyncio 17import errno 18import logging 19import platform 20import random 21import unittest 22 23import grpc 24from grpc.experimental import aio 25 26from src.proto.grpc.testing import messages_pb2 27from src.proto.grpc.testing import test_pb2_grpc 28from tests.unit.framework import common 29from tests_aio.unit._test_base import AioTestBase 30from tests_aio.unit._test_server import start_test_server 31 32_RANDOM_SEED = 42 33 34_ENABLE_REUSE_PORT = "SO_REUSEPORT enabled" 35_DISABLE_REUSE_PORT = "SO_REUSEPORT disabled" 36_SOCKET_OPT_SO_REUSEPORT = "grpc.so_reuseport" 37_OPTIONS = ( 38 (_ENABLE_REUSE_PORT, ((_SOCKET_OPT_SO_REUSEPORT, 1),)), 39 (_DISABLE_REUSE_PORT, ((_SOCKET_OPT_SO_REUSEPORT, 0),)), 40) 41 42_NUM_SERVER_CREATED = 5 43 44_GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH = "grpc.max_receive_message_length" 45_MAX_MESSAGE_LENGTH = 1024 46 47_ADDRESS_TOKEN_ERRNO = errno.EADDRINUSE, errno.ENOSR 48 49 50class _TestPointerWrapper(object): 51 def __int__(self): 52 return 123456 53 54 55_TEST_CHANNEL_ARGS = ( 56 ("arg1", b"bytes_val"), 57 ("arg2", "str_val"), 58 ("arg3", 1), 59 (b"arg4", "str_val"), 60 ("arg6", _TestPointerWrapper()), 61) 62 63_INVALID_TEST_CHANNEL_ARGS = [ 64 {"foo": "bar"}, 65 (("key",),), 66 "str", 67] 68 69 70async def test_if_reuse_port_enabled(server: aio.Server): 71 port = server.add_insecure_port("localhost:0") 72 await server.start() 73 74 try: 75 with common.bound_socket( 76 bind_address="localhost", 77 port=port, 78 listen=False, 79 ) as (unused_host, bound_port): 80 assert bound_port == port 81 except OSError as e: 82 if e.errno in _ADDRESS_TOKEN_ERRNO: 83 return False 84 else: 85 logging.exception(e) 86 raise 87 else: 88 return True 89 90 91class TestChannelArgument(AioTestBase): 92 async def setUp(self): 93 random.seed(_RANDOM_SEED) 94 95 @unittest.skipIf( 96 platform.system() == "Windows", 97 "SO_REUSEPORT only available in Linux-like OS.", 98 ) 99 @unittest.skipIf( 100 "aarch64" in platform.machine(), 101 "SO_REUSEPORT needs to be enabled in Core's port.h.", 102 ) 103 async def test_server_so_reuse_port_is_set_properly(self): 104 async def test_body(): 105 fact, options = random.choice(_OPTIONS) 106 server = aio.server(options=options) 107 try: 108 result = await test_if_reuse_port_enabled(server) 109 if fact == _ENABLE_REUSE_PORT and not result: 110 self.fail( 111 "Enabled reuse port in options, but not observed in" 112 " socket" 113 ) 114 elif fact == _DISABLE_REUSE_PORT and result: 115 self.fail( 116 "Disabled reuse port in options, but observed in socket" 117 ) 118 finally: 119 await server.stop(None) 120 121 # Creating a lot of servers concurrently 122 await asyncio.gather(*(test_body() for _ in range(_NUM_SERVER_CREATED))) 123 124 async def test_client(self): 125 # Do not segfault, or raise exception! 126 channel = aio.insecure_channel("[::]:0", options=_TEST_CHANNEL_ARGS) 127 await channel.close() 128 129 async def test_server(self): 130 # Do not segfault, or raise exception! 131 server = aio.server(options=_TEST_CHANNEL_ARGS) 132 await server.stop(None) 133 134 async def test_invalid_client_args(self): 135 for invalid_arg in _INVALID_TEST_CHANNEL_ARGS: 136 self.assertRaises( 137 (ValueError, TypeError), 138 aio.insecure_channel, 139 "[::]:0", 140 options=invalid_arg, 141 ) 142 143 async def test_max_message_length_applied(self): 144 address, server = await start_test_server() 145 146 async with aio.insecure_channel( 147 address, 148 options=( 149 (_GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, _MAX_MESSAGE_LENGTH), 150 ), 151 ) as channel: 152 stub = test_pb2_grpc.TestServiceStub(channel) 153 154 request = messages_pb2.StreamingOutputCallRequest() 155 # First request will pass 156 request.response_parameters.append( 157 messages_pb2.ResponseParameters( 158 size=_MAX_MESSAGE_LENGTH // 2, 159 ) 160 ) 161 # Second request should fail 162 request.response_parameters.append( 163 messages_pb2.ResponseParameters( 164 size=_MAX_MESSAGE_LENGTH * 2, 165 ) 166 ) 167 168 call = stub.StreamingOutputCall(request) 169 170 response = await call.read() 171 self.assertEqual( 172 _MAX_MESSAGE_LENGTH // 2, len(response.payload.body) 173 ) 174 175 with self.assertRaises(aio.AioRpcError) as exception_context: 176 await call.read() 177 rpc_error = exception_context.exception 178 self.assertEqual( 179 grpc.StatusCode.RESOURCE_EXHAUSTED, rpc_error.code() 180 ) 181 self.assertIn(str(_MAX_MESSAGE_LENGTH), rpc_error.details()) 182 183 self.assertEqual( 184 grpc.StatusCode.RESOURCE_EXHAUSTED, await call.code() 185 ) 186 187 await server.stop(None) 188 189 190if __name__ == "__main__": 191 logging.basicConfig(level=logging.DEBUG) 192 unittest.main(verbosity=2) 193