• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests behavior around the metadata mechanism."""
15
16import asyncio
17import logging
18import platform
19import random
20import unittest
21
22import grpc
23from grpc.experimental import aio
24
25from tests_aio.unit._test_base import AioTestBase
26from tests_aio.unit import _common
27
28_TEST_CLIENT_TO_SERVER = '/test/TestClientToServer'
29_TEST_SERVER_TO_CLIENT = '/test/TestServerToClient'
30_TEST_TRAILING_METADATA = '/test/TestTrailingMetadata'
31_TEST_ECHO_INITIAL_METADATA = '/test/TestEchoInitialMetadata'
32_TEST_GENERIC_HANDLER = '/test/TestGenericHandler'
33_TEST_UNARY_STREAM = '/test/TestUnaryStream'
34_TEST_STREAM_UNARY = '/test/TestStreamUnary'
35_TEST_STREAM_STREAM = '/test/TestStreamStream'
36
37_REQUEST = b'\x00\x00\x00'
38_RESPONSE = b'\x01\x01\x01'
39
40_INITIAL_METADATA_FROM_CLIENT_TO_SERVER = aio.Metadata(
41    ('client-to-server', 'question'),
42    ('client-to-server-bin', b'\x07\x07\x07'),
43)
44_INITIAL_METADATA_FROM_SERVER_TO_CLIENT = aio.Metadata(
45    ('server-to-client', 'answer'),
46    ('server-to-client-bin', b'\x06\x06\x06'),
47)
48_TRAILING_METADATA = aio.Metadata(
49    ('a-trailing-metadata', 'stack-trace'),
50    ('a-trailing-metadata-bin', b'\x05\x05\x05'),
51)
52_INITIAL_METADATA_FOR_GENERIC_HANDLER = aio.Metadata(
53    ('a-must-have-key', 'secret'),)
54
55_INVALID_METADATA_TEST_CASES = (
56    (
57        TypeError,
58        ((42, 42),),
59    ),
60    (
61        TypeError,
62        (({}, {}),),
63    ),
64    (
65        TypeError,
66        ((None, {}),),
67    ),
68    (
69        TypeError,
70        (({}, {}),),
71    ),
72    (
73        TypeError,
74        (('normal', object()),),
75    ),
76)
77
78
79class _TestGenericHandlerForMethods(grpc.GenericRpcHandler):
80
81    def __init__(self):
82        self._routing_table = {
83            _TEST_CLIENT_TO_SERVER:
84                grpc.unary_unary_rpc_method_handler(self._test_client_to_server
85                                                   ),
86            _TEST_SERVER_TO_CLIENT:
87                grpc.unary_unary_rpc_method_handler(self._test_server_to_client
88                                                   ),
89            _TEST_TRAILING_METADATA:
90                grpc.unary_unary_rpc_method_handler(self._test_trailing_metadata
91                                                   ),
92            _TEST_UNARY_STREAM:
93                grpc.unary_stream_rpc_method_handler(self._test_unary_stream),
94            _TEST_STREAM_UNARY:
95                grpc.stream_unary_rpc_method_handler(self._test_stream_unary),
96            _TEST_STREAM_STREAM:
97                grpc.stream_stream_rpc_method_handler(self._test_stream_stream),
98        }
99
100    @staticmethod
101    async def _test_client_to_server(request, context):
102        assert _REQUEST == request
103        assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
104                                     context.invocation_metadata())
105        return _RESPONSE
106
107    @staticmethod
108    async def _test_server_to_client(request, context):
109        assert _REQUEST == request
110        await context.send_initial_metadata(
111            _INITIAL_METADATA_FROM_SERVER_TO_CLIENT)
112        return _RESPONSE
113
114    @staticmethod
115    async def _test_trailing_metadata(request, context):
116        assert _REQUEST == request
117        context.set_trailing_metadata(_TRAILING_METADATA)
118        return _RESPONSE
119
120    @staticmethod
121    async def _test_unary_stream(request, context):
122        assert _REQUEST == request
123        assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
124                                     context.invocation_metadata())
125        await context.send_initial_metadata(
126            _INITIAL_METADATA_FROM_SERVER_TO_CLIENT)
127        yield _RESPONSE
128        context.set_trailing_metadata(_TRAILING_METADATA)
129
130    @staticmethod
131    async def _test_stream_unary(request_iterator, context):
132        assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
133                                     context.invocation_metadata())
134        await context.send_initial_metadata(
135            _INITIAL_METADATA_FROM_SERVER_TO_CLIENT)
136
137        async for request in request_iterator:
138            assert _REQUEST == request
139
140        context.set_trailing_metadata(_TRAILING_METADATA)
141        return _RESPONSE
142
143    @staticmethod
144    async def _test_stream_stream(request_iterator, context):
145        assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
146                                     context.invocation_metadata())
147        await context.send_initial_metadata(
148            _INITIAL_METADATA_FROM_SERVER_TO_CLIENT)
149
150        async for request in request_iterator:
151            assert _REQUEST == request
152
153        yield _RESPONSE
154        context.set_trailing_metadata(_TRAILING_METADATA)
155
156    def service(self, handler_call_details):
157        return self._routing_table.get(handler_call_details.method)
158
159
160class _TestGenericHandlerItself(grpc.GenericRpcHandler):
161
162    @staticmethod
163    async def _method(request, unused_context):
164        assert _REQUEST == request
165        return _RESPONSE
166
167    def service(self, handler_call_details):
168        assert _common.seen_metadata(_INITIAL_METADATA_FOR_GENERIC_HANDLER,
169                                     handler_call_details.invocation_metadata)
170        return grpc.unary_unary_rpc_method_handler(self._method)
171
172
173async def _start_test_server():
174    server = aio.server()
175    port = server.add_insecure_port('[::]:0')
176    server.add_generic_rpc_handlers((
177        _TestGenericHandlerForMethods(),
178        _TestGenericHandlerItself(),
179    ))
180    await server.start()
181    return 'localhost:%d' % port, server
182
183
184class TestMetadata(AioTestBase):
185
186    async def setUp(self):
187        address, self._server = await _start_test_server()
188        self._client = aio.insecure_channel(address)
189
190    async def tearDown(self):
191        await self._client.close()
192        await self._server.stop(None)
193
194    async def test_from_client_to_server(self):
195        multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER)
196        call = multicallable(_REQUEST,
197                             metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)
198        self.assertEqual(_RESPONSE, await call)
199        self.assertEqual(grpc.StatusCode.OK, await call.code())
200
201    async def test_from_server_to_client(self):
202        multicallable = self._client.unary_unary(_TEST_SERVER_TO_CLIENT)
203        call = multicallable(_REQUEST)
204
205        self.assertEqual(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await
206                         call.initial_metadata())
207        self.assertEqual(_RESPONSE, await call)
208        self.assertEqual(grpc.StatusCode.OK, await call.code())
209
210    async def test_trailing_metadata(self):
211        multicallable = self._client.unary_unary(_TEST_TRAILING_METADATA)
212        call = multicallable(_REQUEST)
213        self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata())
214        self.assertEqual(_RESPONSE, await call)
215        self.assertEqual(grpc.StatusCode.OK, await call.code())
216
217    async def test_from_client_to_server_with_list(self):
218        multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER)
219        call = multicallable(
220            _REQUEST, metadata=list(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER))  # pytype: disable=wrong-arg-types
221        self.assertEqual(_RESPONSE, await call)
222        self.assertEqual(grpc.StatusCode.OK, await call.code())
223
224    @unittest.skipIf(platform.system() == 'Windows',
225                     'https://github.com/grpc/grpc/issues/21943')
226    async def test_invalid_metadata(self):
227        multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER)
228        for exception_type, metadata in _INVALID_METADATA_TEST_CASES:
229            with self.subTest(metadata=metadata):
230                with self.assertRaises(exception_type):
231                    call = multicallable(_REQUEST, metadata=metadata)
232                    await call
233
234    async def test_generic_handler(self):
235        multicallable = self._client.unary_unary(_TEST_GENERIC_HANDLER)
236        call = multicallable(_REQUEST,
237                             metadata=_INITIAL_METADATA_FOR_GENERIC_HANDLER)
238        self.assertEqual(_RESPONSE, await call)
239        self.assertEqual(grpc.StatusCode.OK, await call.code())
240
241    async def test_unary_stream(self):
242        multicallable = self._client.unary_stream(_TEST_UNARY_STREAM)
243        call = multicallable(_REQUEST,
244                             metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)
245
246        self.assertTrue(
247            _common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await
248                                  call.initial_metadata()))
249
250        self.assertSequenceEqual([_RESPONSE],
251                                 [request async for request in call])
252
253        self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata())
254        self.assertEqual(grpc.StatusCode.OK, await call.code())
255
256    async def test_stream_unary(self):
257        multicallable = self._client.stream_unary(_TEST_STREAM_UNARY)
258        call = multicallable(metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)
259        await call.write(_REQUEST)
260        await call.done_writing()
261
262        self.assertTrue(
263            _common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await
264                                  call.initial_metadata()))
265        self.assertEqual(_RESPONSE, await call)
266
267        self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata())
268        self.assertEqual(grpc.StatusCode.OK, await call.code())
269
270    async def test_stream_stream(self):
271        multicallable = self._client.stream_stream(_TEST_STREAM_STREAM)
272        call = multicallable(metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)
273        await call.write(_REQUEST)
274        await call.done_writing()
275
276        self.assertTrue(
277            _common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await
278                                  call.initial_metadata()))
279        self.assertSequenceEqual([_RESPONSE],
280                                 [request async for request in call])
281        self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata())
282        self.assertEqual(grpc.StatusCode.OK, await call.code())
283
284    async def test_compatibility_with_tuple(self):
285        metadata_obj = aio.Metadata(('key', '42'), ('key-2', 'value'))
286        self.assertEqual(metadata_obj, tuple(metadata_obj))
287        self.assertEqual(tuple(metadata_obj), metadata_obj)
288
289        expected_sum = tuple(metadata_obj) + (('third', '3'),)
290        self.assertEqual(expected_sum, metadata_obj + (('third', '3'),))
291        self.assertEqual(expected_sum, metadata_obj + aio.Metadata(
292            ('third', '3')))
293
294
295if __name__ == '__main__':
296    logging.basicConfig(level=logging.DEBUG)
297    unittest.main(verbosity=2)
298