• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests behavior around the metadata mechanism."""
15
16import asyncio
17import logging
18import platform
19import random
20import unittest
21
22import grpc
23from grpc.experimental import aio
24
25from tests_aio.unit import _common
26from tests_aio.unit._test_base import AioTestBase
27
28_TEST_CLIENT_TO_SERVER = "/test/TestClientToServer"
29_TEST_SERVER_TO_CLIENT = "/test/TestServerToClient"
30_TEST_TRAILING_METADATA = "/test/TestTrailingMetadata"
31_TEST_ECHO_INITIAL_METADATA = "/test/TestEchoInitialMetadata"
32_TEST_GENERIC_HANDLER = "/test/TestGenericHandler"
33_TEST_UNARY_STREAM = "/test/TestUnaryStream"
34_TEST_STREAM_UNARY = "/test/TestStreamUnary"
35_TEST_STREAM_STREAM = "/test/TestStreamStream"
36_TEST_INSPECT_CONTEXT = "/test/TestInspectContext"
37
38_REQUEST = b"\x00\x00\x00"
39_RESPONSE = b"\x01\x01\x01"
40
41_INITIAL_METADATA_FROM_CLIENT_TO_SERVER = aio.Metadata(
42    ("client-to-server", "question"),
43    ("client-to-server-bin", b"\x07\x07\x07"),
44)
45_INITIAL_METADATA_FROM_SERVER_TO_CLIENT = aio.Metadata(
46    ("server-to-client", "answer"),
47    ("server-to-client-bin", b"\x06\x06\x06"),
48)
49_TRAILING_METADATA = aio.Metadata(
50    ("a-trailing-metadata", "stack-trace"),
51    ("a-trailing-metadata-bin", b"\x05\x05\x05"),
52)
53_INITIAL_METADATA_FOR_GENERIC_HANDLER = aio.Metadata(
54    ("a-must-have-key", "secret"),
55)
56
57_INVALID_METADATA_TEST_CASES = (
58    (
59        TypeError,
60        ((42, 42),),
61    ),
62    (
63        TypeError,
64        (({}, {}),),
65    ),
66    (
67        TypeError,
68        ((None, {}),),
69    ),
70    (
71        TypeError,
72        (({}, {}),),
73    ),
74    (
75        TypeError,
76        (("normal", object()),),
77    ),
78)
79
80_NON_OK_CODE = grpc.StatusCode.NOT_FOUND
81_DETAILS = "Test details!"
82
83
84class _TestGenericHandlerForMethods(grpc.GenericRpcHandler):
85    def __init__(self):
86        self._routing_table = {
87            _TEST_CLIENT_TO_SERVER: grpc.unary_unary_rpc_method_handler(
88                self._test_client_to_server
89            ),
90            _TEST_SERVER_TO_CLIENT: grpc.unary_unary_rpc_method_handler(
91                self._test_server_to_client
92            ),
93            _TEST_TRAILING_METADATA: grpc.unary_unary_rpc_method_handler(
94                self._test_trailing_metadata
95            ),
96            _TEST_UNARY_STREAM: grpc.unary_stream_rpc_method_handler(
97                self._test_unary_stream
98            ),
99            _TEST_STREAM_UNARY: grpc.stream_unary_rpc_method_handler(
100                self._test_stream_unary
101            ),
102            _TEST_STREAM_STREAM: grpc.stream_stream_rpc_method_handler(
103                self._test_stream_stream
104            ),
105            _TEST_INSPECT_CONTEXT: grpc.unary_unary_rpc_method_handler(
106                self._test_inspect_context
107            ),
108        }
109
110    @staticmethod
111    async def _test_client_to_server(request, context):
112        assert _REQUEST == request
113        assert _common.seen_metadata(
114            _INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
115            context.invocation_metadata(),
116        )
117        return _RESPONSE
118
119    @staticmethod
120    async def _test_server_to_client(request, context):
121        assert _REQUEST == request
122        await context.send_initial_metadata(
123            _INITIAL_METADATA_FROM_SERVER_TO_CLIENT
124        )
125        return _RESPONSE
126
127    @staticmethod
128    async def _test_trailing_metadata(request, context):
129        assert _REQUEST == request
130        context.set_trailing_metadata(_TRAILING_METADATA)
131        return _RESPONSE
132
133    @staticmethod
134    async def _test_unary_stream(request, context):
135        assert _REQUEST == request
136        assert _common.seen_metadata(
137            _INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
138            context.invocation_metadata(),
139        )
140        await context.send_initial_metadata(
141            _INITIAL_METADATA_FROM_SERVER_TO_CLIENT
142        )
143        yield _RESPONSE
144        context.set_trailing_metadata(_TRAILING_METADATA)
145
146    @staticmethod
147    async def _test_stream_unary(request_iterator, context):
148        assert _common.seen_metadata(
149            _INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
150            context.invocation_metadata(),
151        )
152        await context.send_initial_metadata(
153            _INITIAL_METADATA_FROM_SERVER_TO_CLIENT
154        )
155
156        async for request in request_iterator:
157            assert _REQUEST == request
158
159        context.set_trailing_metadata(_TRAILING_METADATA)
160        return _RESPONSE
161
162    @staticmethod
163    async def _test_stream_stream(request_iterator, context):
164        assert _common.seen_metadata(
165            _INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
166            context.invocation_metadata(),
167        )
168        await context.send_initial_metadata(
169            _INITIAL_METADATA_FROM_SERVER_TO_CLIENT
170        )
171
172        async for request in request_iterator:
173            assert _REQUEST == request
174
175        yield _RESPONSE
176        context.set_trailing_metadata(_TRAILING_METADATA)
177
178    @staticmethod
179    async def _test_inspect_context(request, context):
180        assert _REQUEST == request
181        context.set_code(_NON_OK_CODE)
182        context.set_details(_DETAILS)
183        context.set_trailing_metadata(_TRAILING_METADATA)
184
185        # ensure that we can read back the data we set on the context
186        assert context.code() == _NON_OK_CODE
187        assert context.details() == _DETAILS
188        assert context.trailing_metadata() == _TRAILING_METADATA
189        return _RESPONSE
190
191    def service(self, handler_call_details):
192        return self._routing_table.get(handler_call_details.method)
193
194
195class _TestGenericHandlerItself(grpc.GenericRpcHandler):
196    @staticmethod
197    async def _method(request, unused_context):
198        assert _REQUEST == request
199        return _RESPONSE
200
201    def service(self, handler_call_details):
202        assert _common.seen_metadata(
203            _INITIAL_METADATA_FOR_GENERIC_HANDLER,
204            handler_call_details.invocation_metadata,
205        )
206        return grpc.unary_unary_rpc_method_handler(self._method)
207
208
209async def _start_test_server():
210    server = aio.server()
211    port = server.add_insecure_port("[::]:0")
212    server.add_generic_rpc_handlers(
213        (
214            _TestGenericHandlerForMethods(),
215            _TestGenericHandlerItself(),
216        )
217    )
218    await server.start()
219    return "localhost:%d" % port, server
220
221
222class TestMetadata(AioTestBase):
223    async def setUp(self):
224        address, self._server = await _start_test_server()
225        self._client = aio.insecure_channel(address)
226
227    async def tearDown(self):
228        await self._client.close()
229        await self._server.stop(None)
230
231    async def test_from_client_to_server(self):
232        multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER)
233        call = multicallable(
234            _REQUEST, metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER
235        )
236        self.assertEqual(_RESPONSE, await call)
237        self.assertEqual(grpc.StatusCode.OK, await call.code())
238
239    async def test_from_server_to_client(self):
240        multicallable = self._client.unary_unary(_TEST_SERVER_TO_CLIENT)
241        call = multicallable(_REQUEST)
242
243        self.assertEqual(
244            _INITIAL_METADATA_FROM_SERVER_TO_CLIENT,
245            await call.initial_metadata(),
246        )
247        self.assertEqual(_RESPONSE, await call)
248        self.assertEqual(grpc.StatusCode.OK, await call.code())
249
250    async def test_trailing_metadata(self):
251        multicallable = self._client.unary_unary(_TEST_TRAILING_METADATA)
252        call = multicallable(_REQUEST)
253        self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata())
254        self.assertEqual(_RESPONSE, await call)
255        self.assertEqual(grpc.StatusCode.OK, await call.code())
256
257    async def test_from_client_to_server_with_list(self):
258        multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER)
259        call = multicallable(
260            _REQUEST, metadata=list(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)
261        )  # pytype: disable=wrong-arg-types
262        self.assertEqual(_RESPONSE, await call)
263        self.assertEqual(grpc.StatusCode.OK, await call.code())
264
265    @unittest.skipIf(
266        platform.system() == "Windows",
267        "https://github.com/grpc/grpc/issues/21943",
268    )
269    async def test_invalid_metadata(self):
270        multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER)
271        for exception_type, metadata in _INVALID_METADATA_TEST_CASES:
272            with self.subTest(metadata=metadata):
273                with self.assertRaises(exception_type):
274                    call = multicallable(_REQUEST, metadata=metadata)
275                    await call
276
277    async def test_generic_handler(self):
278        multicallable = self._client.unary_unary(_TEST_GENERIC_HANDLER)
279        call = multicallable(
280            _REQUEST, metadata=_INITIAL_METADATA_FOR_GENERIC_HANDLER
281        )
282        self.assertEqual(_RESPONSE, await call)
283        self.assertEqual(grpc.StatusCode.OK, await call.code())
284
285    async def test_unary_stream(self):
286        multicallable = self._client.unary_stream(_TEST_UNARY_STREAM)
287        call = multicallable(
288            _REQUEST, metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER
289        )
290
291        self.assertTrue(
292            _common.seen_metadata(
293                _INITIAL_METADATA_FROM_SERVER_TO_CLIENT,
294                await call.initial_metadata(),
295            )
296        )
297
298        self.assertSequenceEqual(
299            [_RESPONSE], [request async for request in call]
300        )
301
302        self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata())
303        self.assertEqual(grpc.StatusCode.OK, await call.code())
304
305    async def test_stream_unary(self):
306        multicallable = self._client.stream_unary(_TEST_STREAM_UNARY)
307        call = multicallable(metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)
308        await call.write(_REQUEST)
309        await call.done_writing()
310
311        self.assertTrue(
312            _common.seen_metadata(
313                _INITIAL_METADATA_FROM_SERVER_TO_CLIENT,
314                await call.initial_metadata(),
315            )
316        )
317        self.assertEqual(_RESPONSE, await call)
318
319        self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata())
320        self.assertEqual(grpc.StatusCode.OK, await call.code())
321
322    async def test_stream_stream(self):
323        multicallable = self._client.stream_stream(_TEST_STREAM_STREAM)
324        call = multicallable(metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)
325        await call.write(_REQUEST)
326        await call.done_writing()
327
328        self.assertTrue(
329            _common.seen_metadata(
330                _INITIAL_METADATA_FROM_SERVER_TO_CLIENT,
331                await call.initial_metadata(),
332            )
333        )
334        self.assertSequenceEqual(
335            [_RESPONSE], [request async for request in call]
336        )
337        self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata())
338        self.assertEqual(grpc.StatusCode.OK, await call.code())
339
340    async def test_compatibility_with_tuple(self):
341        metadata_obj = aio.Metadata(("key", "42"), ("key-2", "value"))
342        self.assertEqual(metadata_obj, tuple(metadata_obj))
343        self.assertEqual(tuple(metadata_obj), metadata_obj)
344
345        expected_sum = tuple(metadata_obj) + (("third", "3"),)
346        self.assertEqual(expected_sum, metadata_obj + (("third", "3"),))
347        self.assertEqual(
348            expected_sum, metadata_obj + aio.Metadata(("third", "3"))
349        )
350
351    async def test_inspect_context(self):
352        multicallable = self._client.unary_unary(_TEST_INSPECT_CONTEXT)
353        call = multicallable(_REQUEST)
354        with self.assertRaises(grpc.RpcError) as exc_data:
355            await call
356
357        err = exc_data.exception
358        self.assertEqual(_NON_OK_CODE, err.code())
359
360
361if __name__ == "__main__":
362    logging.basicConfig(level=logging.DEBUG)
363    unittest.main(verbosity=2)
364