• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests server and client side metadata API."""
15
16import logging
17import unittest
18import weakref
19
20import grpc
21from grpc import _channel
22
23from tests.unit import test_common
24from tests.unit.framework.common import test_constants
25
26_CHANNEL_ARGS = (
27    ("grpc.primary_user_agent", "primary-agent"),
28    ("grpc.secondary_user_agent", "secondary-agent"),
29)
30
31_REQUEST = b"\x00\x00\x00"
32_RESPONSE = b"\x00\x00\x00"
33
34_SERVICE_NAME = "test"
35_UNARY_UNARY = "UnaryUnary"
36_UNARY_STREAM = "UnaryStream"
37_STREAM_UNARY = "StreamUnary"
38_STREAM_STREAM = "StreamStream"
39
40_INVOCATION_METADATA = (
41    (
42        b"invocation-md-key",
43        "invocation-md-value",
44    ),
45    (
46        "invocation-md-key-bin",
47        b"\x00\x01",
48    ),
49)
50_EXPECTED_INVOCATION_METADATA = (
51    (
52        "invocation-md-key",
53        "invocation-md-value",
54    ),
55    (
56        "invocation-md-key-bin",
57        b"\x00\x01",
58    ),
59)
60
61_INITIAL_METADATA = (
62    (b"initial-md-key", "initial-md-value"),
63    ("initial-md-key-bin", b"\x00\x02"),
64)
65_EXPECTED_INITIAL_METADATA = (
66    (
67        "initial-md-key",
68        "initial-md-value",
69    ),
70    (
71        "initial-md-key-bin",
72        b"\x00\x02",
73    ),
74)
75
76_TRAILING_METADATA = (
77    (
78        "server-trailing-md-key",
79        "server-trailing-md-value",
80    ),
81    (
82        "server-trailing-md-key-bin",
83        b"\x00\x03",
84    ),
85)
86_EXPECTED_TRAILING_METADATA = _TRAILING_METADATA
87
88
89def _user_agent(metadata):
90    for key, val in metadata:
91        if key == "user-agent":
92            return val
93    raise KeyError("No user agent!")
94
95
96def validate_client_metadata(test, servicer_context):
97    invocation_metadata = servicer_context.invocation_metadata()
98    test.assertTrue(
99        test_common.metadata_transmitted(
100            _EXPECTED_INVOCATION_METADATA, invocation_metadata
101        )
102    )
103    user_agent = _user_agent(invocation_metadata)
104    test.assertTrue(
105        user_agent.startswith("primary-agent " + _channel._USER_AGENT)
106    )
107    test.assertTrue(user_agent.endswith("secondary-agent"))
108
109
110def handle_unary_unary(test, request, servicer_context):
111    validate_client_metadata(test, servicer_context)
112    servicer_context.send_initial_metadata(_INITIAL_METADATA)
113    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
114    return _RESPONSE
115
116
117def handle_unary_stream(test, request, servicer_context):
118    validate_client_metadata(test, servicer_context)
119    servicer_context.send_initial_metadata(_INITIAL_METADATA)
120    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
121    for _ in range(test_constants.STREAM_LENGTH):
122        yield _RESPONSE
123
124
125def handle_stream_unary(test, request_iterator, servicer_context):
126    validate_client_metadata(test, servicer_context)
127    servicer_context.send_initial_metadata(_INITIAL_METADATA)
128    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
129    # TODO(issue:#6891) We should be able to remove this loop
130    for request in request_iterator:
131        pass
132    return _RESPONSE
133
134
135def handle_stream_stream(test, request_iterator, servicer_context):
136    validate_client_metadata(test, servicer_context)
137    servicer_context.send_initial_metadata(_INITIAL_METADATA)
138    servicer_context.set_trailing_metadata(_TRAILING_METADATA)
139    # TODO(issue:#6891) We should be able to remove this loop,
140    # and replace with return; yield
141    for request in request_iterator:
142        yield _RESPONSE
143
144
145class _MethodHandler(grpc.RpcMethodHandler):
146    def __init__(self, test, request_streaming, response_streaming):
147        self.request_streaming = request_streaming
148        self.response_streaming = response_streaming
149        self.request_deserializer = None
150        self.response_serializer = None
151        self.unary_unary = None
152        self.unary_stream = None
153        self.stream_unary = None
154        self.stream_stream = None
155        if self.request_streaming and self.response_streaming:
156            self.stream_stream = lambda x, y: handle_stream_stream(test, x, y)
157        elif self.request_streaming:
158            self.stream_unary = lambda x, y: handle_stream_unary(test, x, y)
159        elif self.response_streaming:
160            self.unary_stream = lambda x, y: handle_unary_stream(test, x, y)
161        else:
162            self.unary_unary = lambda x, y: handle_unary_unary(test, x, y)
163
164
165def get_method_handlers(test):
166    return {
167        _UNARY_UNARY: _MethodHandler(test, False, False),
168        _UNARY_STREAM: _MethodHandler(test, False, True),
169        _STREAM_UNARY: _MethodHandler(test, True, False),
170        _STREAM_STREAM: _MethodHandler(test, True, True),
171    }
172
173
174class MetadataTest(unittest.TestCase):
175    def setUp(self):
176        self._server = test_common.test_server()
177        self._server.add_registered_method_handlers(
178            _SERVICE_NAME, get_method_handlers(weakref.proxy(self))
179        )
180        port = self._server.add_insecure_port("[::]:0")
181        self._server.start()
182        self._channel = grpc.insecure_channel(
183            "localhost:%d" % port, options=_CHANNEL_ARGS
184        )
185
186    def tearDown(self):
187        self._server.stop(0)
188        self._channel.close()
189
190    def testUnaryUnary(self):
191        multi_callable = self._channel.unary_unary(
192            grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_UNARY),
193            _registered_method=True,
194        )
195        unused_response, call = multi_callable.with_call(
196            _REQUEST, metadata=_INVOCATION_METADATA
197        )
198        self.assertTrue(
199            test_common.metadata_transmitted(
200                _EXPECTED_INITIAL_METADATA, call.initial_metadata()
201            )
202        )
203        self.assertTrue(
204            test_common.metadata_transmitted(
205                _EXPECTED_TRAILING_METADATA, call.trailing_metadata()
206            )
207        )
208
209    def testUnaryStream(self):
210        multi_callable = self._channel.unary_stream(
211            grpc._common.fully_qualified_method(_SERVICE_NAME, _UNARY_STREAM),
212            _registered_method=True,
213        )
214        call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA)
215        self.assertTrue(
216            test_common.metadata_transmitted(
217                _EXPECTED_INITIAL_METADATA, call.initial_metadata()
218            )
219        )
220        for _ in call:
221            pass
222        self.assertTrue(
223            test_common.metadata_transmitted(
224                _EXPECTED_TRAILING_METADATA, call.trailing_metadata()
225            )
226        )
227
228    def testStreamUnary(self):
229        multi_callable = self._channel.stream_unary(
230            grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_UNARY),
231            _registered_method=True,
232        )
233        unused_response, call = multi_callable.with_call(
234            iter([_REQUEST] * test_constants.STREAM_LENGTH),
235            metadata=_INVOCATION_METADATA,
236        )
237        self.assertTrue(
238            test_common.metadata_transmitted(
239                _EXPECTED_INITIAL_METADATA, call.initial_metadata()
240            )
241        )
242        self.assertTrue(
243            test_common.metadata_transmitted(
244                _EXPECTED_TRAILING_METADATA, call.trailing_metadata()
245            )
246        )
247
248    def testStreamStream(self):
249        multi_callable = self._channel.stream_stream(
250            grpc._common.fully_qualified_method(_SERVICE_NAME, _STREAM_STREAM),
251            _registered_method=True,
252        )
253        call = multi_callable(
254            iter([_REQUEST] * test_constants.STREAM_LENGTH),
255            metadata=_INVOCATION_METADATA,
256        )
257        self.assertTrue(
258            test_common.metadata_transmitted(
259                _EXPECTED_INITIAL_METADATA, call.initial_metadata()
260            )
261        )
262        for _ in call:
263            pass
264        self.assertTrue(
265            test_common.metadata_transmitted(
266                _EXPECTED_TRAILING_METADATA, call.trailing_metadata()
267            )
268        )
269
270
271if __name__ == "__main__":
272    logging.basicConfig()
273    unittest.main(verbosity=2)
274