1# Copyright 2018 The gRPC Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from concurrent import futures 16import logging 17import unittest 18 19import grpc 20 21from tests.unit import resources 22from tests.unit import test_common 23from tests.unit.framework.common import test_constants 24 25_REQUEST = b"" 26_RESPONSE = b"response" 27_REGISTERED_RESPONSE = b"registered_response" 28 29_SERVICE_NAME = "test" 30_UNARY_UNARY = "UnaryUnary" 31_UNARY_UNARY_REGISTERED = "UnaryUnaryRegistered" 32_UNARY_STREAM = "UnaryStream" 33_STREAM_UNARY = "StreamUnary" 34_STREAM_STREAM = "StreamStream" 35 36 37class _ActualGenericRpcHandler(grpc.GenericRpcHandler): 38 def service(self, handler_call_details): 39 return None 40 41 42def handle_unary_unary(request, servicer_context): 43 return _RESPONSE 44 45 46def handle_registered_unary_unary(request, servicer_context): 47 return _REGISTERED_RESPONSE 48 49 50def handle_unary_stream(request, servicer_context): 51 for _ in range(test_constants.STREAM_LENGTH): 52 yield _RESPONSE 53 54 55def handle_stream_unary(request_iterator, servicer_context): 56 for request in request_iterator: 57 pass 58 return _RESPONSE 59 60 61def handle_stream_stream(request_iterator, servicer_context): 62 for request in request_iterator: 63 yield _RESPONSE 64 65 66class _MethodHandler(grpc.RpcMethodHandler): 67 def __init__(self, request_streaming, response_streaming, registered=False): 68 self.request_streaming = request_streaming 69 self.response_streaming = response_streaming 70 self.request_deserializer = None 71 self.response_serializer = None 72 self.unary_unary = None 73 self.unary_stream = None 74 self.stream_unary = None 75 self.stream_stream = None 76 if self.request_streaming and self.response_streaming: 77 self.stream_stream = handle_stream_stream 78 elif self.request_streaming: 79 self.stream_unary = handle_stream_unary 80 elif self.response_streaming: 81 self.unary_stream = handle_unary_stream 82 else: 83 if registered: 84 self.unary_unary = handle_registered_unary_unary 85 else: 86 self.unary_unary = handle_unary_unary 87 88 89class _GenericHandler(grpc.GenericRpcHandler): 90 def service(self, handler_call_details): 91 if handler_call_details.method == _UNARY_UNARY: 92 return _MethodHandler(False, False) 93 elif handler_call_details.method == _UNARY_STREAM: 94 return _MethodHandler(False, True) 95 elif handler_call_details.method == _STREAM_UNARY: 96 return _MethodHandler(True, False) 97 elif handler_call_details.method == _STREAM_STREAM: 98 return _MethodHandler(True, True) 99 elif handler_call_details.method == grpc._common.fully_qualified_method( 100 _SERVICE_NAME, _UNARY_UNARY_REGISTERED 101 ): 102 return _MethodHandler(False, False) 103 else: 104 return None 105 106 107class _GenericHandlerWithRegisteredName(grpc.GenericRpcHandler): 108 def service(self, handler_call_details): 109 if handler_call_details.method == grpc._common.fully_qualified_method( 110 _SERVICE_NAME, _UNARY_UNARY_REGISTERED 111 ): 112 return _MethodHandler(False, False) 113 else: 114 return None 115 116 117_REGISTERED_METHOD_HANDLERS = { 118 _UNARY_UNARY_REGISTERED: _MethodHandler(False, False, True), 119} 120 121 122class ServerTest(unittest.TestCase): 123 def test_not_a_generic_rpc_handler_at_construction(self): 124 with self.assertRaises(AttributeError) as exception_context: 125 grpc.server( 126 futures.ThreadPoolExecutor(max_workers=5), 127 handlers=[ 128 _ActualGenericRpcHandler(), 129 object(), 130 ], 131 ) 132 self.assertIn( 133 "grpc.GenericRpcHandler", str(exception_context.exception) 134 ) 135 136 def test_not_a_generic_rpc_handler_after_construction(self): 137 server = grpc.server(futures.ThreadPoolExecutor(max_workers=5)) 138 with self.assertRaises(AttributeError) as exception_context: 139 server.add_generic_rpc_handlers( 140 [ 141 _ActualGenericRpcHandler(), 142 object(), 143 ] 144 ) 145 self.assertIn( 146 "grpc.GenericRpcHandler", str(exception_context.exception) 147 ) 148 149 def test_failed_port_binding_exception(self): 150 server = grpc.server(None, options=(("grpc.so_reuseport", 0),)) 151 port = server.add_insecure_port("localhost:0") 152 bind_address = "localhost:%d" % port 153 154 with self.assertRaises(RuntimeError): 155 server.add_insecure_port(bind_address) 156 157 server_credentials = grpc.ssl_server_credentials( 158 [(resources.private_key(), resources.certificate_chain())] 159 ) 160 with self.assertRaises(RuntimeError): 161 server.add_secure_port(bind_address, server_credentials) 162 163 164class ServerHandlerTest(unittest.TestCase): 165 def tearDown(self): 166 self._server.stop(0) 167 self._channel.close() 168 169 def test_generic_unary_unary_handler(self): 170 self._server = test_common.test_server() 171 port = self._server.add_insecure_port("[::]:0") 172 self._server.start() 173 self._server.add_generic_rpc_handlers((_GenericHandler(),)) 174 self._channel = grpc.insecure_channel("localhost:%d" % port) 175 176 response = self._channel.unary_unary( 177 _UNARY_UNARY, 178 _registered_method=True, 179 )(_REQUEST) 180 self.assertEqual(_RESPONSE, response) 181 182 def test_generic_unary_stream_handler(self): 183 self._server = test_common.test_server() 184 self._server.add_generic_rpc_handlers((_GenericHandler(),)) 185 port = self._server.add_insecure_port("[::]:0") 186 self._server.start() 187 self._channel = grpc.insecure_channel("localhost:%d" % port) 188 189 response_iterator = self._channel.unary_stream( 190 _UNARY_STREAM, 191 _registered_method=True, 192 )(_REQUEST) 193 self.assertSequenceEqual( 194 [_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator) 195 ) 196 197 def test_generic_stream_unary_handler(self): 198 self._server = test_common.test_server() 199 self._server.add_generic_rpc_handlers((_GenericHandler(),)) 200 port = self._server.add_insecure_port("[::]:0") 201 self._server.start() 202 self._channel = grpc.insecure_channel("localhost:%d" % port) 203 204 response = self._channel.stream_unary( 205 _STREAM_UNARY, 206 _registered_method=True, 207 )(iter([_REQUEST] * test_constants.STREAM_LENGTH)) 208 self.assertEqual(_RESPONSE, response) 209 210 def test_generic_stream_stream_handler(self): 211 self._server = test_common.test_server() 212 self._server.add_generic_rpc_handlers((_GenericHandler(),)) 213 port = self._server.add_insecure_port("[::]:0") 214 self._server.start() 215 self._channel = grpc.insecure_channel("localhost:%d" % port) 216 217 response_iterator = self._channel.stream_stream( 218 _STREAM_STREAM, 219 _registered_method=True, 220 )(iter([_REQUEST] * test_constants.STREAM_LENGTH)) 221 self.assertSequenceEqual( 222 [_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator) 223 ) 224 225 def test_add_generic_handler_after_server_start(self): 226 self._server = test_common.test_server() 227 port = self._server.add_insecure_port("[::]:0") 228 self._server.start() 229 self._server.add_generic_rpc_handlers((_GenericHandler(),)) 230 self._channel = grpc.insecure_channel("localhost:%d" % port) 231 232 response = self._channel.unary_unary( 233 _UNARY_UNARY, 234 _registered_method=True, 235 )(_REQUEST) 236 self.assertEqual(_RESPONSE, response) 237 238 def test_add_registered_handler_after_server_start(self): 239 self._server = test_common.test_server() 240 port = self._server.add_insecure_port("[::]:0") 241 self._server.start() 242 self._server.add_registered_method_handlers( 243 _SERVICE_NAME, _REGISTERED_METHOD_HANDLERS 244 ) 245 self._channel = grpc.insecure_channel("localhost:%d" % port) 246 247 with self.assertRaises(grpc.RpcError) as exception_context: 248 self._channel.unary_unary( 249 grpc._common.fully_qualified_method( 250 _SERVICE_NAME, _UNARY_UNARY_REGISTERED 251 ), 252 _registered_method=True, 253 )(_REQUEST) 254 255 self.assertIn("Method not found", str(exception_context.exception)) 256 257 def test_server_with_both_registered_and_generic_handlers(self): 258 self._server = test_common.test_server() 259 self._server.add_generic_rpc_handlers((_GenericHandler(),)) 260 self._server.add_registered_method_handlers( 261 _SERVICE_NAME, _REGISTERED_METHOD_HANDLERS 262 ) 263 port = self._server.add_insecure_port("[::]:0") 264 self._server.start() 265 self._channel = grpc.insecure_channel("localhost:%d" % port) 266 267 generic_response = self._channel.unary_unary( 268 _UNARY_UNARY, 269 _registered_method=True, 270 )(_REQUEST) 271 self.assertEqual(_RESPONSE, generic_response) 272 273 registered_response = self._channel.unary_unary( 274 grpc._common.fully_qualified_method( 275 _SERVICE_NAME, _UNARY_UNARY_REGISTERED 276 ), 277 _registered_method=True, 278 )(_REQUEST) 279 self.assertEqual(_REGISTERED_RESPONSE, registered_response) 280 281 def test_server_registered_handler_take_precedence(self): 282 # Test if the same method have both generic and registered handler, 283 # registered handler will take precedence. 284 self._server = test_common.test_server() 285 self._server.add_generic_rpc_handlers( 286 (_GenericHandlerWithRegisteredName(),) 287 ) 288 self._server.add_registered_method_handlers( 289 _SERVICE_NAME, _REGISTERED_METHOD_HANDLERS 290 ) 291 port = self._server.add_insecure_port("[::]:0") 292 self._server.start() 293 self._channel = grpc.insecure_channel("localhost:%d" % port) 294 295 registered_response = self._channel.unary_unary( 296 grpc._common.fully_qualified_method( 297 _SERVICE_NAME, _UNARY_UNARY_REGISTERED 298 ), 299 _registered_method=True, 300 )(_REQUEST) 301 self.assertEqual(_REGISTERED_RESPONSE, registered_response) 302 303 304if __name__ == "__main__": 305 logging.basicConfig() 306 unittest.main(verbosity=2) 307