1# Copyright 2020 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests behavior around the compression mechanism.""" 15 16import asyncio 17import logging 18import platform 19import random 20import unittest 21 22import grpc 23from grpc.experimental import aio 24 25from tests_aio.unit import _common 26from tests_aio.unit._test_base import AioTestBase 27 28_GZIP_CHANNEL_ARGUMENT = ("grpc.default_compression_algorithm", 2) 29_GZIP_DISABLED_CHANNEL_ARGUMENT = ( 30 "grpc.compression_enabled_algorithms_bitset", 31 3, 32) 33_DEFLATE_DISABLED_CHANNEL_ARGUMENT = ( 34 "grpc.compression_enabled_algorithms_bitset", 35 5, 36) 37 38_TEST_UNARY_UNARY = "/test/TestUnaryUnary" 39_TEST_SET_COMPRESSION = "/test/TestSetCompression" 40_TEST_DISABLE_COMPRESSION_UNARY = "/test/TestDisableCompressionUnary" 41_TEST_DISABLE_COMPRESSION_STREAM = "/test/TestDisableCompressionStream" 42 43_REQUEST = b"\x01" * 100 44_RESPONSE = b"\x02" * 100 45 46 47async def _test_unary_unary(unused_request, unused_context): 48 return _RESPONSE 49 50 51async def _test_set_compression(unused_request_iterator, context): 52 assert _REQUEST == await context.read() 53 context.set_compression(grpc.Compression.Deflate) 54 await context.write(_RESPONSE) 55 try: 56 context.set_compression(grpc.Compression.Deflate) 57 except RuntimeError: 58 # NOTE(lidiz) Testing if the servicer context raises exception when 59 # the set_compression method is called after initial_metadata sent. 60 # After the initial_metadata sent, the server-side has no control over 61 # which compression algorithm it should use. 62 pass 63 else: 64 raise ValueError( 65 "Expecting exceptions if set_compression is not effective" 66 ) 67 68 69async def _test_disable_compression_unary(request, context): 70 assert _REQUEST == request 71 context.set_compression(grpc.Compression.Deflate) 72 context.disable_next_message_compression() 73 return _RESPONSE 74 75 76async def _test_disable_compression_stream(unused_request_iterator, context): 77 assert _REQUEST == await context.read() 78 context.set_compression(grpc.Compression.Deflate) 79 await context.write(_RESPONSE) 80 context.disable_next_message_compression() 81 await context.write(_RESPONSE) 82 await context.write(_RESPONSE) 83 84 85_ROUTING_TABLE = { 86 _TEST_UNARY_UNARY: grpc.unary_unary_rpc_method_handler(_test_unary_unary), 87 _TEST_SET_COMPRESSION: grpc.stream_stream_rpc_method_handler( 88 _test_set_compression 89 ), 90 _TEST_DISABLE_COMPRESSION_UNARY: grpc.unary_unary_rpc_method_handler( 91 _test_disable_compression_unary 92 ), 93 _TEST_DISABLE_COMPRESSION_STREAM: grpc.stream_stream_rpc_method_handler( 94 _test_disable_compression_stream 95 ), 96} 97 98 99class _GenericHandler(grpc.GenericRpcHandler): 100 def service(self, handler_call_details): 101 return _ROUTING_TABLE.get(handler_call_details.method) 102 103 104async def _start_test_server(options=None): 105 server = aio.server(options=options) 106 port = server.add_insecure_port("[::]:0") 107 server.add_generic_rpc_handlers((_GenericHandler(),)) 108 await server.start() 109 return f"localhost:{port}", server 110 111 112class TestCompression(AioTestBase): 113 async def setUp(self): 114 server_options = (_GZIP_DISABLED_CHANNEL_ARGUMENT,) 115 self._address, self._server = await _start_test_server(server_options) 116 self._channel = aio.insecure_channel(self._address) 117 118 async def tearDown(self): 119 await self._channel.close() 120 await self._server.stop(None) 121 122 async def test_channel_level_compression_baned_compression(self): 123 # GZIP is disabled, this call should fail 124 async with aio.insecure_channel( 125 self._address, compression=grpc.Compression.Gzip 126 ) as channel: 127 multicallable = channel.unary_unary(_TEST_UNARY_UNARY) 128 call = multicallable(_REQUEST) 129 with self.assertRaises(aio.AioRpcError) as exception_context: 130 await call 131 rpc_error = exception_context.exception 132 self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code()) 133 134 async def test_channel_level_compression_allowed_compression(self): 135 # Deflate is allowed, this call should succeed 136 async with aio.insecure_channel( 137 self._address, compression=grpc.Compression.Deflate 138 ) as channel: 139 multicallable = channel.unary_unary(_TEST_UNARY_UNARY) 140 call = multicallable(_REQUEST) 141 self.assertEqual(grpc.StatusCode.OK, await call.code()) 142 143 async def test_client_call_level_compression_baned_compression(self): 144 multicallable = self._channel.unary_unary(_TEST_UNARY_UNARY) 145 146 # GZIP is disabled, this call should fail 147 call = multicallable(_REQUEST, compression=grpc.Compression.Gzip) 148 with self.assertRaises(aio.AioRpcError) as exception_context: 149 await call 150 rpc_error = exception_context.exception 151 self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code()) 152 153 async def test_client_call_level_compression_allowed_compression(self): 154 multicallable = self._channel.unary_unary(_TEST_UNARY_UNARY) 155 156 # Deflate is allowed, this call should succeed 157 call = multicallable(_REQUEST, compression=grpc.Compression.Deflate) 158 self.assertEqual(grpc.StatusCode.OK, await call.code()) 159 160 async def test_server_call_level_compression(self): 161 multicallable = self._channel.stream_stream(_TEST_SET_COMPRESSION) 162 call = multicallable() 163 await call.write(_REQUEST) 164 await call.done_writing() 165 self.assertEqual(_RESPONSE, await call.read()) 166 self.assertEqual(grpc.StatusCode.OK, await call.code()) 167 168 async def test_server_disable_compression_unary(self): 169 multicallable = self._channel.unary_unary( 170 _TEST_DISABLE_COMPRESSION_UNARY 171 ) 172 call = multicallable(_REQUEST) 173 self.assertEqual(_RESPONSE, await call) 174 self.assertEqual(grpc.StatusCode.OK, await call.code()) 175 176 async def test_server_disable_compression_stream(self): 177 multicallable = self._channel.stream_stream( 178 _TEST_DISABLE_COMPRESSION_STREAM 179 ) 180 call = multicallable() 181 await call.write(_REQUEST) 182 await call.done_writing() 183 self.assertEqual(_RESPONSE, await call.read()) 184 self.assertEqual(_RESPONSE, await call.read()) 185 self.assertEqual(_RESPONSE, await call.read()) 186 self.assertEqual(grpc.StatusCode.OK, await call.code()) 187 188 async def test_server_default_compression_algorithm(self): 189 server = aio.server(compression=grpc.Compression.Deflate) 190 port = server.add_insecure_port("[::]:0") 191 server.add_generic_rpc_handlers((_GenericHandler(),)) 192 await server.start() 193 194 async with aio.insecure_channel(f"localhost:{port}") as channel: 195 multicallable = channel.unary_unary(_TEST_UNARY_UNARY) 196 call = multicallable(_REQUEST) 197 self.assertEqual(_RESPONSE, await call) 198 self.assertEqual(grpc.StatusCode.OK, await call.code()) 199 200 await server.stop(None) 201 202 203if __name__ == "__main__": 204 logging.basicConfig(level=logging.DEBUG) 205 unittest.main(verbosity=2) 206