• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests server and client side metadata API."""
15
16import unittest
17import weakref
18import logging
19
20import grpc
21from grpc import _channel
22
23from tests.unit import test_common
24from tests.unit.framework.common import test_constants
25
26_CHANNEL_ARGS = (('grpc.primary_user_agent', 'primary-agent'),
27                 ('grpc.secondary_user_agent', 'secondary-agent'))
28
29_REQUEST = b'\x00\x00\x00'
30_RESPONSE = b'\x00\x00\x00'
31
32_UNARY_UNARY = '/test/UnaryUnary'
33_UNARY_STREAM = '/test/UnaryStream'
34_STREAM_UNARY = '/test/StreamUnary'
35_STREAM_STREAM = '/test/StreamStream'
36
37_INVOCATION_METADATA = (
38    (
39        b'invocation-md-key',
40        u'invocation-md-value',
41    ),
42    (
43        u'invocation-md-key-bin',
44        b'\x00\x01',
45    ),
46)
47_EXPECTED_INVOCATION_METADATA = (
48    (
49        'invocation-md-key',
50        'invocation-md-value',
51    ),
52    (
53        'invocation-md-key-bin',
54        b'\x00\x01',
55    ),
56)
57
58_INITIAL_METADATA = ((b'initial-md-key', u'initial-md-value'),
59                     (u'initial-md-key-bin', b'\x00\x02'))
60_EXPECTED_INITIAL_METADATA = (
61    (
62        'initial-md-key',
63        'initial-md-value',
64    ),
65    (
66        'initial-md-key-bin',
67        b'\x00\x02',
68    ),
69)
70
71_TRAILING_METADATA = (
72    (
73        'server-trailing-md-key',
74        'server-trailing-md-value',
75    ),
76    (
77        'server-trailing-md-key-bin',
78        b'\x00\x03',
79    ),
80)
81_EXPECTED_TRAILING_METADATA = _TRAILING_METADATA
82
83
84def _user_agent(metadata):
85    for key, val in metadata:
86        if key == 'user-agent':
87            return val
88    raise KeyError('No user agent!')
89
90
91def validate_client_metadata(test, servicer_context):
92    invocation_metadata = servicer_context.invocation_metadata()
93    test.assertTrue(
94        test_common.metadata_transmitted(_EXPECTED_INVOCATION_METADATA,
95                                         invocation_metadata))
96    user_agent = _user_agent(invocation_metadata)
97    test.assertTrue(
98        user_agent.startswith('primary-agent ' + _channel._USER_AGENT))
99    test.assertTrue(user_agent.endswith('secondary-agent'))
100
101
102def handle_unary_unary(test, request, servicer_context):
103    validate_client_metadata(test, servicer_context)
104    servicer_context.send_initial_metadata(_INITIAL_METADATA)
105    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
106    return _RESPONSE
107
108
109def handle_unary_stream(test, request, servicer_context):
110    validate_client_metadata(test, servicer_context)
111    servicer_context.send_initial_metadata(_INITIAL_METADATA)
112    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
113    for _ in range(test_constants.STREAM_LENGTH):
114        yield _RESPONSE
115
116
117def handle_stream_unary(test, request_iterator, servicer_context):
118    validate_client_metadata(test, servicer_context)
119    servicer_context.send_initial_metadata(_INITIAL_METADATA)
120    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
121    # TODO(issue:#6891) We should be able to remove this loop
122    for request in request_iterator:
123        pass
124    return _RESPONSE
125
126
127def handle_stream_stream(test, request_iterator, servicer_context):
128    validate_client_metadata(test, servicer_context)
129    servicer_context.send_initial_metadata(_INITIAL_METADATA)
130    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
131    # TODO(issue:#6891) We should be able to remove this loop,
132    # and replace with return; yield
133    for request in request_iterator:
134        yield _RESPONSE
135
136
137class _MethodHandler(grpc.RpcMethodHandler):
138
139    def __init__(self, test, request_streaming, response_streaming):
140        self.request_streaming = request_streaming
141        self.response_streaming = response_streaming
142        self.request_deserializer = None
143        self.response_serializer = None
144        self.unary_unary = None
145        self.unary_stream = None
146        self.stream_unary = None
147        self.stream_stream = None
148        if self.request_streaming and self.response_streaming:
149            self.stream_stream = lambda x, y: handle_stream_stream(test, x, y)
150        elif self.request_streaming:
151            self.stream_unary = lambda x, y: handle_stream_unary(test, x, y)
152        elif self.response_streaming:
153            self.unary_stream = lambda x, y: handle_unary_stream(test, x, y)
154        else:
155            self.unary_unary = lambda x, y: handle_unary_unary(test, x, y)
156
157
158class _GenericHandler(grpc.GenericRpcHandler):
159
160    def __init__(self, test):
161        self._test = test
162
163    def service(self, handler_call_details):
164        if handler_call_details.method == _UNARY_UNARY:
165            return _MethodHandler(self._test, False, False)
166        elif handler_call_details.method == _UNARY_STREAM:
167            return _MethodHandler(self._test, False, True)
168        elif handler_call_details.method == _STREAM_UNARY:
169            return _MethodHandler(self._test, True, False)
170        elif handler_call_details.method == _STREAM_STREAM:
171            return _MethodHandler(self._test, True, True)
172        else:
173            return None
174
175
176class MetadataTest(unittest.TestCase):
177
178    def setUp(self):
179        self._server = test_common.test_server()
180        self._server.add_generic_rpc_handlers(
181            (_GenericHandler(weakref.proxy(self)),))
182        port = self._server.add_insecure_port('[::]:0')
183        self._server.start()
184        self._channel = grpc.insecure_channel('localhost:%d' % port,
185                                              options=_CHANNEL_ARGS)
186
187    def tearDown(self):
188        self._server.stop(0)
189        self._channel.close()
190
191    def testUnaryUnary(self):
192        multi_callable = self._channel.unary_unary(_UNARY_UNARY)
193        unused_response, call = multi_callable.with_call(
194            _REQUEST, metadata=_INVOCATION_METADATA)
195        self.assertTrue(
196            test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
197                                             call.initial_metadata()))
198        self.assertTrue(
199            test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
200                                             call.trailing_metadata()))
201
202    def testUnaryStream(self):
203        multi_callable = self._channel.unary_stream(_UNARY_STREAM)
204        call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA)
205        self.assertTrue(
206            test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
207                                             call.initial_metadata()))
208        for _ in call:
209            pass
210        self.assertTrue(
211            test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
212                                             call.trailing_metadata()))
213
214    def testStreamUnary(self):
215        multi_callable = self._channel.stream_unary(_STREAM_UNARY)
216        unused_response, call = multi_callable.with_call(
217            iter([_REQUEST] * test_constants.STREAM_LENGTH),
218            metadata=_INVOCATION_METADATA)
219        self.assertTrue(
220            test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
221                                             call.initial_metadata()))
222        self.assertTrue(
223            test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
224                                             call.trailing_metadata()))
225
226    def testStreamStream(self):
227        multi_callable = self._channel.stream_stream(_STREAM_STREAM)
228        call = multi_callable(iter([_REQUEST] * test_constants.STREAM_LENGTH),
229                              metadata=_INVOCATION_METADATA)
230        self.assertTrue(
231            test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
232                                             call.initial_metadata()))
233        for _ in call:
234            pass
235        self.assertTrue(
236            test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
237                                             call.trailing_metadata()))
238
239
240if __name__ == '__main__':
241    logging.basicConfig()
242    unittest.main(verbosity=2)
243