1# Copyright 2020 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests behavior around the metadata mechanism.""" 15 16import asyncio 17import logging 18import platform 19import random 20import unittest 21 22import grpc 23from grpc.experimental import aio 24 25from tests_aio.unit._test_base import AioTestBase 26from tests_aio.unit import _common 27 28_TEST_CLIENT_TO_SERVER = '/test/TestClientToServer' 29_TEST_SERVER_TO_CLIENT = '/test/TestServerToClient' 30_TEST_TRAILING_METADATA = '/test/TestTrailingMetadata' 31_TEST_ECHO_INITIAL_METADATA = '/test/TestEchoInitialMetadata' 32_TEST_GENERIC_HANDLER = '/test/TestGenericHandler' 33_TEST_UNARY_STREAM = '/test/TestUnaryStream' 34_TEST_STREAM_UNARY = '/test/TestStreamUnary' 35_TEST_STREAM_STREAM = '/test/TestStreamStream' 36 37_REQUEST = b'\x00\x00\x00' 38_RESPONSE = b'\x01\x01\x01' 39 40_INITIAL_METADATA_FROM_CLIENT_TO_SERVER = aio.Metadata( 41 ('client-to-server', 'question'), 42 ('client-to-server-bin', b'\x07\x07\x07'), 43) 44_INITIAL_METADATA_FROM_SERVER_TO_CLIENT = aio.Metadata( 45 ('server-to-client', 'answer'), 46 ('server-to-client-bin', b'\x06\x06\x06'), 47) 48_TRAILING_METADATA = aio.Metadata( 49 ('a-trailing-metadata', 'stack-trace'), 50 ('a-trailing-metadata-bin', b'\x05\x05\x05'), 51) 52_INITIAL_METADATA_FOR_GENERIC_HANDLER = aio.Metadata( 53 ('a-must-have-key', 'secret'),) 54 55_INVALID_METADATA_TEST_CASES = ( 56 ( 57 TypeError, 58 ((42, 42),), 59 ), 60 ( 61 TypeError, 62 (({}, {}),), 63 ), 64 ( 65 TypeError, 66 ((None, {}),), 67 ), 68 ( 69 TypeError, 70 (({}, {}),), 71 ), 72 ( 73 TypeError, 74 (('normal', object()),), 75 ), 76) 77 78 79class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): 80 81 def __init__(self): 82 self._routing_table = { 83 _TEST_CLIENT_TO_SERVER: 84 grpc.unary_unary_rpc_method_handler(self._test_client_to_server 85 ), 86 _TEST_SERVER_TO_CLIENT: 87 grpc.unary_unary_rpc_method_handler(self._test_server_to_client 88 ), 89 _TEST_TRAILING_METADATA: 90 grpc.unary_unary_rpc_method_handler(self._test_trailing_metadata 91 ), 92 _TEST_UNARY_STREAM: 93 grpc.unary_stream_rpc_method_handler(self._test_unary_stream), 94 _TEST_STREAM_UNARY: 95 grpc.stream_unary_rpc_method_handler(self._test_stream_unary), 96 _TEST_STREAM_STREAM: 97 grpc.stream_stream_rpc_method_handler(self._test_stream_stream), 98 } 99 100 @staticmethod 101 async def _test_client_to_server(request, context): 102 assert _REQUEST == request 103 assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, 104 context.invocation_metadata()) 105 return _RESPONSE 106 107 @staticmethod 108 async def _test_server_to_client(request, context): 109 assert _REQUEST == request 110 await context.send_initial_metadata( 111 _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) 112 return _RESPONSE 113 114 @staticmethod 115 async def _test_trailing_metadata(request, context): 116 assert _REQUEST == request 117 context.set_trailing_metadata(_TRAILING_METADATA) 118 return _RESPONSE 119 120 @staticmethod 121 async def _test_unary_stream(request, context): 122 assert _REQUEST == request 123 assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, 124 context.invocation_metadata()) 125 await context.send_initial_metadata( 126 _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) 127 yield _RESPONSE 128 context.set_trailing_metadata(_TRAILING_METADATA) 129 130 @staticmethod 131 async def _test_stream_unary(request_iterator, context): 132 assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, 133 context.invocation_metadata()) 134 await context.send_initial_metadata( 135 _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) 136 137 async for request in request_iterator: 138 assert _REQUEST == request 139 140 context.set_trailing_metadata(_TRAILING_METADATA) 141 return _RESPONSE 142 143 @staticmethod 144 async def _test_stream_stream(request_iterator, context): 145 assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, 146 context.invocation_metadata()) 147 await context.send_initial_metadata( 148 _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) 149 150 async for request in request_iterator: 151 assert _REQUEST == request 152 153 yield _RESPONSE 154 context.set_trailing_metadata(_TRAILING_METADATA) 155 156 def service(self, handler_call_details): 157 return self._routing_table.get(handler_call_details.method) 158 159 160class _TestGenericHandlerItself(grpc.GenericRpcHandler): 161 162 @staticmethod 163 async def _method(request, unused_context): 164 assert _REQUEST == request 165 return _RESPONSE 166 167 def service(self, handler_call_details): 168 assert _common.seen_metadata(_INITIAL_METADATA_FOR_GENERIC_HANDLER, 169 handler_call_details.invocation_metadata) 170 return grpc.unary_unary_rpc_method_handler(self._method) 171 172 173async def _start_test_server(): 174 server = aio.server() 175 port = server.add_insecure_port('[::]:0') 176 server.add_generic_rpc_handlers(( 177 _TestGenericHandlerForMethods(), 178 _TestGenericHandlerItself(), 179 )) 180 await server.start() 181 return 'localhost:%d' % port, server 182 183 184class TestMetadata(AioTestBase): 185 186 async def setUp(self): 187 address, self._server = await _start_test_server() 188 self._client = aio.insecure_channel(address) 189 190 async def tearDown(self): 191 await self._client.close() 192 await self._server.stop(None) 193 194 async def test_from_client_to_server(self): 195 multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) 196 call = multicallable(_REQUEST, 197 metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) 198 self.assertEqual(_RESPONSE, await call) 199 self.assertEqual(grpc.StatusCode.OK, await call.code()) 200 201 async def test_from_server_to_client(self): 202 multicallable = self._client.unary_unary(_TEST_SERVER_TO_CLIENT) 203 call = multicallable(_REQUEST) 204 205 self.assertEqual(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await 206 call.initial_metadata()) 207 self.assertEqual(_RESPONSE, await call) 208 self.assertEqual(grpc.StatusCode.OK, await call.code()) 209 210 async def test_trailing_metadata(self): 211 multicallable = self._client.unary_unary(_TEST_TRAILING_METADATA) 212 call = multicallable(_REQUEST) 213 self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) 214 self.assertEqual(_RESPONSE, await call) 215 self.assertEqual(grpc.StatusCode.OK, await call.code()) 216 217 async def test_from_client_to_server_with_list(self): 218 multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) 219 call = multicallable( 220 _REQUEST, metadata=list(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)) # pytype: disable=wrong-arg-types 221 self.assertEqual(_RESPONSE, await call) 222 self.assertEqual(grpc.StatusCode.OK, await call.code()) 223 224 @unittest.skipIf(platform.system() == 'Windows', 225 'https://github.com/grpc/grpc/issues/21943') 226 async def test_invalid_metadata(self): 227 multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) 228 for exception_type, metadata in _INVALID_METADATA_TEST_CASES: 229 with self.subTest(metadata=metadata): 230 with self.assertRaises(exception_type): 231 call = multicallable(_REQUEST, metadata=metadata) 232 await call 233 234 async def test_generic_handler(self): 235 multicallable = self._client.unary_unary(_TEST_GENERIC_HANDLER) 236 call = multicallable(_REQUEST, 237 metadata=_INITIAL_METADATA_FOR_GENERIC_HANDLER) 238 self.assertEqual(_RESPONSE, await call) 239 self.assertEqual(grpc.StatusCode.OK, await call.code()) 240 241 async def test_unary_stream(self): 242 multicallable = self._client.unary_stream(_TEST_UNARY_STREAM) 243 call = multicallable(_REQUEST, 244 metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) 245 246 self.assertTrue( 247 _common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await 248 call.initial_metadata())) 249 250 self.assertSequenceEqual([_RESPONSE], 251 [request async for request in call]) 252 253 self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) 254 self.assertEqual(grpc.StatusCode.OK, await call.code()) 255 256 async def test_stream_unary(self): 257 multicallable = self._client.stream_unary(_TEST_STREAM_UNARY) 258 call = multicallable(metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) 259 await call.write(_REQUEST) 260 await call.done_writing() 261 262 self.assertTrue( 263 _common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await 264 call.initial_metadata())) 265 self.assertEqual(_RESPONSE, await call) 266 267 self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) 268 self.assertEqual(grpc.StatusCode.OK, await call.code()) 269 270 async def test_stream_stream(self): 271 multicallable = self._client.stream_stream(_TEST_STREAM_STREAM) 272 call = multicallable(metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) 273 await call.write(_REQUEST) 274 await call.done_writing() 275 276 self.assertTrue( 277 _common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await 278 call.initial_metadata())) 279 self.assertSequenceEqual([_RESPONSE], 280 [request async for request in call]) 281 self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) 282 self.assertEqual(grpc.StatusCode.OK, await call.code()) 283 284 async def test_compatibility_with_tuple(self): 285 metadata_obj = aio.Metadata(('key', '42'), ('key-2', 'value')) 286 self.assertEqual(metadata_obj, tuple(metadata_obj)) 287 self.assertEqual(tuple(metadata_obj), metadata_obj) 288 289 expected_sum = tuple(metadata_obj) + (('third', '3'),) 290 self.assertEqual(expected_sum, metadata_obj + (('third', '3'),)) 291 self.assertEqual(expected_sum, metadata_obj + aio.Metadata( 292 ('third', '3'))) 293 294 295if __name__ == '__main__': 296 logging.basicConfig(level=logging.DEBUG) 297 unittest.main(verbosity=2) 298