1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests server and client side compression.""" 15 16from concurrent import futures 17import contextlib 18import functools 19import itertools 20import logging 21import os 22import unittest 23 24import grpc 25from grpc import _grpcio_metadata 26 27from tests.unit import _tcp_proxy 28from tests.unit.framework.common import test_constants 29 30_SERVICE_NAME = "test" 31_UNARY_UNARY = "UnaryUnary" 32_UNARY_STREAM = "UnaryStream" 33_STREAM_UNARY = "StreamUnary" 34_STREAM_STREAM = "StreamStream" 35 36# Cut down on test time. 37_STREAM_LENGTH = test_constants.STREAM_LENGTH // 16 38 39_HOST = "localhost" 40 41_REQUEST = b"\x00" * 100 42_COMPRESSION_RATIO_THRESHOLD = 0.05 43_COMPRESSION_METHODS = ( 44 None, 45 # Disabled for test tractability. 46 # grpc.Compression.NoCompression, 47 # grpc.Compression.Deflate, 48 grpc.Compression.Gzip, 49) 50_COMPRESSION_NAMES = { 51 None: "Uncompressed", 52 grpc.Compression.NoCompression: "NoCompression", 53 grpc.Compression.Deflate: "DeflateCompression", 54 grpc.Compression.Gzip: "GzipCompression", 55} 56 57_TEST_OPTIONS = { 58 "client_streaming": (True, False), 59 "server_streaming": (True, False), 60 "channel_compression": _COMPRESSION_METHODS, 61 "multicallable_compression": _COMPRESSION_METHODS, 62 "server_compression": _COMPRESSION_METHODS, 63 "server_call_compression": _COMPRESSION_METHODS, 64} 65 66 67def _make_handle_unary_unary(pre_response_callback): 68 def _handle_unary(request, servicer_context): 69 if pre_response_callback: 70 pre_response_callback(request, servicer_context) 71 return request 72 73 return _handle_unary 74 75 76def _make_handle_unary_stream(pre_response_callback): 77 def _handle_unary_stream(request, servicer_context): 78 if pre_response_callback: 79 pre_response_callback(request, servicer_context) 80 for _ in range(_STREAM_LENGTH): 81 yield request 82 83 return _handle_unary_stream 84 85 86def _make_handle_stream_unary(pre_response_callback): 87 def _handle_stream_unary(request_iterator, servicer_context): 88 if pre_response_callback: 89 pre_response_callback(request_iterator, servicer_context) 90 response = None 91 for request in request_iterator: 92 if not response: 93 response = request 94 return response 95 96 return _handle_stream_unary 97 98 99def _make_handle_stream_stream(pre_response_callback): 100 def _handle_stream(request_iterator, servicer_context): 101 # TODO(issue:#6891) We should be able to remove this loop, 102 # and replace with return; yield 103 for request in request_iterator: 104 if pre_response_callback: 105 pre_response_callback(request, servicer_context) 106 yield request 107 108 return _handle_stream 109 110 111def set_call_compression( 112 compression_method, request_or_iterator, servicer_context 113): 114 del request_or_iterator 115 servicer_context.set_compression(compression_method) 116 117 118def disable_next_compression(request, servicer_context): 119 del request 120 servicer_context.disable_next_message_compression() 121 122 123def disable_first_compression(request, servicer_context): 124 if int(request.decode("ascii")) == 0: 125 servicer_context.disable_next_message_compression() 126 127 128class _MethodHandler(grpc.RpcMethodHandler): 129 def __init__( 130 self, request_streaming, response_streaming, pre_response_callback 131 ): 132 self.request_streaming = request_streaming 133 self.response_streaming = response_streaming 134 self.request_deserializer = None 135 self.response_serializer = None 136 self.unary_unary = None 137 self.unary_stream = None 138 self.stream_unary = None 139 self.stream_stream = None 140 141 if self.request_streaming and self.response_streaming: 142 self.stream_stream = _make_handle_stream_stream( 143 pre_response_callback 144 ) 145 elif not self.request_streaming and not self.response_streaming: 146 self.unary_unary = _make_handle_unary_unary(pre_response_callback) 147 elif not self.request_streaming and self.response_streaming: 148 self.unary_stream = _make_handle_unary_stream(pre_response_callback) 149 else: 150 self.stream_unary = _make_handle_stream_unary(pre_response_callback) 151 152 153def get_method_handlers(pre_response_callback): 154 return { 155 _UNARY_UNARY: _MethodHandler(False, False, pre_response_callback), 156 _UNARY_STREAM: _MethodHandler(False, True, pre_response_callback), 157 _STREAM_UNARY: _MethodHandler(True, False, pre_response_callback), 158 _STREAM_STREAM: _MethodHandler(True, True, pre_response_callback), 159 } 160 161 162@contextlib.contextmanager 163def _instrumented_client_server_pair( 164 channel_kwargs, server_kwargs, server_handler 165): 166 server = grpc.server(futures.ThreadPoolExecutor(), **server_kwargs) 167 server.add_registered_method_handlers(_SERVICE_NAME, server_handler) 168 server_port = server.add_insecure_port("{}:0".format(_HOST)) 169 server.start() 170 with _tcp_proxy.TcpProxy(_HOST, _HOST, server_port) as proxy: 171 proxy_port = proxy.get_port() 172 with grpc.insecure_channel( 173 "{}:{}".format(_HOST, proxy_port), **channel_kwargs 174 ) as client_channel: 175 try: 176 yield client_channel, proxy, server 177 finally: 178 server.stop(None) 179 180 181def _get_byte_counts( 182 channel_kwargs, 183 multicallable_kwargs, 184 client_function, 185 server_kwargs, 186 server_handler, 187 message, 188): 189 with _instrumented_client_server_pair( 190 channel_kwargs, server_kwargs, server_handler 191 ) as pipeline: 192 client_channel, proxy, server = pipeline 193 client_function(client_channel, multicallable_kwargs, message) 194 return proxy.get_byte_count() 195 196 197def _get_compression_ratios( 198 client_function, 199 first_channel_kwargs, 200 first_multicallable_kwargs, 201 first_server_kwargs, 202 first_server_handler, 203 second_channel_kwargs, 204 second_multicallable_kwargs, 205 second_server_kwargs, 206 second_server_handler, 207 message, 208): 209 first_bytes_sent, first_bytes_received = _get_byte_counts( 210 first_channel_kwargs, 211 first_multicallable_kwargs, 212 client_function, 213 first_server_kwargs, 214 first_server_handler, 215 message, 216 ) 217 second_bytes_sent, second_bytes_received = _get_byte_counts( 218 second_channel_kwargs, 219 second_multicallable_kwargs, 220 client_function, 221 second_server_kwargs, 222 second_server_handler, 223 message, 224 ) 225 return ( 226 (second_bytes_sent - first_bytes_sent) / float(first_bytes_sent), 227 (second_bytes_received - first_bytes_received) 228 / float(first_bytes_received), 229 ) 230 231 232def _unary_unary_client(channel, multicallable_kwargs, message): 233 multi_callable = channel.unary_unary( 234 grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_UNARY), 235 _registered_method=True, 236 ) 237 response = multi_callable(message, **multicallable_kwargs) 238 if response != message: 239 raise RuntimeError( 240 "Request '{}' != Response '{}'".format(message, response) 241 ) 242 243 244def _unary_stream_client(channel, multicallable_kwargs, message): 245 multi_callable = channel.unary_stream( 246 grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_STREAM), 247 _registered_method=True, 248 ) 249 response_iterator = multi_callable(message, **multicallable_kwargs) 250 for response in response_iterator: 251 if response != message: 252 raise RuntimeError( 253 "Request '{}' != Response '{}'".format(message, response) 254 ) 255 256 257def _stream_unary_client(channel, multicallable_kwargs, message): 258 multi_callable = channel.stream_unary( 259 grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_UNARY), 260 _registered_method=True, 261 ) 262 requests = (_REQUEST for _ in range(_STREAM_LENGTH)) 263 response = multi_callable(requests, **multicallable_kwargs) 264 if response != message: 265 raise RuntimeError( 266 "Request '{}' != Response '{}'".format(message, response) 267 ) 268 269 270def _stream_stream_client(channel, multicallable_kwargs, message): 271 multi_callable = channel.stream_stream( 272 grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_STREAM), 273 _registered_method=True, 274 ) 275 request_prefix = str(0).encode("ascii") * 100 276 requests = ( 277 request_prefix + str(i).encode("ascii") for i in range(_STREAM_LENGTH) 278 ) 279 response_iterator = multi_callable(requests, **multicallable_kwargs) 280 for i, response in enumerate(response_iterator): 281 if int(response.decode("ascii")) != i: 282 raise RuntimeError( 283 "Request '{}' != Response '{}'".format(i, response) 284 ) 285 286 287class CompressionTest(unittest.TestCase): 288 def assertCompressed(self, compression_ratio): 289 self.assertLess( 290 compression_ratio, 291 -1.0 * _COMPRESSION_RATIO_THRESHOLD, 292 msg="Actual compression ratio: {}".format(compression_ratio), 293 ) 294 295 def assertNotCompressed(self, compression_ratio): 296 self.assertGreaterEqual( 297 compression_ratio, 298 -1.0 * _COMPRESSION_RATIO_THRESHOLD, 299 msg="Actual compression ratio: {}".format(compression_ratio), 300 ) 301 302 def assertConfigurationCompressed( 303 self, 304 client_streaming, 305 server_streaming, 306 channel_compression, 307 multicallable_compression, 308 server_compression, 309 server_call_compression, 310 ): 311 client_side_compressed = ( 312 channel_compression or multicallable_compression 313 ) 314 server_side_compressed = server_compression or server_call_compression 315 channel_kwargs = ( 316 { 317 "compression": channel_compression, 318 } 319 if channel_compression 320 else {} 321 ) 322 multicallable_kwargs = ( 323 { 324 "compression": multicallable_compression, 325 } 326 if multicallable_compression 327 else {} 328 ) 329 330 client_function = None 331 if not client_streaming and not server_streaming: 332 client_function = _unary_unary_client 333 elif not client_streaming and server_streaming: 334 client_function = _unary_stream_client 335 elif client_streaming and not server_streaming: 336 client_function = _stream_unary_client 337 else: 338 client_function = _stream_stream_client 339 340 server_kwargs = ( 341 { 342 "compression": server_compression, 343 } 344 if server_compression 345 else {} 346 ) 347 if server_call_compression: 348 server_handler = get_method_handlers( 349 functools.partial(set_call_compression, grpc.Compression.Gzip) 350 ) 351 else: 352 server_handler = get_method_handlers(None) 353 354 _get_compression_ratios( 355 client_function, 356 {}, 357 {}, 358 {}, 359 get_method_handlers(None), 360 channel_kwargs, 361 multicallable_kwargs, 362 server_kwargs, 363 server_handler, 364 _REQUEST, 365 ) 366 367 def testDisableNextCompressionStreaming(self): 368 server_kwargs = { 369 "compression": grpc.Compression.Deflate, 370 } 371 _get_compression_ratios( 372 _stream_stream_client, 373 {}, 374 {}, 375 {}, 376 get_method_handlers(None), 377 {}, 378 {}, 379 server_kwargs, 380 get_method_handlers(disable_next_compression), 381 _REQUEST, 382 ) 383 384 def testDisableNextCompressionStreamingResets(self): 385 server_kwargs = { 386 "compression": grpc.Compression.Deflate, 387 } 388 _get_compression_ratios( 389 _stream_stream_client, 390 {}, 391 {}, 392 {}, 393 get_method_handlers(None), 394 {}, 395 {}, 396 server_kwargs, 397 get_method_handlers(disable_first_compression), 398 _REQUEST, 399 ) 400 401 402def _get_compression_str(name, value): 403 return "{}{}".format(name, _COMPRESSION_NAMES[value]) 404 405 406def _get_compression_test_name( 407 client_streaming, 408 server_streaming, 409 channel_compression, 410 multicallable_compression, 411 server_compression, 412 server_call_compression, 413): 414 client_arity = "Stream" if client_streaming else "Unary" 415 server_arity = "Stream" if server_streaming else "Unary" 416 arity = "{}{}".format(client_arity, server_arity) 417 channel_compression_str = _get_compression_str( 418 "Channel", channel_compression 419 ) 420 multicallable_compression_str = _get_compression_str( 421 "Multicallable", multicallable_compression 422 ) 423 server_compression_str = _get_compression_str("Server", server_compression) 424 server_call_compression_str = _get_compression_str( 425 "ServerCall", server_call_compression 426 ) 427 return "test{}{}{}{}{}".format( 428 arity, 429 channel_compression_str, 430 multicallable_compression_str, 431 server_compression_str, 432 server_call_compression_str, 433 ) 434 435 436def _test_options(): 437 for test_parameters in itertools.product(*_TEST_OPTIONS.values()): 438 yield dict(zip(_TEST_OPTIONS.keys(), test_parameters)) 439 440 441for options in _test_options(): 442 443 def test_compression(**kwargs): 444 def _test_compression(self): 445 self.assertConfigurationCompressed(**kwargs) 446 447 return _test_compression 448 449 setattr( 450 CompressionTest, 451 _get_compression_test_name(**options), 452 test_compression(**options), 453 ) 454 455if __name__ == "__main__": 456 logging.basicConfig() 457 unittest.main(verbosity=2) 458