• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests behavior around the compression mechanism."""
15
16import asyncio
17import logging
18import platform
19import random
20import unittest
21
22import grpc
23from grpc.experimental import aio
24
25from tests_aio.unit import _common
26from tests_aio.unit._test_base import AioTestBase
27
28_GZIP_CHANNEL_ARGUMENT = ("grpc.default_compression_algorithm", 2)
29_GZIP_DISABLED_CHANNEL_ARGUMENT = (
30    "grpc.compression_enabled_algorithms_bitset",
31    3,
32)
33_DEFLATE_DISABLED_CHANNEL_ARGUMENT = (
34    "grpc.compression_enabled_algorithms_bitset",
35    5,
36)
37
38_TEST_UNARY_UNARY = "/test/TestUnaryUnary"
39_TEST_SET_COMPRESSION = "/test/TestSetCompression"
40_TEST_DISABLE_COMPRESSION_UNARY = "/test/TestDisableCompressionUnary"
41_TEST_DISABLE_COMPRESSION_STREAM = "/test/TestDisableCompressionStream"
42
43_REQUEST = b"\x01" * 100
44_RESPONSE = b"\x02" * 100
45
46
47async def _test_unary_unary(unused_request, unused_context):
48    return _RESPONSE
49
50
51async def _test_set_compression(unused_request_iterator, context):
52    assert _REQUEST == await context.read()
53    context.set_compression(grpc.Compression.Deflate)
54    await context.write(_RESPONSE)
55    try:
56        context.set_compression(grpc.Compression.Deflate)
57    except RuntimeError:
58        # NOTE(lidiz) Testing if the servicer context raises exception when
59        # the set_compression method is called after initial_metadata sent.
60        # After the initial_metadata sent, the server-side has no control over
61        # which compression algorithm it should use.
62        pass
63    else:
64        raise ValueError(
65            "Expecting exceptions if set_compression is not effective"
66        )
67
68
69async def _test_disable_compression_unary(request, context):
70    assert _REQUEST == request
71    context.set_compression(grpc.Compression.Deflate)
72    context.disable_next_message_compression()
73    return _RESPONSE
74
75
76async def _test_disable_compression_stream(unused_request_iterator, context):
77    assert _REQUEST == await context.read()
78    context.set_compression(grpc.Compression.Deflate)
79    await context.write(_RESPONSE)
80    context.disable_next_message_compression()
81    await context.write(_RESPONSE)
82    await context.write(_RESPONSE)
83
84
85_ROUTING_TABLE = {
86    _TEST_UNARY_UNARY: grpc.unary_unary_rpc_method_handler(_test_unary_unary),
87    _TEST_SET_COMPRESSION: grpc.stream_stream_rpc_method_handler(
88        _test_set_compression
89    ),
90    _TEST_DISABLE_COMPRESSION_UNARY: grpc.unary_unary_rpc_method_handler(
91        _test_disable_compression_unary
92    ),
93    _TEST_DISABLE_COMPRESSION_STREAM: grpc.stream_stream_rpc_method_handler(
94        _test_disable_compression_stream
95    ),
96}
97
98
99class _GenericHandler(grpc.GenericRpcHandler):
100    def service(self, handler_call_details):
101        return _ROUTING_TABLE.get(handler_call_details.method)
102
103
104async def _start_test_server(options=None):
105    server = aio.server(options=options)
106    port = server.add_insecure_port("[::]:0")
107    server.add_generic_rpc_handlers((_GenericHandler(),))
108    await server.start()
109    return f"localhost:{port}", server
110
111
112class TestCompression(AioTestBase):
113    async def setUp(self):
114        server_options = (_GZIP_DISABLED_CHANNEL_ARGUMENT,)
115        self._address, self._server = await _start_test_server(server_options)
116        self._channel = aio.insecure_channel(self._address)
117
118    async def tearDown(self):
119        await self._channel.close()
120        await self._server.stop(None)
121
122    async def test_channel_level_compression_baned_compression(self):
123        # GZIP is disabled, this call should fail
124        async with aio.insecure_channel(
125            self._address, compression=grpc.Compression.Gzip
126        ) as channel:
127            multicallable = channel.unary_unary(_TEST_UNARY_UNARY)
128            call = multicallable(_REQUEST)
129            with self.assertRaises(aio.AioRpcError) as exception_context:
130                await call
131            rpc_error = exception_context.exception
132            self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code())
133
134    async def test_channel_level_compression_allowed_compression(self):
135        # Deflate is allowed, this call should succeed
136        async with aio.insecure_channel(
137            self._address, compression=grpc.Compression.Deflate
138        ) as channel:
139            multicallable = channel.unary_unary(_TEST_UNARY_UNARY)
140            call = multicallable(_REQUEST)
141            self.assertEqual(grpc.StatusCode.OK, await call.code())
142
143    async def test_client_call_level_compression_baned_compression(self):
144        multicallable = self._channel.unary_unary(_TEST_UNARY_UNARY)
145
146        # GZIP is disabled, this call should fail
147        call = multicallable(_REQUEST, compression=grpc.Compression.Gzip)
148        with self.assertRaises(aio.AioRpcError) as exception_context:
149            await call
150        rpc_error = exception_context.exception
151        self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code())
152
153    async def test_client_call_level_compression_allowed_compression(self):
154        multicallable = self._channel.unary_unary(_TEST_UNARY_UNARY)
155
156        # Deflate is allowed, this call should succeed
157        call = multicallable(_REQUEST, compression=grpc.Compression.Deflate)
158        self.assertEqual(grpc.StatusCode.OK, await call.code())
159
160    async def test_server_call_level_compression(self):
161        multicallable = self._channel.stream_stream(_TEST_SET_COMPRESSION)
162        call = multicallable()
163        await call.write(_REQUEST)
164        await call.done_writing()
165        self.assertEqual(_RESPONSE, await call.read())
166        self.assertEqual(grpc.StatusCode.OK, await call.code())
167
168    async def test_server_disable_compression_unary(self):
169        multicallable = self._channel.unary_unary(
170            _TEST_DISABLE_COMPRESSION_UNARY
171        )
172        call = multicallable(_REQUEST)
173        self.assertEqual(_RESPONSE, await call)
174        self.assertEqual(grpc.StatusCode.OK, await call.code())
175
176    async def test_server_disable_compression_stream(self):
177        multicallable = self._channel.stream_stream(
178            _TEST_DISABLE_COMPRESSION_STREAM
179        )
180        call = multicallable()
181        await call.write(_REQUEST)
182        await call.done_writing()
183        self.assertEqual(_RESPONSE, await call.read())
184        self.assertEqual(_RESPONSE, await call.read())
185        self.assertEqual(_RESPONSE, await call.read())
186        self.assertEqual(grpc.StatusCode.OK, await call.code())
187
188    async def test_server_default_compression_algorithm(self):
189        server = aio.server(compression=grpc.Compression.Deflate)
190        port = server.add_insecure_port("[::]:0")
191        server.add_generic_rpc_handlers((_GenericHandler(),))
192        await server.start()
193
194        async with aio.insecure_channel(f"localhost:{port}") as channel:
195            multicallable = channel.unary_unary(_TEST_UNARY_UNARY)
196            call = multicallable(_REQUEST)
197            self.assertEqual(_RESPONSE, await call)
198            self.assertEqual(grpc.StatusCode.OK, await call.code())
199
200        await server.stop(None)
201
202
203if __name__ == "__main__":
204    logging.basicConfig(level=logging.DEBUG)
205    unittest.main(verbosity=2)
206