1# Copyright 2020 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests behavior around the metadata mechanism.""" 15 16import asyncio 17import logging 18import platform 19import random 20import unittest 21 22import grpc 23from grpc.experimental import aio 24 25from tests_aio.unit import _common 26from tests_aio.unit._test_base import AioTestBase 27 28_TEST_CLIENT_TO_SERVER = "/test/TestClientToServer" 29_TEST_SERVER_TO_CLIENT = "/test/TestServerToClient" 30_TEST_TRAILING_METADATA = "/test/TestTrailingMetadata" 31_TEST_ECHO_INITIAL_METADATA = "/test/TestEchoInitialMetadata" 32_TEST_GENERIC_HANDLER = "/test/TestGenericHandler" 33_TEST_UNARY_STREAM = "/test/TestUnaryStream" 34_TEST_STREAM_UNARY = "/test/TestStreamUnary" 35_TEST_STREAM_STREAM = "/test/TestStreamStream" 36_TEST_INSPECT_CONTEXT = "/test/TestInspectContext" 37 38_REQUEST = b"\x00\x00\x00" 39_RESPONSE = b"\x01\x01\x01" 40 41_INITIAL_METADATA_FROM_CLIENT_TO_SERVER = aio.Metadata( 42 ("client-to-server", "question"), 43 ("client-to-server-bin", b"\x07\x07\x07"), 44) 45_INITIAL_METADATA_FROM_SERVER_TO_CLIENT = aio.Metadata( 46 ("server-to-client", "answer"), 47 ("server-to-client-bin", b"\x06\x06\x06"), 48) 49_TRAILING_METADATA = aio.Metadata( 50 ("a-trailing-metadata", "stack-trace"), 51 ("a-trailing-metadata-bin", b"\x05\x05\x05"), 52) 53_INITIAL_METADATA_FOR_GENERIC_HANDLER = aio.Metadata( 54 ("a-must-have-key", "secret"), 55) 56 57_INVALID_METADATA_TEST_CASES = ( 58 ( 59 TypeError, 60 ((42, 42),), 61 ), 62 ( 63 TypeError, 64 (({}, {}),), 65 ), 66 ( 67 TypeError, 68 ((None, {}),), 69 ), 70 ( 71 TypeError, 72 (({}, {}),), 73 ), 74 ( 75 TypeError, 76 (("normal", object()),), 77 ), 78) 79 80_NON_OK_CODE = grpc.StatusCode.NOT_FOUND 81_DETAILS = "Test details!" 82 83 84class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): 85 def __init__(self): 86 self._routing_table = { 87 _TEST_CLIENT_TO_SERVER: grpc.unary_unary_rpc_method_handler( 88 self._test_client_to_server 89 ), 90 _TEST_SERVER_TO_CLIENT: grpc.unary_unary_rpc_method_handler( 91 self._test_server_to_client 92 ), 93 _TEST_TRAILING_METADATA: grpc.unary_unary_rpc_method_handler( 94 self._test_trailing_metadata 95 ), 96 _TEST_UNARY_STREAM: grpc.unary_stream_rpc_method_handler( 97 self._test_unary_stream 98 ), 99 _TEST_STREAM_UNARY: grpc.stream_unary_rpc_method_handler( 100 self._test_stream_unary 101 ), 102 _TEST_STREAM_STREAM: grpc.stream_stream_rpc_method_handler( 103 self._test_stream_stream 104 ), 105 _TEST_INSPECT_CONTEXT: grpc.unary_unary_rpc_method_handler( 106 self._test_inspect_context 107 ), 108 } 109 110 @staticmethod 111 async def _test_client_to_server(request, context): 112 assert _REQUEST == request 113 assert _common.seen_metadata( 114 _INITIAL_METADATA_FROM_CLIENT_TO_SERVER, 115 context.invocation_metadata(), 116 ) 117 return _RESPONSE 118 119 @staticmethod 120 async def _test_server_to_client(request, context): 121 assert _REQUEST == request 122 await context.send_initial_metadata( 123 _INITIAL_METADATA_FROM_SERVER_TO_CLIENT 124 ) 125 return _RESPONSE 126 127 @staticmethod 128 async def _test_trailing_metadata(request, context): 129 assert _REQUEST == request 130 context.set_trailing_metadata(_TRAILING_METADATA) 131 return _RESPONSE 132 133 @staticmethod 134 async def _test_unary_stream(request, context): 135 assert _REQUEST == request 136 assert _common.seen_metadata( 137 _INITIAL_METADATA_FROM_CLIENT_TO_SERVER, 138 context.invocation_metadata(), 139 ) 140 await context.send_initial_metadata( 141 _INITIAL_METADATA_FROM_SERVER_TO_CLIENT 142 ) 143 yield _RESPONSE 144 context.set_trailing_metadata(_TRAILING_METADATA) 145 146 @staticmethod 147 async def _test_stream_unary(request_iterator, context): 148 assert _common.seen_metadata( 149 _INITIAL_METADATA_FROM_CLIENT_TO_SERVER, 150 context.invocation_metadata(), 151 ) 152 await context.send_initial_metadata( 153 _INITIAL_METADATA_FROM_SERVER_TO_CLIENT 154 ) 155 156 async for request in request_iterator: 157 assert _REQUEST == request 158 159 context.set_trailing_metadata(_TRAILING_METADATA) 160 return _RESPONSE 161 162 @staticmethod 163 async def _test_stream_stream(request_iterator, context): 164 assert _common.seen_metadata( 165 _INITIAL_METADATA_FROM_CLIENT_TO_SERVER, 166 context.invocation_metadata(), 167 ) 168 await context.send_initial_metadata( 169 _INITIAL_METADATA_FROM_SERVER_TO_CLIENT 170 ) 171 172 async for request in request_iterator: 173 assert _REQUEST == request 174 175 yield _RESPONSE 176 context.set_trailing_metadata(_TRAILING_METADATA) 177 178 @staticmethod 179 async def _test_inspect_context(request, context): 180 assert _REQUEST == request 181 context.set_code(_NON_OK_CODE) 182 context.set_details(_DETAILS) 183 context.set_trailing_metadata(_TRAILING_METADATA) 184 185 # ensure that we can read back the data we set on the context 186 assert context.code() == _NON_OK_CODE 187 assert context.details() == _DETAILS 188 assert context.trailing_metadata() == _TRAILING_METADATA 189 return _RESPONSE 190 191 def service(self, handler_call_details): 192 return self._routing_table.get(handler_call_details.method) 193 194 195class _TestGenericHandlerItself(grpc.GenericRpcHandler): 196 @staticmethod 197 async def _method(request, unused_context): 198 assert _REQUEST == request 199 return _RESPONSE 200 201 def service(self, handler_call_details): 202 assert _common.seen_metadata( 203 _INITIAL_METADATA_FOR_GENERIC_HANDLER, 204 handler_call_details.invocation_metadata, 205 ) 206 return grpc.unary_unary_rpc_method_handler(self._method) 207 208 209async def _start_test_server(): 210 server = aio.server() 211 port = server.add_insecure_port("[::]:0") 212 server.add_generic_rpc_handlers( 213 ( 214 _TestGenericHandlerForMethods(), 215 _TestGenericHandlerItself(), 216 ) 217 ) 218 await server.start() 219 return "localhost:%d" % port, server 220 221 222class TestMetadata(AioTestBase): 223 async def setUp(self): 224 address, self._server = await _start_test_server() 225 self._client = aio.insecure_channel(address) 226 227 async def tearDown(self): 228 await self._client.close() 229 await self._server.stop(None) 230 231 async def test_from_client_to_server(self): 232 multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) 233 call = multicallable( 234 _REQUEST, metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER 235 ) 236 self.assertEqual(_RESPONSE, await call) 237 self.assertEqual(grpc.StatusCode.OK, await call.code()) 238 239 async def test_from_server_to_client(self): 240 multicallable = self._client.unary_unary(_TEST_SERVER_TO_CLIENT) 241 call = multicallable(_REQUEST) 242 243 self.assertEqual( 244 _INITIAL_METADATA_FROM_SERVER_TO_CLIENT, 245 await call.initial_metadata(), 246 ) 247 self.assertEqual(_RESPONSE, await call) 248 self.assertEqual(grpc.StatusCode.OK, await call.code()) 249 250 async def test_trailing_metadata(self): 251 multicallable = self._client.unary_unary(_TEST_TRAILING_METADATA) 252 call = multicallable(_REQUEST) 253 self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) 254 self.assertEqual(_RESPONSE, await call) 255 self.assertEqual(grpc.StatusCode.OK, await call.code()) 256 257 async def test_from_client_to_server_with_list(self): 258 multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) 259 call = multicallable( 260 _REQUEST, metadata=list(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) 261 ) # pytype: disable=wrong-arg-types 262 self.assertEqual(_RESPONSE, await call) 263 self.assertEqual(grpc.StatusCode.OK, await call.code()) 264 265 @unittest.skipIf( 266 platform.system() == "Windows", 267 "https://github.com/grpc/grpc/issues/21943", 268 ) 269 async def test_invalid_metadata(self): 270 multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) 271 for exception_type, metadata in _INVALID_METADATA_TEST_CASES: 272 with self.subTest(metadata=metadata): 273 with self.assertRaises(exception_type): 274 call = multicallable(_REQUEST, metadata=metadata) 275 await call 276 277 async def test_generic_handler(self): 278 multicallable = self._client.unary_unary(_TEST_GENERIC_HANDLER) 279 call = multicallable( 280 _REQUEST, metadata=_INITIAL_METADATA_FOR_GENERIC_HANDLER 281 ) 282 self.assertEqual(_RESPONSE, await call) 283 self.assertEqual(grpc.StatusCode.OK, await call.code()) 284 285 async def test_unary_stream(self): 286 multicallable = self._client.unary_stream(_TEST_UNARY_STREAM) 287 call = multicallable( 288 _REQUEST, metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER 289 ) 290 291 self.assertTrue( 292 _common.seen_metadata( 293 _INITIAL_METADATA_FROM_SERVER_TO_CLIENT, 294 await call.initial_metadata(), 295 ) 296 ) 297 298 self.assertSequenceEqual( 299 [_RESPONSE], [request async for request in call] 300 ) 301 302 self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) 303 self.assertEqual(grpc.StatusCode.OK, await call.code()) 304 305 async def test_stream_unary(self): 306 multicallable = self._client.stream_unary(_TEST_STREAM_UNARY) 307 call = multicallable(metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) 308 await call.write(_REQUEST) 309 await call.done_writing() 310 311 self.assertTrue( 312 _common.seen_metadata( 313 _INITIAL_METADATA_FROM_SERVER_TO_CLIENT, 314 await call.initial_metadata(), 315 ) 316 ) 317 self.assertEqual(_RESPONSE, await call) 318 319 self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) 320 self.assertEqual(grpc.StatusCode.OK, await call.code()) 321 322 async def test_stream_stream(self): 323 multicallable = self._client.stream_stream(_TEST_STREAM_STREAM) 324 call = multicallable(metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) 325 await call.write(_REQUEST) 326 await call.done_writing() 327 328 self.assertTrue( 329 _common.seen_metadata( 330 _INITIAL_METADATA_FROM_SERVER_TO_CLIENT, 331 await call.initial_metadata(), 332 ) 333 ) 334 self.assertSequenceEqual( 335 [_RESPONSE], [request async for request in call] 336 ) 337 self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) 338 self.assertEqual(grpc.StatusCode.OK, await call.code()) 339 340 async def test_compatibility_with_tuple(self): 341 metadata_obj = aio.Metadata(("key", "42"), ("key-2", "value")) 342 self.assertEqual(metadata_obj, tuple(metadata_obj)) 343 self.assertEqual(tuple(metadata_obj), metadata_obj) 344 345 expected_sum = tuple(metadata_obj) + (("third", "3"),) 346 self.assertEqual(expected_sum, metadata_obj + (("third", "3"),)) 347 self.assertEqual( 348 expected_sum, metadata_obj + aio.Metadata(("third", "3")) 349 ) 350 351 async def test_inspect_context(self): 352 multicallable = self._client.unary_unary(_TEST_INSPECT_CONTEXT) 353 call = multicallable(_REQUEST) 354 with self.assertRaises(grpc.RpcError) as exc_data: 355 await call 356 357 err = exc_data.exception 358 self.assertEqual(_NON_OK_CODE, err.code()) 359 360 361if __name__ == "__main__": 362 logging.basicConfig(level=logging.DEBUG) 363 unittest.main(verbosity=2) 364