• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests server and client side compression."""
15
16import unittest
17
18import contextlib
19from concurrent import futures
20import functools
21import itertools
22import logging
23import os
24
25import grpc
26from grpc import _grpcio_metadata
27
28from tests.unit import test_common
29from tests.unit.framework.common import test_constants
30from tests.unit import _tcp_proxy
31
32_UNARY_UNARY = '/test/UnaryUnary'
33_UNARY_STREAM = '/test/UnaryStream'
34_STREAM_UNARY = '/test/StreamUnary'
35_STREAM_STREAM = '/test/StreamStream'
36
37# Cut down on test time.
38_STREAM_LENGTH = test_constants.STREAM_LENGTH // 16
39
40_HOST = 'localhost'
41
42_REQUEST = b'\x00' * 100
43_COMPRESSION_RATIO_THRESHOLD = 0.05
44_COMPRESSION_METHODS = (
45    None,
46    # Disabled for test tractability.
47    # grpc.Compression.NoCompression,
48    # grpc.Compression.Deflate,
49    grpc.Compression.Gzip,
50)
51_COMPRESSION_NAMES = {
52    None: 'Uncompressed',
53    grpc.Compression.NoCompression: 'NoCompression',
54    grpc.Compression.Deflate: 'DeflateCompression',
55    grpc.Compression.Gzip: 'GzipCompression',
56}
57
58_TEST_OPTIONS = {
59    'client_streaming': (True, False),
60    'server_streaming': (True, False),
61    'channel_compression': _COMPRESSION_METHODS,
62    'multicallable_compression': _COMPRESSION_METHODS,
63    'server_compression': _COMPRESSION_METHODS,
64    'server_call_compression': _COMPRESSION_METHODS,
65}
66
67
68def _make_handle_unary_unary(pre_response_callback):
69
70    def _handle_unary(request, servicer_context):
71        if pre_response_callback:
72            pre_response_callback(request, servicer_context)
73        return request
74
75    return _handle_unary
76
77
78def _make_handle_unary_stream(pre_response_callback):
79
80    def _handle_unary_stream(request, servicer_context):
81        if pre_response_callback:
82            pre_response_callback(request, servicer_context)
83        for _ in range(_STREAM_LENGTH):
84            yield request
85
86    return _handle_unary_stream
87
88
89def _make_handle_stream_unary(pre_response_callback):
90
91    def _handle_stream_unary(request_iterator, servicer_context):
92        if pre_response_callback:
93            pre_response_callback(request_iterator, servicer_context)
94        response = None
95        for request in request_iterator:
96            if not response:
97                response = request
98        return response
99
100    return _handle_stream_unary
101
102
103def _make_handle_stream_stream(pre_response_callback):
104
105    def _handle_stream(request_iterator, servicer_context):
106        # TODO(issue:#6891) We should be able to remove this loop,
107        # and replace with return; yield
108        for request in request_iterator:
109            if pre_response_callback:
110                pre_response_callback(request, servicer_context)
111            yield request
112
113    return _handle_stream
114
115
116def set_call_compression(compression_method, request_or_iterator,
117                         servicer_context):
118    del request_or_iterator
119    servicer_context.set_compression(compression_method)
120
121
122def disable_next_compression(request, servicer_context):
123    del request
124    servicer_context.disable_next_message_compression()
125
126
127def disable_first_compression(request, servicer_context):
128    if int(request.decode('ascii')) == 0:
129        servicer_context.disable_next_message_compression()
130
131
132class _MethodHandler(grpc.RpcMethodHandler):
133
134    def __init__(self, request_streaming, response_streaming,
135                 pre_response_callback):
136        self.request_streaming = request_streaming
137        self.response_streaming = response_streaming
138        self.request_deserializer = None
139        self.response_serializer = None
140        self.unary_unary = None
141        self.unary_stream = None
142        self.stream_unary = None
143        self.stream_stream = None
144
145        if self.request_streaming and self.response_streaming:
146            self.stream_stream = _make_handle_stream_stream(
147                pre_response_callback)
148        elif not self.request_streaming and not self.response_streaming:
149            self.unary_unary = _make_handle_unary_unary(pre_response_callback)
150        elif not self.request_streaming and self.response_streaming:
151            self.unary_stream = _make_handle_unary_stream(pre_response_callback)
152        else:
153            self.stream_unary = _make_handle_stream_unary(pre_response_callback)
154
155
156class _GenericHandler(grpc.GenericRpcHandler):
157
158    def __init__(self, pre_response_callback):
159        self._pre_response_callback = pre_response_callback
160
161    def service(self, handler_call_details):
162        if handler_call_details.method == _UNARY_UNARY:
163            return _MethodHandler(False, False, self._pre_response_callback)
164        elif handler_call_details.method == _UNARY_STREAM:
165            return _MethodHandler(False, True, self._pre_response_callback)
166        elif handler_call_details.method == _STREAM_UNARY:
167            return _MethodHandler(True, False, self._pre_response_callback)
168        elif handler_call_details.method == _STREAM_STREAM:
169            return _MethodHandler(True, True, self._pre_response_callback)
170        else:
171            return None
172
173
174@contextlib.contextmanager
175def _instrumented_client_server_pair(channel_kwargs, server_kwargs,
176                                     server_handler):
177    server = grpc.server(futures.ThreadPoolExecutor(), **server_kwargs)
178    server.add_generic_rpc_handlers((server_handler,))
179    server_port = server.add_insecure_port('{}:0'.format(_HOST))
180    server.start()
181    with _tcp_proxy.TcpProxy(_HOST, _HOST, server_port) as proxy:
182        proxy_port = proxy.get_port()
183        with grpc.insecure_channel('{}:{}'.format(_HOST, proxy_port),
184                                   **channel_kwargs) as client_channel:
185            try:
186                yield client_channel, proxy, server
187            finally:
188                server.stop(None)
189
190
191def _get_byte_counts(channel_kwargs, multicallable_kwargs, client_function,
192                     server_kwargs, server_handler, message):
193    with _instrumented_client_server_pair(channel_kwargs, server_kwargs,
194                                          server_handler) as pipeline:
195        client_channel, proxy, server = pipeline
196        client_function(client_channel, multicallable_kwargs, message)
197        return proxy.get_byte_count()
198
199
200def _get_compression_ratios(client_function, first_channel_kwargs,
201                            first_multicallable_kwargs, first_server_kwargs,
202                            first_server_handler, second_channel_kwargs,
203                            second_multicallable_kwargs, second_server_kwargs,
204                            second_server_handler, message):
205    try:
206        # This test requires the byte length of each connection to be deterministic. As
207        # it turns out, flow control puts bytes on the wire in a nondeterministic
208        # manner. We disable it here in order to measure compression ratios
209        # deterministically.
210        os.environ['GRPC_EXPERIMENTAL_DISABLE_FLOW_CONTROL'] = 'true'
211        first_bytes_sent, first_bytes_received = _get_byte_counts(
212            first_channel_kwargs, first_multicallable_kwargs, client_function,
213            first_server_kwargs, first_server_handler, message)
214        second_bytes_sent, second_bytes_received = _get_byte_counts(
215            second_channel_kwargs, second_multicallable_kwargs, client_function,
216            second_server_kwargs, second_server_handler, message)
217        return ((second_bytes_sent - first_bytes_sent) /
218                float(first_bytes_sent),
219                (second_bytes_received - first_bytes_received) /
220                float(first_bytes_received))
221    finally:
222        del os.environ['GRPC_EXPERIMENTAL_DISABLE_FLOW_CONTROL']
223
224
225def _unary_unary_client(channel, multicallable_kwargs, message):
226    multi_callable = channel.unary_unary(_UNARY_UNARY)
227    response = multi_callable(message, **multicallable_kwargs)
228    if response != message:
229        raise RuntimeError("Request '{}' != Response '{}'".format(
230            message, response))
231
232
233def _unary_stream_client(channel, multicallable_kwargs, message):
234    multi_callable = channel.unary_stream(_UNARY_STREAM)
235    response_iterator = multi_callable(message, **multicallable_kwargs)
236    for response in response_iterator:
237        if response != message:
238            raise RuntimeError("Request '{}' != Response '{}'".format(
239                message, response))
240
241
242def _stream_unary_client(channel, multicallable_kwargs, message):
243    multi_callable = channel.stream_unary(_STREAM_UNARY)
244    requests = (_REQUEST for _ in range(_STREAM_LENGTH))
245    response = multi_callable(requests, **multicallable_kwargs)
246    if response != message:
247        raise RuntimeError("Request '{}' != Response '{}'".format(
248            message, response))
249
250
251def _stream_stream_client(channel, multicallable_kwargs, message):
252    multi_callable = channel.stream_stream(_STREAM_STREAM)
253    request_prefix = str(0).encode('ascii') * 100
254    requests = (
255        request_prefix + str(i).encode('ascii') for i in range(_STREAM_LENGTH))
256    response_iterator = multi_callable(requests, **multicallable_kwargs)
257    for i, response in enumerate(response_iterator):
258        if int(response.decode('ascii')) != i:
259            raise RuntimeError("Request '{}' != Response '{}'".format(
260                i, response))
261
262
263class CompressionTest(unittest.TestCase):
264
265    def assertCompressed(self, compression_ratio):
266        self.assertLess(
267            compression_ratio,
268            -1.0 * _COMPRESSION_RATIO_THRESHOLD,
269            msg='Actual compression ratio: {}'.format(compression_ratio))
270
271    def assertNotCompressed(self, compression_ratio):
272        self.assertGreaterEqual(
273            compression_ratio,
274            -1.0 * _COMPRESSION_RATIO_THRESHOLD,
275            msg='Actual compession ratio: {}'.format(compression_ratio))
276
277    def assertConfigurationCompressed(self, client_streaming, server_streaming,
278                                      channel_compression,
279                                      multicallable_compression,
280                                      server_compression,
281                                      server_call_compression):
282        client_side_compressed = channel_compression or multicallable_compression
283        server_side_compressed = server_compression or server_call_compression
284        channel_kwargs = {
285            'compression': channel_compression,
286        } if channel_compression else {}
287        multicallable_kwargs = {
288            'compression': multicallable_compression,
289        } if multicallable_compression else {}
290
291        client_function = None
292        if not client_streaming and not server_streaming:
293            client_function = _unary_unary_client
294        elif not client_streaming and server_streaming:
295            client_function = _unary_stream_client
296        elif client_streaming and not server_streaming:
297            client_function = _stream_unary_client
298        else:
299            client_function = _stream_stream_client
300
301        server_kwargs = {
302            'compression': server_compression,
303        } if server_compression else {}
304        server_handler = _GenericHandler(
305            functools.partial(set_call_compression, grpc.Compression.Gzip)
306        ) if server_call_compression else _GenericHandler(None)
307        sent_ratio, received_ratio = _get_compression_ratios(
308            client_function, {}, {}, {}, _GenericHandler(None), channel_kwargs,
309            multicallable_kwargs, server_kwargs, server_handler, _REQUEST)
310
311        if client_side_compressed:
312            self.assertCompressed(sent_ratio)
313        else:
314            self.assertNotCompressed(sent_ratio)
315
316        if server_side_compressed:
317            self.assertCompressed(received_ratio)
318        else:
319            self.assertNotCompressed(received_ratio)
320
321    def testDisableNextCompressionStreaming(self):
322        server_kwargs = {
323            'compression': grpc.Compression.Deflate,
324        }
325        _, received_ratio = _get_compression_ratios(
326            _stream_stream_client, {}, {}, {}, _GenericHandler(None), {}, {},
327            server_kwargs, _GenericHandler(disable_next_compression), _REQUEST)
328        self.assertNotCompressed(received_ratio)
329
330    def testDisableNextCompressionStreamingResets(self):
331        server_kwargs = {
332            'compression': grpc.Compression.Deflate,
333        }
334        _, received_ratio = _get_compression_ratios(
335            _stream_stream_client, {}, {}, {}, _GenericHandler(None), {}, {},
336            server_kwargs, _GenericHandler(disable_first_compression), _REQUEST)
337        self.assertCompressed(received_ratio)
338
339
340def _get_compression_str(name, value):
341    return '{}{}'.format(name, _COMPRESSION_NAMES[value])
342
343
344def _get_compression_test_name(client_streaming, server_streaming,
345                               channel_compression, multicallable_compression,
346                               server_compression, server_call_compression):
347    client_arity = 'Stream' if client_streaming else 'Unary'
348    server_arity = 'Stream' if server_streaming else 'Unary'
349    arity = '{}{}'.format(client_arity, server_arity)
350    channel_compression_str = _get_compression_str('Channel',
351                                                   channel_compression)
352    multicallable_compression_str = _get_compression_str(
353        'Multicallable', multicallable_compression)
354    server_compression_str = _get_compression_str('Server', server_compression)
355    server_call_compression_str = _get_compression_str('ServerCall',
356                                                       server_call_compression)
357    return 'test{}{}{}{}{}'.format(arity, channel_compression_str,
358                                   multicallable_compression_str,
359                                   server_compression_str,
360                                   server_call_compression_str)
361
362
363def _test_options():
364    for test_parameters in itertools.product(*_TEST_OPTIONS.values()):
365        yield dict(zip(_TEST_OPTIONS.keys(), test_parameters))
366
367
368for options in _test_options():
369
370    def test_compression(**kwargs):
371
372        def _test_compression(self):
373            self.assertConfigurationCompressed(**kwargs)
374
375        return _test_compression
376
377    setattr(CompressionTest, _get_compression_test_name(**options),
378            test_compression(**options))
379
380if __name__ == '__main__':
381    logging.basicConfig()
382    unittest.main(verbosity=2)
383