1# Copyright 2020 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests behavior around the compression mechanism.""" 15 16import asyncio 17import logging 18import platform 19import random 20import unittest 21 22import grpc 23from grpc.experimental import aio 24 25from tests_aio.unit._test_base import AioTestBase 26from tests_aio.unit import _common 27 28_GZIP_CHANNEL_ARGUMENT = ('grpc.default_compression_algorithm', 2) 29_GZIP_DISABLED_CHANNEL_ARGUMENT = ('grpc.compression_enabled_algorithms_bitset', 30 3) 31_DEFLATE_DISABLED_CHANNEL_ARGUMENT = ( 32 'grpc.compression_enabled_algorithms_bitset', 5) 33 34_TEST_UNARY_UNARY = '/test/TestUnaryUnary' 35_TEST_SET_COMPRESSION = '/test/TestSetCompression' 36_TEST_DISABLE_COMPRESSION_UNARY = '/test/TestDisableCompressionUnary' 37_TEST_DISABLE_COMPRESSION_STREAM = '/test/TestDisableCompressionStream' 38 39_REQUEST = b'\x01' * 100 40_RESPONSE = b'\x02' * 100 41 42 43async def _test_unary_unary(unused_request, unused_context): 44 return _RESPONSE 45 46 47async def _test_set_compression(unused_request_iterator, context): 48 assert _REQUEST == await context.read() 49 context.set_compression(grpc.Compression.Deflate) 50 await context.write(_RESPONSE) 51 try: 52 context.set_compression(grpc.Compression.Deflate) 53 except RuntimeError: 54 # NOTE(lidiz) Testing if the servicer context raises exception when 55 # the set_compression method is called after initial_metadata sent. 56 # After the initial_metadata sent, the server-side has no control over 57 # which compression algorithm it should use. 58 pass 59 else: 60 raise ValueError( 61 'Expecting exceptions if set_compression is not effective') 62 63 64async def _test_disable_compression_unary(request, context): 65 assert _REQUEST == request 66 context.set_compression(grpc.Compression.Deflate) 67 context.disable_next_message_compression() 68 return _RESPONSE 69 70 71async def _test_disable_compression_stream(unused_request_iterator, context): 72 assert _REQUEST == await context.read() 73 context.set_compression(grpc.Compression.Deflate) 74 await context.write(_RESPONSE) 75 context.disable_next_message_compression() 76 await context.write(_RESPONSE) 77 await context.write(_RESPONSE) 78 79 80_ROUTING_TABLE = { 81 _TEST_UNARY_UNARY: 82 grpc.unary_unary_rpc_method_handler(_test_unary_unary), 83 _TEST_SET_COMPRESSION: 84 grpc.stream_stream_rpc_method_handler(_test_set_compression), 85 _TEST_DISABLE_COMPRESSION_UNARY: 86 grpc.unary_unary_rpc_method_handler(_test_disable_compression_unary), 87 _TEST_DISABLE_COMPRESSION_STREAM: 88 grpc.stream_stream_rpc_method_handler(_test_disable_compression_stream), 89} 90 91 92class _GenericHandler(grpc.GenericRpcHandler): 93 94 def service(self, handler_call_details): 95 return _ROUTING_TABLE.get(handler_call_details.method) 96 97 98async def _start_test_server(options=None): 99 server = aio.server(options=options) 100 port = server.add_insecure_port('[::]:0') 101 server.add_generic_rpc_handlers((_GenericHandler(),)) 102 await server.start() 103 return f'localhost:{port}', server 104 105 106class TestCompression(AioTestBase): 107 108 async def setUp(self): 109 server_options = (_GZIP_DISABLED_CHANNEL_ARGUMENT,) 110 self._address, self._server = await _start_test_server(server_options) 111 self._channel = aio.insecure_channel(self._address) 112 113 async def tearDown(self): 114 await self._channel.close() 115 await self._server.stop(None) 116 117 async def test_channel_level_compression_baned_compression(self): 118 # GZIP is disabled, this call should fail 119 async with aio.insecure_channel( 120 self._address, compression=grpc.Compression.Gzip) as channel: 121 multicallable = channel.unary_unary(_TEST_UNARY_UNARY) 122 call = multicallable(_REQUEST) 123 with self.assertRaises(aio.AioRpcError) as exception_context: 124 await call 125 rpc_error = exception_context.exception 126 self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code()) 127 128 async def test_channel_level_compression_allowed_compression(self): 129 # Deflate is allowed, this call should succeed 130 async with aio.insecure_channel( 131 self._address, compression=grpc.Compression.Deflate) as channel: 132 multicallable = channel.unary_unary(_TEST_UNARY_UNARY) 133 call = multicallable(_REQUEST) 134 self.assertEqual(grpc.StatusCode.OK, await call.code()) 135 136 async def test_client_call_level_compression_baned_compression(self): 137 multicallable = self._channel.unary_unary(_TEST_UNARY_UNARY) 138 139 # GZIP is disabled, this call should fail 140 call = multicallable(_REQUEST, compression=grpc.Compression.Gzip) 141 with self.assertRaises(aio.AioRpcError) as exception_context: 142 await call 143 rpc_error = exception_context.exception 144 self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code()) 145 146 async def test_client_call_level_compression_allowed_compression(self): 147 multicallable = self._channel.unary_unary(_TEST_UNARY_UNARY) 148 149 # Deflate is allowed, this call should succeed 150 call = multicallable(_REQUEST, compression=grpc.Compression.Deflate) 151 self.assertEqual(grpc.StatusCode.OK, await call.code()) 152 153 async def test_server_call_level_compression(self): 154 multicallable = self._channel.stream_stream(_TEST_SET_COMPRESSION) 155 call = multicallable() 156 await call.write(_REQUEST) 157 await call.done_writing() 158 self.assertEqual(_RESPONSE, await call.read()) 159 self.assertEqual(grpc.StatusCode.OK, await call.code()) 160 161 async def test_server_disable_compression_unary(self): 162 multicallable = self._channel.unary_unary( 163 _TEST_DISABLE_COMPRESSION_UNARY) 164 call = multicallable(_REQUEST) 165 self.assertEqual(_RESPONSE, await call) 166 self.assertEqual(grpc.StatusCode.OK, await call.code()) 167 168 async def test_server_disable_compression_stream(self): 169 multicallable = self._channel.stream_stream( 170 _TEST_DISABLE_COMPRESSION_STREAM) 171 call = multicallable() 172 await call.write(_REQUEST) 173 await call.done_writing() 174 self.assertEqual(_RESPONSE, await call.read()) 175 self.assertEqual(_RESPONSE, await call.read()) 176 self.assertEqual(_RESPONSE, await call.read()) 177 self.assertEqual(grpc.StatusCode.OK, await call.code()) 178 179 async def test_server_default_compression_algorithm(self): 180 server = aio.server(compression=grpc.Compression.Deflate) 181 port = server.add_insecure_port('[::]:0') 182 server.add_generic_rpc_handlers((_GenericHandler(),)) 183 await server.start() 184 185 async with aio.insecure_channel(f'localhost:{port}') as channel: 186 multicallable = channel.unary_unary(_TEST_UNARY_UNARY) 187 call = multicallable(_REQUEST) 188 self.assertEqual(_RESPONSE, await call) 189 self.assertEqual(grpc.StatusCode.OK, await call.code()) 190 191 await server.stop(None) 192 193 194if __name__ == '__main__': 195 logging.basicConfig(level=logging.DEBUG) 196 unittest.main(verbosity=2) 197