• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests server and client side compression."""
15
16from concurrent import futures
17import contextlib
18import functools
19import itertools
20import logging
21import os
22import unittest
23
24import grpc
25from grpc import _grpcio_metadata
26
27from tests.unit import _tcp_proxy
28from tests.unit.framework.common import test_constants
29
30_SERVICE_NAME = "test"
31_UNARY_UNARY = "UnaryUnary"
32_UNARY_STREAM = "UnaryStream"
33_STREAM_UNARY = "StreamUnary"
34_STREAM_STREAM = "StreamStream"
35
36# Cut down on test time.
37_STREAM_LENGTH = test_constants.STREAM_LENGTH // 16
38
39_HOST = "localhost"
40
41_REQUEST = b"\x00" * 100
42_COMPRESSION_RATIO_THRESHOLD = 0.05
43_COMPRESSION_METHODS = (
44    None,
45    # Disabled for test tractability.
46    # grpc.Compression.NoCompression,
47    # grpc.Compression.Deflate,
48    grpc.Compression.Gzip,
49)
50_COMPRESSION_NAMES = {
51    None: "Uncompressed",
52    grpc.Compression.NoCompression: "NoCompression",
53    grpc.Compression.Deflate: "DeflateCompression",
54    grpc.Compression.Gzip: "GzipCompression",
55}
56
57_TEST_OPTIONS = {
58    "client_streaming": (True, False),
59    "server_streaming": (True, False),
60    "channel_compression": _COMPRESSION_METHODS,
61    "multicallable_compression": _COMPRESSION_METHODS,
62    "server_compression": _COMPRESSION_METHODS,
63    "server_call_compression": _COMPRESSION_METHODS,
64}
65
66
67def _make_handle_unary_unary(pre_response_callback):
68    def _handle_unary(request, servicer_context):
69        if pre_response_callback:
70            pre_response_callback(request, servicer_context)
71        return request
72
73    return _handle_unary
74
75
76def _make_handle_unary_stream(pre_response_callback):
77    def _handle_unary_stream(request, servicer_context):
78        if pre_response_callback:
79            pre_response_callback(request, servicer_context)
80        for _ in range(_STREAM_LENGTH):
81            yield request
82
83    return _handle_unary_stream
84
85
86def _make_handle_stream_unary(pre_response_callback):
87    def _handle_stream_unary(request_iterator, servicer_context):
88        if pre_response_callback:
89            pre_response_callback(request_iterator, servicer_context)
90        response = None
91        for request in request_iterator:
92            if not response:
93                response = request
94        return response
95
96    return _handle_stream_unary
97
98
99def _make_handle_stream_stream(pre_response_callback):
100    def _handle_stream(request_iterator, servicer_context):
101        # TODO(issue:#6891) We should be able to remove this loop,
102        # and replace with return; yield
103        for request in request_iterator:
104            if pre_response_callback:
105                pre_response_callback(request, servicer_context)
106            yield request
107
108    return _handle_stream
109
110
111def set_call_compression(
112    compression_method, request_or_iterator, servicer_context
113):
114    del request_or_iterator
115    servicer_context.set_compression(compression_method)
116
117
118def disable_next_compression(request, servicer_context):
119    del request
120    servicer_context.disable_next_message_compression()
121
122
123def disable_first_compression(request, servicer_context):
124    if int(request.decode("ascii")) == 0:
125        servicer_context.disable_next_message_compression()
126
127
128class _MethodHandler(grpc.RpcMethodHandler):
129    def __init__(
130        self, request_streaming, response_streaming, pre_response_callback
131    ):
132        self.request_streaming = request_streaming
133        self.response_streaming = response_streaming
134        self.request_deserializer = None
135        self.response_serializer = None
136        self.unary_unary = None
137        self.unary_stream = None
138        self.stream_unary = None
139        self.stream_stream = None
140
141        if self.request_streaming and self.response_streaming:
142            self.stream_stream = _make_handle_stream_stream(
143                pre_response_callback
144            )
145        elif not self.request_streaming and not self.response_streaming:
146            self.unary_unary = _make_handle_unary_unary(pre_response_callback)
147        elif not self.request_streaming and self.response_streaming:
148            self.unary_stream = _make_handle_unary_stream(pre_response_callback)
149        else:
150            self.stream_unary = _make_handle_stream_unary(pre_response_callback)
151
152
153def get_method_handlers(pre_response_callback):
154    return {
155        _UNARY_UNARY: _MethodHandler(False, False, pre_response_callback),
156        _UNARY_STREAM: _MethodHandler(False, True, pre_response_callback),
157        _STREAM_UNARY: _MethodHandler(True, False, pre_response_callback),
158        _STREAM_STREAM: _MethodHandler(True, True, pre_response_callback),
159    }
160
161
162@contextlib.contextmanager
163def _instrumented_client_server_pair(
164    channel_kwargs, server_kwargs, server_handler
165):
166    server = grpc.server(futures.ThreadPoolExecutor(), **server_kwargs)
167    server.add_registered_method_handlers(_SERVICE_NAME, server_handler)
168    server_port = server.add_insecure_port("{}:0".format(_HOST))
169    server.start()
170    with _tcp_proxy.TcpProxy(_HOST, _HOST, server_port) as proxy:
171        proxy_port = proxy.get_port()
172        with grpc.insecure_channel(
173            "{}:{}".format(_HOST, proxy_port), **channel_kwargs
174        ) as client_channel:
175            try:
176                yield client_channel, proxy, server
177            finally:
178                server.stop(None)
179
180
181def _get_byte_counts(
182    channel_kwargs,
183    multicallable_kwargs,
184    client_function,
185    server_kwargs,
186    server_handler,
187    message,
188):
189    with _instrumented_client_server_pair(
190        channel_kwargs, server_kwargs, server_handler
191    ) as pipeline:
192        client_channel, proxy, server = pipeline
193        client_function(client_channel, multicallable_kwargs, message)
194        return proxy.get_byte_count()
195
196
197def _get_compression_ratios(
198    client_function,
199    first_channel_kwargs,
200    first_multicallable_kwargs,
201    first_server_kwargs,
202    first_server_handler,
203    second_channel_kwargs,
204    second_multicallable_kwargs,
205    second_server_kwargs,
206    second_server_handler,
207    message,
208):
209    first_bytes_sent, first_bytes_received = _get_byte_counts(
210        first_channel_kwargs,
211        first_multicallable_kwargs,
212        client_function,
213        first_server_kwargs,
214        first_server_handler,
215        message,
216    )
217    second_bytes_sent, second_bytes_received = _get_byte_counts(
218        second_channel_kwargs,
219        second_multicallable_kwargs,
220        client_function,
221        second_server_kwargs,
222        second_server_handler,
223        message,
224    )
225    return (
226        (second_bytes_sent - first_bytes_sent) / float(first_bytes_sent),
227        (second_bytes_received - first_bytes_received)
228        / float(first_bytes_received),
229    )
230
231
232def _unary_unary_client(channel, multicallable_kwargs, message):
233    multi_callable = channel.unary_unary(
234        grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_UNARY),
235        _registered_method=True,
236    )
237    response = multi_callable(message, **multicallable_kwargs)
238    if response != message:
239        raise RuntimeError(
240            "Request '{}' != Response '{}'".format(message, response)
241        )
242
243
244def _unary_stream_client(channel, multicallable_kwargs, message):
245    multi_callable = channel.unary_stream(
246        grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_STREAM),
247        _registered_method=True,
248    )
249    response_iterator = multi_callable(message, **multicallable_kwargs)
250    for response in response_iterator:
251        if response != message:
252            raise RuntimeError(
253                "Request '{}' != Response '{}'".format(message, response)
254            )
255
256
257def _stream_unary_client(channel, multicallable_kwargs, message):
258    multi_callable = channel.stream_unary(
259        grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_UNARY),
260        _registered_method=True,
261    )
262    requests = (_REQUEST for _ in range(_STREAM_LENGTH))
263    response = multi_callable(requests, **multicallable_kwargs)
264    if response != message:
265        raise RuntimeError(
266            "Request '{}' != Response '{}'".format(message, response)
267        )
268
269
270def _stream_stream_client(channel, multicallable_kwargs, message):
271    multi_callable = channel.stream_stream(
272        grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_STREAM),
273        _registered_method=True,
274    )
275    request_prefix = str(0).encode("ascii") * 100
276    requests = (
277        request_prefix + str(i).encode("ascii") for i in range(_STREAM_LENGTH)
278    )
279    response_iterator = multi_callable(requests, **multicallable_kwargs)
280    for i, response in enumerate(response_iterator):
281        if int(response.decode("ascii")) != i:
282            raise RuntimeError(
283                "Request '{}' != Response '{}'".format(i, response)
284            )
285
286
287class CompressionTest(unittest.TestCase):
288    def assertCompressed(self, compression_ratio):
289        self.assertLess(
290            compression_ratio,
291            -1.0 * _COMPRESSION_RATIO_THRESHOLD,
292            msg="Actual compression ratio: {}".format(compression_ratio),
293        )
294
295    def assertNotCompressed(self, compression_ratio):
296        self.assertGreaterEqual(
297            compression_ratio,
298            -1.0 * _COMPRESSION_RATIO_THRESHOLD,
299            msg="Actual compression ratio: {}".format(compression_ratio),
300        )
301
302    def assertConfigurationCompressed(
303        self,
304        client_streaming,
305        server_streaming,
306        channel_compression,
307        multicallable_compression,
308        server_compression,
309        server_call_compression,
310    ):
311        client_side_compressed = (
312            channel_compression or multicallable_compression
313        )
314        server_side_compressed = server_compression or server_call_compression
315        channel_kwargs = (
316            {
317                "compression": channel_compression,
318            }
319            if channel_compression
320            else {}
321        )
322        multicallable_kwargs = (
323            {
324                "compression": multicallable_compression,
325            }
326            if multicallable_compression
327            else {}
328        )
329
330        client_function = None
331        if not client_streaming and not server_streaming:
332            client_function = _unary_unary_client
333        elif not client_streaming and server_streaming:
334            client_function = _unary_stream_client
335        elif client_streaming and not server_streaming:
336            client_function = _stream_unary_client
337        else:
338            client_function = _stream_stream_client
339
340        server_kwargs = (
341            {
342                "compression": server_compression,
343            }
344            if server_compression
345            else {}
346        )
347        if server_call_compression:
348            server_handler = get_method_handlers(
349                functools.partial(set_call_compression, grpc.Compression.Gzip)
350            )
351        else:
352            server_handler = get_method_handlers(None)
353
354        _get_compression_ratios(
355            client_function,
356            {},
357            {},
358            {},
359            get_method_handlers(None),
360            channel_kwargs,
361            multicallable_kwargs,
362            server_kwargs,
363            server_handler,
364            _REQUEST,
365        )
366
367    def testDisableNextCompressionStreaming(self):
368        server_kwargs = {
369            "compression": grpc.Compression.Deflate,
370        }
371        _get_compression_ratios(
372            _stream_stream_client,
373            {},
374            {},
375            {},
376            get_method_handlers(None),
377            {},
378            {},
379            server_kwargs,
380            get_method_handlers(disable_next_compression),
381            _REQUEST,
382        )
383
384    def testDisableNextCompressionStreamingResets(self):
385        server_kwargs = {
386            "compression": grpc.Compression.Deflate,
387        }
388        _get_compression_ratios(
389            _stream_stream_client,
390            {},
391            {},
392            {},
393            get_method_handlers(None),
394            {},
395            {},
396            server_kwargs,
397            get_method_handlers(disable_first_compression),
398            _REQUEST,
399        )
400
401
402def _get_compression_str(name, value):
403    return "{}{}".format(name, _COMPRESSION_NAMES[value])
404
405
406def _get_compression_test_name(
407    client_streaming,
408    server_streaming,
409    channel_compression,
410    multicallable_compression,
411    server_compression,
412    server_call_compression,
413):
414    client_arity = "Stream" if client_streaming else "Unary"
415    server_arity = "Stream" if server_streaming else "Unary"
416    arity = "{}{}".format(client_arity, server_arity)
417    channel_compression_str = _get_compression_str(
418        "Channel", channel_compression
419    )
420    multicallable_compression_str = _get_compression_str(
421        "Multicallable", multicallable_compression
422    )
423    server_compression_str = _get_compression_str("Server", server_compression)
424    server_call_compression_str = _get_compression_str(
425        "ServerCall", server_call_compression
426    )
427    return "test{}{}{}{}{}".format(
428        arity,
429        channel_compression_str,
430        multicallable_compression_str,
431        server_compression_str,
432        server_call_compression_str,
433    )
434
435
436def _test_options():
437    for test_parameters in itertools.product(*_TEST_OPTIONS.values()):
438        yield dict(zip(_TEST_OPTIONS.keys(), test_parameters))
439
440
441for options in _test_options():
442
443    def test_compression(**kwargs):
444        def _test_compression(self):
445            self.assertConfigurationCompressed(**kwargs)
446
447        return _test_compression
448
449    setattr(
450        CompressionTest,
451        _get_compression_test_name(**options),
452        test_compression(**options),
453    )
454
455if __name__ == "__main__":
456    logging.basicConfig()
457    unittest.main(verbosity=2)
458