1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests server and client side metadata API.""" 15 16import logging 17import unittest 18import weakref 19 20import grpc 21from grpc import _channel 22 23from tests.unit import test_common 24from tests.unit.framework.common import test_constants 25 26_CHANNEL_ARGS = ( 27 ("grpc.primary_user_agent", "primary-agent"), 28 ("grpc.secondary_user_agent", "secondary-agent"), 29) 30 31_REQUEST = b"\x00\x00\x00" 32_RESPONSE = b"\x00\x00\x00" 33 34_SERVICE_NAME = "test" 35_UNARY_UNARY = "UnaryUnary" 36_UNARY_STREAM = "UnaryStream" 37_STREAM_UNARY = "StreamUnary" 38_STREAM_STREAM = "StreamStream" 39 40_INVOCATION_METADATA = ( 41 ( 42 b"invocation-md-key", 43 "invocation-md-value", 44 ), 45 ( 46 "invocation-md-key-bin", 47 b"\x00\x01", 48 ), 49) 50_EXPECTED_INVOCATION_METADATA = ( 51 ( 52 "invocation-md-key", 53 "invocation-md-value", 54 ), 55 ( 56 "invocation-md-key-bin", 57 b"\x00\x01", 58 ), 59) 60 61_INITIAL_METADATA = ( 62 (b"initial-md-key", "initial-md-value"), 63 ("initial-md-key-bin", b"\x00\x02"), 64) 65_EXPECTED_INITIAL_METADATA = ( 66 ( 67 "initial-md-key", 68 "initial-md-value", 69 ), 70 ( 71 "initial-md-key-bin", 72 b"\x00\x02", 73 ), 74) 75 76_TRAILING_METADATA = ( 77 ( 78 "server-trailing-md-key", 79 "server-trailing-md-value", 80 ), 81 ( 82 "server-trailing-md-key-bin", 83 b"\x00\x03", 84 ), 85) 86_EXPECTED_TRAILING_METADATA = _TRAILING_METADATA 87 88 89def _user_agent(metadata): 90 for key, val in metadata: 91 if key == "user-agent": 92 return val 93 raise KeyError("No user agent!") 94 95 96def validate_client_metadata(test, servicer_context): 97 invocation_metadata = servicer_context.invocation_metadata() 98 test.assertTrue( 99 test_common.metadata_transmitted( 100 _EXPECTED_INVOCATION_METADATA, invocation_metadata 101 ) 102 ) 103 user_agent = _user_agent(invocation_metadata) 104 test.assertTrue( 105 user_agent.startswith("primary-agent " + _channel._USER_AGENT) 106 ) 107 test.assertTrue(user_agent.endswith("secondary-agent")) 108 109 110def handle_unary_unary(test, request, servicer_context): 111 validate_client_metadata(test, servicer_context) 112 servicer_context.send_initial_metadata(_INITIAL_METADATA) 113 servicer_context.set_trailing_metadata(_TRAILING_METADATA) 114 return _RESPONSE 115 116 117def handle_unary_stream(test, request, servicer_context): 118 validate_client_metadata(test, servicer_context) 119 servicer_context.send_initial_metadata(_INITIAL_METADATA) 120 servicer_context.set_trailing_metadata(_TRAILING_METADATA) 121 for _ in range(test_constants.STREAM_LENGTH): 122 yield _RESPONSE 123 124 125def handle_stream_unary(test, request_iterator, servicer_context): 126 validate_client_metadata(test, servicer_context) 127 servicer_context.send_initial_metadata(_INITIAL_METADATA) 128 servicer_context.set_trailing_metadata(_TRAILING_METADATA) 129 # TODO(issue:#6891) We should be able to remove this loop 130 for request in request_iterator: 131 pass 132 return _RESPONSE 133 134 135def handle_stream_stream(test, request_iterator, servicer_context): 136 validate_client_metadata(test, servicer_context) 137 servicer_context.send_initial_metadata(_INITIAL_METADATA) 138 servicer_context.set_trailing_metadata(_TRAILING_METADATA) 139 # TODO(issue:#6891) We should be able to remove this loop, 140 # and replace with return; yield 141 for request in request_iterator: 142 yield _RESPONSE 143 144 145class _MethodHandler(grpc.RpcMethodHandler): 146 def __init__(self, test, request_streaming, response_streaming): 147 self.request_streaming = request_streaming 148 self.response_streaming = response_streaming 149 self.request_deserializer = None 150 self.response_serializer = None 151 self.unary_unary = None 152 self.unary_stream = None 153 self.stream_unary = None 154 self.stream_stream = None 155 if self.request_streaming and self.response_streaming: 156 self.stream_stream = lambda x, y: handle_stream_stream(test, x, y) 157 elif self.request_streaming: 158 self.stream_unary = lambda x, y: handle_stream_unary(test, x, y) 159 elif self.response_streaming: 160 self.unary_stream = lambda x, y: handle_unary_stream(test, x, y) 161 else: 162 self.unary_unary = lambda x, y: handle_unary_unary(test, x, y) 163 164 165def get_method_handlers(test): 166 return { 167 _UNARY_UNARY: _MethodHandler(test, False, False), 168 _UNARY_STREAM: _MethodHandler(test, False, True), 169 _STREAM_UNARY: _MethodHandler(test, True, False), 170 _STREAM_STREAM: _MethodHandler(test, True, True), 171 } 172 173 174class MetadataTest(unittest.TestCase): 175 def setUp(self): 176 self._server = test_common.test_server() 177 self._server.add_registered_method_handlers( 178 _SERVICE_NAME, get_method_handlers(weakref.proxy(self)) 179 ) 180 port = self._server.add_insecure_port("[::]:0") 181 self._server.start() 182 self._channel = grpc.insecure_channel( 183 "localhost:%d" % port, options=_CHANNEL_ARGS 184 ) 185 186 def tearDown(self): 187 self._server.stop(0) 188 self._channel.close() 189 190 def testUnaryUnary(self): 191 multi_callable = self._channel.unary_unary( 192 grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_UNARY), 193 _registered_method=True, 194 ) 195 unused_response, call = multi_callable.with_call( 196 _REQUEST, metadata=_INVOCATION_METADATA 197 ) 198 self.assertTrue( 199 test_common.metadata_transmitted( 200 _EXPECTED_INITIAL_METADATA, call.initial_metadata() 201 ) 202 ) 203 self.assertTrue( 204 test_common.metadata_transmitted( 205 _EXPECTED_TRAILING_METADATA, call.trailing_metadata() 206 ) 207 ) 208 209 def testUnaryStream(self): 210 multi_callable = self._channel.unary_stream( 211 grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_STREAM), 212 _registered_method=True, 213 ) 214 call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA) 215 self.assertTrue( 216 test_common.metadata_transmitted( 217 _EXPECTED_INITIAL_METADATA, call.initial_metadata() 218 ) 219 ) 220 for _ in call: 221 pass 222 self.assertTrue( 223 test_common.metadata_transmitted( 224 _EXPECTED_TRAILING_METADATA, call.trailing_metadata() 225 ) 226 ) 227 228 def testStreamUnary(self): 229 multi_callable = self._channel.stream_unary( 230 grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_UNARY), 231 _registered_method=True, 232 ) 233 unused_response, call = multi_callable.with_call( 234 iter([_REQUEST] * test_constants.STREAM_LENGTH), 235 metadata=_INVOCATION_METADATA, 236 ) 237 self.assertTrue( 238 test_common.metadata_transmitted( 239 _EXPECTED_INITIAL_METADATA, call.initial_metadata() 240 ) 241 ) 242 self.assertTrue( 243 test_common.metadata_transmitted( 244 _EXPECTED_TRAILING_METADATA, call.trailing_metadata() 245 ) 246 ) 247 248 def testStreamStream(self): 249 multi_callable = self._channel.stream_stream( 250 grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_STREAM), 251 _registered_method=True, 252 ) 253 call = multi_callable( 254 iter([_REQUEST] * test_constants.STREAM_LENGTH), 255 metadata=_INVOCATION_METADATA, 256 ) 257 self.assertTrue( 258 test_common.metadata_transmitted( 259 _EXPECTED_INITIAL_METADATA, call.initial_metadata() 260 ) 261 ) 262 for _ in call: 263 pass 264 self.assertTrue( 265 test_common.metadata_transmitted( 266 _EXPECTED_TRAILING_METADATA, call.trailing_metadata() 267 ) 268 ) 269 270 271if __name__ == "__main__": 272 logging.basicConfig() 273 unittest.main(verbosity=2) 274