1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests server and client side compression.""" 15 16import unittest 17 18import contextlib 19from concurrent import futures 20import functools 21import itertools 22import logging 23import os 24 25import grpc 26from grpc import _grpcio_metadata 27 28from tests.unit import test_common 29from tests.unit.framework.common import test_constants 30from tests.unit import _tcp_proxy 31 32_UNARY_UNARY = '/test/UnaryUnary' 33_UNARY_STREAM = '/test/UnaryStream' 34_STREAM_UNARY = '/test/StreamUnary' 35_STREAM_STREAM = '/test/StreamStream' 36 37# Cut down on test time. 38_STREAM_LENGTH = test_constants.STREAM_LENGTH // 16 39 40_HOST = 'localhost' 41 42_REQUEST = b'\x00' * 100 43_COMPRESSION_RATIO_THRESHOLD = 0.05 44_COMPRESSION_METHODS = ( 45 None, 46 # Disabled for test tractability. 47 # grpc.Compression.NoCompression, 48 # grpc.Compression.Deflate, 49 grpc.Compression.Gzip, 50) 51_COMPRESSION_NAMES = { 52 None: 'Uncompressed', 53 grpc.Compression.NoCompression: 'NoCompression', 54 grpc.Compression.Deflate: 'DeflateCompression', 55 grpc.Compression.Gzip: 'GzipCompression', 56} 57 58_TEST_OPTIONS = { 59 'client_streaming': (True, False), 60 'server_streaming': (True, False), 61 'channel_compression': _COMPRESSION_METHODS, 62 'multicallable_compression': _COMPRESSION_METHODS, 63 'server_compression': _COMPRESSION_METHODS, 64 'server_call_compression': _COMPRESSION_METHODS, 65} 66 67 68def _make_handle_unary_unary(pre_response_callback): 69 70 def _handle_unary(request, servicer_context): 71 if pre_response_callback: 72 pre_response_callback(request, servicer_context) 73 return request 74 75 return _handle_unary 76 77 78def _make_handle_unary_stream(pre_response_callback): 79 80 def _handle_unary_stream(request, servicer_context): 81 if pre_response_callback: 82 pre_response_callback(request, servicer_context) 83 for _ in range(_STREAM_LENGTH): 84 yield request 85 86 return _handle_unary_stream 87 88 89def _make_handle_stream_unary(pre_response_callback): 90 91 def _handle_stream_unary(request_iterator, servicer_context): 92 if pre_response_callback: 93 pre_response_callback(request_iterator, servicer_context) 94 response = None 95 for request in request_iterator: 96 if not response: 97 response = request 98 return response 99 100 return _handle_stream_unary 101 102 103def _make_handle_stream_stream(pre_response_callback): 104 105 def _handle_stream(request_iterator, servicer_context): 106 # TODO(issue:#6891) We should be able to remove this loop, 107 # and replace with return; yield 108 for request in request_iterator: 109 if pre_response_callback: 110 pre_response_callback(request, servicer_context) 111 yield request 112 113 return _handle_stream 114 115 116def set_call_compression(compression_method, request_or_iterator, 117 servicer_context): 118 del request_or_iterator 119 servicer_context.set_compression(compression_method) 120 121 122def disable_next_compression(request, servicer_context): 123 del request 124 servicer_context.disable_next_message_compression() 125 126 127def disable_first_compression(request, servicer_context): 128 if int(request.decode('ascii')) == 0: 129 servicer_context.disable_next_message_compression() 130 131 132class _MethodHandler(grpc.RpcMethodHandler): 133 134 def __init__(self, request_streaming, response_streaming, 135 pre_response_callback): 136 self.request_streaming = request_streaming 137 self.response_streaming = response_streaming 138 self.request_deserializer = None 139 self.response_serializer = None 140 self.unary_unary = None 141 self.unary_stream = None 142 self.stream_unary = None 143 self.stream_stream = None 144 145 if self.request_streaming and self.response_streaming: 146 self.stream_stream = _make_handle_stream_stream( 147 pre_response_callback) 148 elif not self.request_streaming and not self.response_streaming: 149 self.unary_unary = _make_handle_unary_unary(pre_response_callback) 150 elif not self.request_streaming and self.response_streaming: 151 self.unary_stream = _make_handle_unary_stream(pre_response_callback) 152 else: 153 self.stream_unary = _make_handle_stream_unary(pre_response_callback) 154 155 156class _GenericHandler(grpc.GenericRpcHandler): 157 158 def __init__(self, pre_response_callback): 159 self._pre_response_callback = pre_response_callback 160 161 def service(self, handler_call_details): 162 if handler_call_details.method == _UNARY_UNARY: 163 return _MethodHandler(False, False, self._pre_response_callback) 164 elif handler_call_details.method == _UNARY_STREAM: 165 return _MethodHandler(False, True, self._pre_response_callback) 166 elif handler_call_details.method == _STREAM_UNARY: 167 return _MethodHandler(True, False, self._pre_response_callback) 168 elif handler_call_details.method == _STREAM_STREAM: 169 return _MethodHandler(True, True, self._pre_response_callback) 170 else: 171 return None 172 173 174@contextlib.contextmanager 175def _instrumented_client_server_pair(channel_kwargs, server_kwargs, 176 server_handler): 177 server = grpc.server(futures.ThreadPoolExecutor(), **server_kwargs) 178 server.add_generic_rpc_handlers((server_handler,)) 179 server_port = server.add_insecure_port('{}:0'.format(_HOST)) 180 server.start() 181 with _tcp_proxy.TcpProxy(_HOST, _HOST, server_port) as proxy: 182 proxy_port = proxy.get_port() 183 with grpc.insecure_channel('{}:{}'.format(_HOST, proxy_port), 184 **channel_kwargs) as client_channel: 185 try: 186 yield client_channel, proxy, server 187 finally: 188 server.stop(None) 189 190 191def _get_byte_counts(channel_kwargs, multicallable_kwargs, client_function, 192 server_kwargs, server_handler, message): 193 with _instrumented_client_server_pair(channel_kwargs, server_kwargs, 194 server_handler) as pipeline: 195 client_channel, proxy, server = pipeline 196 client_function(client_channel, multicallable_kwargs, message) 197 return proxy.get_byte_count() 198 199 200def _get_compression_ratios(client_function, first_channel_kwargs, 201 first_multicallable_kwargs, first_server_kwargs, 202 first_server_handler, second_channel_kwargs, 203 second_multicallable_kwargs, second_server_kwargs, 204 second_server_handler, message): 205 try: 206 # This test requires the byte length of each connection to be deterministic. As 207 # it turns out, flow control puts bytes on the wire in a nondeterministic 208 # manner. We disable it here in order to measure compression ratios 209 # deterministically. 210 os.environ['GRPC_EXPERIMENTAL_DISABLE_FLOW_CONTROL'] = 'true' 211 first_bytes_sent, first_bytes_received = _get_byte_counts( 212 first_channel_kwargs, first_multicallable_kwargs, client_function, 213 first_server_kwargs, first_server_handler, message) 214 second_bytes_sent, second_bytes_received = _get_byte_counts( 215 second_channel_kwargs, second_multicallable_kwargs, client_function, 216 second_server_kwargs, second_server_handler, message) 217 return ((second_bytes_sent - first_bytes_sent) / 218 float(first_bytes_sent), 219 (second_bytes_received - first_bytes_received) / 220 float(first_bytes_received)) 221 finally: 222 del os.environ['GRPC_EXPERIMENTAL_DISABLE_FLOW_CONTROL'] 223 224 225def _unary_unary_client(channel, multicallable_kwargs, message): 226 multi_callable = channel.unary_unary(_UNARY_UNARY) 227 response = multi_callable(message, **multicallable_kwargs) 228 if response != message: 229 raise RuntimeError("Request '{}' != Response '{}'".format( 230 message, response)) 231 232 233def _unary_stream_client(channel, multicallable_kwargs, message): 234 multi_callable = channel.unary_stream(_UNARY_STREAM) 235 response_iterator = multi_callable(message, **multicallable_kwargs) 236 for response in response_iterator: 237 if response != message: 238 raise RuntimeError("Request '{}' != Response '{}'".format( 239 message, response)) 240 241 242def _stream_unary_client(channel, multicallable_kwargs, message): 243 multi_callable = channel.stream_unary(_STREAM_UNARY) 244 requests = (_REQUEST for _ in range(_STREAM_LENGTH)) 245 response = multi_callable(requests, **multicallable_kwargs) 246 if response != message: 247 raise RuntimeError("Request '{}' != Response '{}'".format( 248 message, response)) 249 250 251def _stream_stream_client(channel, multicallable_kwargs, message): 252 multi_callable = channel.stream_stream(_STREAM_STREAM) 253 request_prefix = str(0).encode('ascii') * 100 254 requests = ( 255 request_prefix + str(i).encode('ascii') for i in range(_STREAM_LENGTH)) 256 response_iterator = multi_callable(requests, **multicallable_kwargs) 257 for i, response in enumerate(response_iterator): 258 if int(response.decode('ascii')) != i: 259 raise RuntimeError("Request '{}' != Response '{}'".format( 260 i, response)) 261 262 263class CompressionTest(unittest.TestCase): 264 265 def assertCompressed(self, compression_ratio): 266 self.assertLess( 267 compression_ratio, 268 -1.0 * _COMPRESSION_RATIO_THRESHOLD, 269 msg='Actual compression ratio: {}'.format(compression_ratio)) 270 271 def assertNotCompressed(self, compression_ratio): 272 self.assertGreaterEqual( 273 compression_ratio, 274 -1.0 * _COMPRESSION_RATIO_THRESHOLD, 275 msg='Actual compession ratio: {}'.format(compression_ratio)) 276 277 def assertConfigurationCompressed(self, client_streaming, server_streaming, 278 channel_compression, 279 multicallable_compression, 280 server_compression, 281 server_call_compression): 282 client_side_compressed = channel_compression or multicallable_compression 283 server_side_compressed = server_compression or server_call_compression 284 channel_kwargs = { 285 'compression': channel_compression, 286 } if channel_compression else {} 287 multicallable_kwargs = { 288 'compression': multicallable_compression, 289 } if multicallable_compression else {} 290 291 client_function = None 292 if not client_streaming and not server_streaming: 293 client_function = _unary_unary_client 294 elif not client_streaming and server_streaming: 295 client_function = _unary_stream_client 296 elif client_streaming and not server_streaming: 297 client_function = _stream_unary_client 298 else: 299 client_function = _stream_stream_client 300 301 server_kwargs = { 302 'compression': server_compression, 303 } if server_compression else {} 304 server_handler = _GenericHandler( 305 functools.partial(set_call_compression, grpc.Compression.Gzip) 306 ) if server_call_compression else _GenericHandler(None) 307 sent_ratio, received_ratio = _get_compression_ratios( 308 client_function, {}, {}, {}, _GenericHandler(None), channel_kwargs, 309 multicallable_kwargs, server_kwargs, server_handler, _REQUEST) 310 311 if client_side_compressed: 312 self.assertCompressed(sent_ratio) 313 else: 314 self.assertNotCompressed(sent_ratio) 315 316 if server_side_compressed: 317 self.assertCompressed(received_ratio) 318 else: 319 self.assertNotCompressed(received_ratio) 320 321 def testDisableNextCompressionStreaming(self): 322 server_kwargs = { 323 'compression': grpc.Compression.Deflate, 324 } 325 _, received_ratio = _get_compression_ratios( 326 _stream_stream_client, {}, {}, {}, _GenericHandler(None), {}, {}, 327 server_kwargs, _GenericHandler(disable_next_compression), _REQUEST) 328 self.assertNotCompressed(received_ratio) 329 330 def testDisableNextCompressionStreamingResets(self): 331 server_kwargs = { 332 'compression': grpc.Compression.Deflate, 333 } 334 _, received_ratio = _get_compression_ratios( 335 _stream_stream_client, {}, {}, {}, _GenericHandler(None), {}, {}, 336 server_kwargs, _GenericHandler(disable_first_compression), _REQUEST) 337 self.assertCompressed(received_ratio) 338 339 340def _get_compression_str(name, value): 341 return '{}{}'.format(name, _COMPRESSION_NAMES[value]) 342 343 344def _get_compression_test_name(client_streaming, server_streaming, 345 channel_compression, multicallable_compression, 346 server_compression, server_call_compression): 347 client_arity = 'Stream' if client_streaming else 'Unary' 348 server_arity = 'Stream' if server_streaming else 'Unary' 349 arity = '{}{}'.format(client_arity, server_arity) 350 channel_compression_str = _get_compression_str('Channel', 351 channel_compression) 352 multicallable_compression_str = _get_compression_str( 353 'Multicallable', multicallable_compression) 354 server_compression_str = _get_compression_str('Server', server_compression) 355 server_call_compression_str = _get_compression_str('ServerCall', 356 server_call_compression) 357 return 'test{}{}{}{}{}'.format(arity, channel_compression_str, 358 multicallable_compression_str, 359 server_compression_str, 360 server_call_compression_str) 361 362 363def _test_options(): 364 for test_parameters in itertools.product(*_TEST_OPTIONS.values()): 365 yield dict(zip(_TEST_OPTIONS.keys(), test_parameters)) 366 367 368for options in _test_options(): 369 370 def test_compression(**kwargs): 371 372 def _test_compression(self): 373 self.assertConfigurationCompressed(**kwargs) 374 375 return _test_compression 376 377 setattr(CompressionTest, _get_compression_test_name(**options), 378 test_compression(**options)) 379 380if __name__ == '__main__': 381 logging.basicConfig() 382 unittest.main(verbosity=2) 383