• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The gRPC Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from concurrent import futures
16import logging
17import unittest
18
19import grpc
20
21from tests.unit import resources
22from tests.unit import test_common
23from tests.unit.framework.common import test_constants
24
25_REQUEST = b""
26_RESPONSE = b"response"
27_REGISTERED_RESPONSE = b"registered_response"
28
29_SERVICE_NAME = "test"
30_UNARY_UNARY = "UnaryUnary"
31_UNARY_UNARY_REGISTERED = "UnaryUnaryRegistered"
32_UNARY_STREAM = "UnaryStream"
33_STREAM_UNARY = "StreamUnary"
34_STREAM_STREAM = "StreamStream"
35
36
37class _ActualGenericRpcHandler(grpc.GenericRpcHandler):
38    def service(self, handler_call_details):
39        return None
40
41
42def handle_unary_unary(request, servicer_context):
43    return _RESPONSE
44
45
46def handle_registered_unary_unary(request, servicer_context):
47    return _REGISTERED_RESPONSE
48
49
50def handle_unary_stream(request, servicer_context):
51    for _ in range(test_constants.STREAM_LENGTH):
52        yield _RESPONSE
53
54
55def handle_stream_unary(request_iterator, servicer_context):
56    for request in request_iterator:
57        pass
58    return _RESPONSE
59
60
61def handle_stream_stream(request_iterator, servicer_context):
62    for request in request_iterator:
63        yield _RESPONSE
64
65
66class _MethodHandler(grpc.RpcMethodHandler):
67    def __init__(self, request_streaming, response_streaming, registered=False):
68        self.request_streaming = request_streaming
69        self.response_streaming = response_streaming
70        self.request_deserializer = None
71        self.response_serializer = None
72        self.unary_unary = None
73        self.unary_stream = None
74        self.stream_unary = None
75        self.stream_stream = None
76        if self.request_streaming and self.response_streaming:
77            self.stream_stream = handle_stream_stream
78        elif self.request_streaming:
79            self.stream_unary = handle_stream_unary
80        elif self.response_streaming:
81            self.unary_stream = handle_unary_stream
82        else:
83            if registered:
84                self.unary_unary = handle_registered_unary_unary
85            else:
86                self.unary_unary = handle_unary_unary
87
88
89class _GenericHandler(grpc.GenericRpcHandler):
90    def service(self, handler_call_details):
91        if handler_call_details.method == _UNARY_UNARY:
92            return _MethodHandler(False, False)
93        elif handler_call_details.method == _UNARY_STREAM:
94            return _MethodHandler(False, True)
95        elif handler_call_details.method == _STREAM_UNARY:
96            return _MethodHandler(True, False)
97        elif handler_call_details.method == _STREAM_STREAM:
98            return _MethodHandler(True, True)
99        elif handler_call_details.method == grpc._common.fully_qualified_method(
100            _SERVICE_NAME, _UNARY_UNARY_REGISTERED
101        ):
102            return _MethodHandler(False, False)
103        else:
104            return None
105
106
107class _GenericHandlerWithRegisteredName(grpc.GenericRpcHandler):
108    def service(self, handler_call_details):
109        if handler_call_details.method == grpc._common.fully_qualified_method(
110            _SERVICE_NAME, _UNARY_UNARY_REGISTERED
111        ):
112            return _MethodHandler(False, False)
113        else:
114            return None
115
116
117_REGISTERED_METHOD_HANDLERS = {
118    _UNARY_UNARY_REGISTERED: _MethodHandler(False, False, True),
119}
120
121
122class ServerTest(unittest.TestCase):
123    def test_not_a_generic_rpc_handler_at_construction(self):
124        with self.assertRaises(AttributeError) as exception_context:
125            grpc.server(
126                futures.ThreadPoolExecutor(max_workers=5),
127                handlers=[
128                    _ActualGenericRpcHandler(),
129                    object(),
130                ],
131            )
132        self.assertIn(
133            "grpc.GenericRpcHandler", str(exception_context.exception)
134        )
135
136    def test_not_a_generic_rpc_handler_after_construction(self):
137        server = grpc.server(futures.ThreadPoolExecutor(max_workers=5))
138        with self.assertRaises(AttributeError) as exception_context:
139            server.add_generic_rpc_handlers(
140                [
141                    _ActualGenericRpcHandler(),
142                    object(),
143                ]
144            )
145        self.assertIn(
146            "grpc.GenericRpcHandler", str(exception_context.exception)
147        )
148
149    def test_failed_port_binding_exception(self):
150        server = grpc.server(None, options=(("grpc.so_reuseport", 0),))
151        port = server.add_insecure_port("localhost:0")
152        bind_address = "localhost:%d" % port
153
154        with self.assertRaises(RuntimeError):
155            server.add_insecure_port(bind_address)
156
157        server_credentials = grpc.ssl_server_credentials(
158            [(resources.private_key(), resources.certificate_chain())]
159        )
160        with self.assertRaises(RuntimeError):
161            server.add_secure_port(bind_address, server_credentials)
162
163
164class ServerHandlerTest(unittest.TestCase):
165    def tearDown(self):
166        self._server.stop(0)
167        self._channel.close()
168
169    def test_generic_unary_unary_handler(self):
170        self._server = test_common.test_server()
171        port = self._server.add_insecure_port("[::]:0")
172        self._server.start()
173        self._server.add_generic_rpc_handlers((_GenericHandler(),))
174        self._channel = grpc.insecure_channel("localhost:%d" % port)
175
176        response = self._channel.unary_unary(
177            _UNARY_UNARY,
178            _registered_method=True,
179        )(_REQUEST)
180        self.assertEqual(_RESPONSE, response)
181
182    def test_generic_unary_stream_handler(self):
183        self._server = test_common.test_server()
184        self._server.add_generic_rpc_handlers((_GenericHandler(),))
185        port = self._server.add_insecure_port("[::]:0")
186        self._server.start()
187        self._channel = grpc.insecure_channel("localhost:%d" % port)
188
189        response_iterator = self._channel.unary_stream(
190            _UNARY_STREAM,
191            _registered_method=True,
192        )(_REQUEST)
193        self.assertSequenceEqual(
194            [_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator)
195        )
196
197    def test_generic_stream_unary_handler(self):
198        self._server = test_common.test_server()
199        self._server.add_generic_rpc_handlers((_GenericHandler(),))
200        port = self._server.add_insecure_port("[::]:0")
201        self._server.start()
202        self._channel = grpc.insecure_channel("localhost:%d" % port)
203
204        response = self._channel.stream_unary(
205            _STREAM_UNARY,
206            _registered_method=True,
207        )(iter([_REQUEST] * test_constants.STREAM_LENGTH))
208        self.assertEqual(_RESPONSE, response)
209
210    def test_generic_stream_stream_handler(self):
211        self._server = test_common.test_server()
212        self._server.add_generic_rpc_handlers((_GenericHandler(),))
213        port = self._server.add_insecure_port("[::]:0")
214        self._server.start()
215        self._channel = grpc.insecure_channel("localhost:%d" % port)
216
217        response_iterator = self._channel.stream_stream(
218            _STREAM_STREAM,
219            _registered_method=True,
220        )(iter([_REQUEST] * test_constants.STREAM_LENGTH))
221        self.assertSequenceEqual(
222            [_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator)
223        )
224
225    def test_add_generic_handler_after_server_start(self):
226        self._server = test_common.test_server()
227        port = self._server.add_insecure_port("[::]:0")
228        self._server.start()
229        self._server.add_generic_rpc_handlers((_GenericHandler(),))
230        self._channel = grpc.insecure_channel("localhost:%d" % port)
231
232        response = self._channel.unary_unary(
233            _UNARY_UNARY,
234            _registered_method=True,
235        )(_REQUEST)
236        self.assertEqual(_RESPONSE, response)
237
238    def test_add_registered_handler_after_server_start(self):
239        self._server = test_common.test_server()
240        port = self._server.add_insecure_port("[::]:0")
241        self._server.start()
242        self._server.add_registered_method_handlers(
243            _SERVICE_NAME, _REGISTERED_METHOD_HANDLERS
244        )
245        self._channel = grpc.insecure_channel("localhost:%d" % port)
246
247        with self.assertRaises(grpc.RpcError) as exception_context:
248            self._channel.unary_unary(
249                grpc._common.fully_qualified_method(
250                    _SERVICE_NAME, _UNARY_UNARY_REGISTERED
251                ),
252                _registered_method=True,
253            )(_REQUEST)
254
255        self.assertIn("Method not found", str(exception_context.exception))
256
257    def test_server_with_both_registered_and_generic_handlers(self):
258        self._server = test_common.test_server()
259        self._server.add_generic_rpc_handlers((_GenericHandler(),))
260        self._server.add_registered_method_handlers(
261            _SERVICE_NAME, _REGISTERED_METHOD_HANDLERS
262        )
263        port = self._server.add_insecure_port("[::]:0")
264        self._server.start()
265        self._channel = grpc.insecure_channel("localhost:%d" % port)
266
267        generic_response = self._channel.unary_unary(
268            _UNARY_UNARY,
269            _registered_method=True,
270        )(_REQUEST)
271        self.assertEqual(_RESPONSE, generic_response)
272
273        registered_response = self._channel.unary_unary(
274            grpc._common.fully_qualified_method(
275                _SERVICE_NAME, _UNARY_UNARY_REGISTERED
276            ),
277            _registered_method=True,
278        )(_REQUEST)
279        self.assertEqual(_REGISTERED_RESPONSE, registered_response)
280
281    def test_server_registered_handler_take_precedence(self):
282        # Test if the same method have both generic and registered handler,
283        # registered handler will take precedence.
284        self._server = test_common.test_server()
285        self._server.add_generic_rpc_handlers(
286            (_GenericHandlerWithRegisteredName(),)
287        )
288        self._server.add_registered_method_handlers(
289            _SERVICE_NAME, _REGISTERED_METHOD_HANDLERS
290        )
291        port = self._server.add_insecure_port("[::]:0")
292        self._server.start()
293        self._channel = grpc.insecure_channel("localhost:%d" % port)
294
295        registered_response = self._channel.unary_unary(
296            grpc._common.fully_qualified_method(
297                _SERVICE_NAME, _UNARY_UNARY_REGISTERED
298            ),
299            _registered_method=True,
300        )(_REQUEST)
301        self.assertEqual(_REGISTERED_RESPONSE, registered_response)
302
303
304if __name__ == "__main__":
305    logging.basicConfig()
306    unittest.main(verbosity=2)
307