• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests behavior around the Core channel arguments."""
15
16import asyncio
17import errno
18import logging
19import platform
20import random
21import unittest
22
23import grpc
24from grpc.experimental import aio
25
26from src.proto.grpc.testing import messages_pb2
27from src.proto.grpc.testing import test_pb2_grpc
28from tests.unit.framework import common
29from tests_aio.unit._test_base import AioTestBase
30from tests_aio.unit._test_server import start_test_server
31
32_RANDOM_SEED = 42
33
34_ENABLE_REUSE_PORT = "SO_REUSEPORT enabled"
35_DISABLE_REUSE_PORT = "SO_REUSEPORT disabled"
36_SOCKET_OPT_SO_REUSEPORT = "grpc.so_reuseport"
37_OPTIONS = (
38    (_ENABLE_REUSE_PORT, ((_SOCKET_OPT_SO_REUSEPORT, 1),)),
39    (_DISABLE_REUSE_PORT, ((_SOCKET_OPT_SO_REUSEPORT, 0),)),
40)
41
42_NUM_SERVER_CREATED = 5
43
44_GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH = "grpc.max_receive_message_length"
45_MAX_MESSAGE_LENGTH = 1024
46
47_ADDRESS_TOKEN_ERRNO = errno.EADDRINUSE, errno.ENOSR
48
49
50class _TestPointerWrapper(object):
51    def __int__(self):
52        return 123456
53
54
55_TEST_CHANNEL_ARGS = (
56    ("arg1", b"bytes_val"),
57    ("arg2", "str_val"),
58    ("arg3", 1),
59    (b"arg4", "str_val"),
60    ("arg6", _TestPointerWrapper()),
61)
62
63_INVALID_TEST_CHANNEL_ARGS = [
64    {"foo": "bar"},
65    (("key",),),
66    "str",
67]
68
69
70async def test_if_reuse_port_enabled(server: aio.Server):
71    port = server.add_insecure_port("localhost:0")
72    await server.start()
73
74    try:
75        with common.bound_socket(
76            bind_address="localhost",
77            port=port,
78            listen=False,
79        ) as (unused_host, bound_port):
80            assert bound_port == port
81    except OSError as e:
82        if e.errno in _ADDRESS_TOKEN_ERRNO:
83            return False
84        else:
85            logging.exception(e)
86            raise
87    else:
88        return True
89
90
91class TestChannelArgument(AioTestBase):
92    async def setUp(self):
93        random.seed(_RANDOM_SEED)
94
95    @unittest.skipIf(
96        platform.system() == "Windows",
97        "SO_REUSEPORT only available in Linux-like OS.",
98    )
99    @unittest.skipIf(
100        "aarch64" in platform.machine(),
101        "SO_REUSEPORT needs to be enabled in Core's port.h.",
102    )
103    async def test_server_so_reuse_port_is_set_properly(self):
104        async def test_body():
105            fact, options = random.choice(_OPTIONS)
106            server = aio.server(options=options)
107            try:
108                result = await test_if_reuse_port_enabled(server)
109                if fact == _ENABLE_REUSE_PORT and not result:
110                    self.fail(
111                        "Enabled reuse port in options, but not observed in"
112                        " socket"
113                    )
114                elif fact == _DISABLE_REUSE_PORT and result:
115                    self.fail(
116                        "Disabled reuse port in options, but observed in socket"
117                    )
118            finally:
119                await server.stop(None)
120
121        # Creating a lot of servers concurrently
122        await asyncio.gather(*(test_body() for _ in range(_NUM_SERVER_CREATED)))
123
124    async def test_client(self):
125        # Do not segfault, or raise exception!
126        channel = aio.insecure_channel("[::]:0", options=_TEST_CHANNEL_ARGS)
127        await channel.close()
128
129    async def test_server(self):
130        # Do not segfault, or raise exception!
131        server = aio.server(options=_TEST_CHANNEL_ARGS)
132        await server.stop(None)
133
134    async def test_invalid_client_args(self):
135        for invalid_arg in _INVALID_TEST_CHANNEL_ARGS:
136            self.assertRaises(
137                (ValueError, TypeError),
138                aio.insecure_channel,
139                "[::]:0",
140                options=invalid_arg,
141            )
142
143    async def test_max_message_length_applied(self):
144        address, server = await start_test_server()
145
146        async with aio.insecure_channel(
147            address,
148            options=(
149                (_GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, _MAX_MESSAGE_LENGTH),
150            ),
151        ) as channel:
152            stub = test_pb2_grpc.TestServiceStub(channel)
153
154            request = messages_pb2.StreamingOutputCallRequest()
155            # First request will pass
156            request.response_parameters.append(
157                messages_pb2.ResponseParameters(
158                    size=_MAX_MESSAGE_LENGTH // 2,
159                )
160            )
161            # Second request should fail
162            request.response_parameters.append(
163                messages_pb2.ResponseParameters(
164                    size=_MAX_MESSAGE_LENGTH * 2,
165                )
166            )
167
168            call = stub.StreamingOutputCall(request)
169
170            response = await call.read()
171            self.assertEqual(
172                _MAX_MESSAGE_LENGTH // 2, len(response.payload.body)
173            )
174
175            with self.assertRaises(aio.AioRpcError) as exception_context:
176                await call.read()
177            rpc_error = exception_context.exception
178            self.assertEqual(
179                grpc.StatusCode.RESOURCE_EXHAUSTED, rpc_error.code()
180            )
181            self.assertIn(str(_MAX_MESSAGE_LENGTH), rpc_error.details())
182
183            self.assertEqual(
184                grpc.StatusCode.RESOURCE_EXHAUSTED, await call.code()
185            )
186
187        await server.stop(None)
188
189
190if __name__ == "__main__":
191    logging.basicConfig(level=logging.DEBUG)
192    unittest.main(verbosity=2)
193