1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import logging 16import unittest 17 18import grpc 19 20from tests.unit import test_common 21from tests.unit.framework.common import test_constants 22from tests.unit.framework.common import test_control 23 24_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 25_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2 :] 26_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 27_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[: len(bytestring) // 3] 28 29_UNARY_UNARY = "/test/UnaryUnary" 30_UNARY_UNARY_NESTED_EXCEPTION = "/test/UnaryUnaryNestedException" 31_UNARY_STREAM = "/test/UnaryStream" 32_STREAM_UNARY = "/test/StreamUnary" 33_STREAM_STREAM = "/test/StreamStream" 34_DEFECTIVE_GENERIC_RPC_HANDLER = "/test/DefectiveGenericRpcHandler" 35 36 37class _Handler(object): 38 def __init__(self, control): 39 self._control = control 40 41 def handle_unary_unary(self, request, servicer_context): 42 self._control.control() 43 if servicer_context is not None: 44 servicer_context.set_trailing_metadata( 45 ( 46 ( 47 "testkey", 48 "testvalue", 49 ), 50 ) 51 ) 52 return request 53 54 def handle_unary_unary_with_nested_exception( 55 self, request, servicer_context 56 ): 57 raise test_control.NestedDefect() 58 59 def handle_unary_stream(self, request, servicer_context): 60 for _ in range(test_constants.STREAM_LENGTH): 61 self._control.control() 62 yield request 63 self._control.control() 64 if servicer_context is not None: 65 servicer_context.set_trailing_metadata( 66 ( 67 ( 68 "testkey", 69 "testvalue", 70 ), 71 ) 72 ) 73 74 def handle_stream_unary(self, request_iterator, servicer_context): 75 if servicer_context is not None: 76 servicer_context.invocation_metadata() 77 self._control.control() 78 response_elements = [] 79 for request in request_iterator: 80 self._control.control() 81 response_elements.append(request) 82 self._control.control() 83 if servicer_context is not None: 84 servicer_context.set_trailing_metadata( 85 ( 86 ( 87 "testkey", 88 "testvalue", 89 ), 90 ) 91 ) 92 return b"".join(response_elements) 93 94 def handle_stream_stream(self, request_iterator, servicer_context): 95 self._control.control() 96 if servicer_context is not None: 97 servicer_context.set_trailing_metadata( 98 ( 99 ( 100 "testkey", 101 "testvalue", 102 ), 103 ) 104 ) 105 for request in request_iterator: 106 self._control.control() 107 yield request 108 self._control.control() 109 110 def defective_generic_rpc_handler(self): 111 raise test_control.Defect() 112 113 114class _MethodHandler(grpc.RpcMethodHandler): 115 def __init__( 116 self, 117 request_streaming, 118 response_streaming, 119 request_deserializer, 120 response_serializer, 121 unary_unary, 122 unary_stream, 123 stream_unary, 124 stream_stream, 125 ): 126 self.request_streaming = request_streaming 127 self.response_streaming = response_streaming 128 self.request_deserializer = request_deserializer 129 self.response_serializer = response_serializer 130 self.unary_unary = unary_unary 131 self.unary_stream = unary_stream 132 self.stream_unary = stream_unary 133 self.stream_stream = stream_stream 134 135 136class _GenericHandler(grpc.GenericRpcHandler): 137 def __init__(self, handler): 138 self._handler = handler 139 140 def service(self, handler_call_details): 141 if handler_call_details.method == _UNARY_UNARY: 142 return _MethodHandler( 143 False, 144 False, 145 None, 146 None, 147 self._handler.handle_unary_unary, 148 None, 149 None, 150 None, 151 ) 152 elif handler_call_details.method == _UNARY_STREAM: 153 return _MethodHandler( 154 False, 155 True, 156 _DESERIALIZE_REQUEST, 157 _SERIALIZE_RESPONSE, 158 None, 159 self._handler.handle_unary_stream, 160 None, 161 None, 162 ) 163 elif handler_call_details.method == _STREAM_UNARY: 164 return _MethodHandler( 165 True, 166 False, 167 _DESERIALIZE_REQUEST, 168 _SERIALIZE_RESPONSE, 169 None, 170 None, 171 self._handler.handle_stream_unary, 172 None, 173 ) 174 elif handler_call_details.method == _STREAM_STREAM: 175 return _MethodHandler( 176 True, 177 True, 178 None, 179 None, 180 None, 181 None, 182 None, 183 self._handler.handle_stream_stream, 184 ) 185 elif handler_call_details.method == _DEFECTIVE_GENERIC_RPC_HANDLER: 186 return self._handler.defective_generic_rpc_handler() 187 elif handler_call_details.method == _UNARY_UNARY_NESTED_EXCEPTION: 188 return _MethodHandler( 189 False, 190 False, 191 None, 192 None, 193 self._handler.handle_unary_unary_with_nested_exception, 194 None, 195 None, 196 None, 197 ) 198 else: 199 return None 200 201 202class FailAfterFewIterationsCounter(object): 203 def __init__(self, high, bytestring): 204 self._current = 0 205 self._high = high 206 self._bytestring = bytestring 207 208 def __iter__(self): 209 return self 210 211 def __next__(self): 212 if self._current >= self._high: 213 raise test_control.Defect() 214 else: 215 self._current += 1 216 return self._bytestring 217 218 next = __next__ 219 220 221def _unary_unary_multi_callable(channel): 222 return channel.unary_unary( 223 _UNARY_UNARY, 224 _registered_method=True, 225 ) 226 227 228def _unary_stream_multi_callable(channel): 229 return channel.unary_stream( 230 _UNARY_STREAM, 231 request_serializer=_SERIALIZE_REQUEST, 232 response_deserializer=_DESERIALIZE_RESPONSE, 233 _registered_method=True, 234 ) 235 236 237def _stream_unary_multi_callable(channel): 238 return channel.stream_unary( 239 _STREAM_UNARY, 240 request_serializer=_SERIALIZE_REQUEST, 241 response_deserializer=_DESERIALIZE_RESPONSE, 242 _registered_method=True, 243 ) 244 245 246def _stream_stream_multi_callable(channel): 247 return channel.stream_stream( 248 _STREAM_STREAM, 249 _registered_method=True, 250 ) 251 252 253def _defective_handler_multi_callable(channel): 254 return channel.unary_unary( 255 _DEFECTIVE_GENERIC_RPC_HANDLER, 256 _registered_method=True, 257 ) 258 259 260def _defective_nested_exception_handler_multi_callable(channel): 261 return channel.unary_unary( 262 _UNARY_UNARY_NESTED_EXCEPTION, 263 _registered_method=True, 264 ) 265 266 267class InvocationDefectsTest(unittest.TestCase): 268 """Tests the handling of exception-raising user code on the client-side.""" 269 270 def setUp(self): 271 self._control = test_control.PauseFailControl() 272 self._handler = _Handler(self._control) 273 274 self._server = test_common.test_server() 275 port = self._server.add_insecure_port("[::]:0") 276 self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) 277 self._server.start() 278 279 self._channel = grpc.insecure_channel("localhost:%d" % port) 280 281 def tearDown(self): 282 self._server.stop(0) 283 self._channel.close() 284 285 def testIterableStreamRequestBlockingUnaryResponse(self): 286 requests = object() 287 multi_callable = _stream_unary_multi_callable(self._channel) 288 289 with self.assertRaises(grpc.RpcError) as exception_context: 290 multi_callable( 291 requests, 292 metadata=( 293 ("test", "IterableStreamRequestBlockingUnaryResponse"), 294 ), 295 ) 296 297 self.assertIs( 298 grpc.StatusCode.UNKNOWN, exception_context.exception.code() 299 ) 300 301 def testIterableStreamRequestFutureUnaryResponse(self): 302 requests = object() 303 multi_callable = _stream_unary_multi_callable(self._channel) 304 response_future = multi_callable.future( 305 requests, 306 metadata=(("test", "IterableStreamRequestFutureUnaryResponse"),), 307 ) 308 309 with self.assertRaises(grpc.RpcError) as exception_context: 310 response_future.result() 311 312 self.assertIs( 313 grpc.StatusCode.UNKNOWN, exception_context.exception.code() 314 ) 315 316 def testIterableStreamRequestStreamResponse(self): 317 requests = object() 318 multi_callable = _stream_stream_multi_callable(self._channel) 319 response_iterator = multi_callable( 320 requests, 321 metadata=(("test", "IterableStreamRequestStreamResponse"),), 322 ) 323 324 with self.assertRaises(grpc.RpcError) as exception_context: 325 next(response_iterator) 326 327 self.assertIs( 328 grpc.StatusCode.UNKNOWN, exception_context.exception.code() 329 ) 330 331 def testIteratorStreamRequestStreamResponse(self): 332 requests_iterator = FailAfterFewIterationsCounter( 333 test_constants.STREAM_LENGTH // 2, b"\x07\x08" 334 ) 335 multi_callable = _stream_stream_multi_callable(self._channel) 336 response_iterator = multi_callable( 337 requests_iterator, 338 metadata=(("test", "IteratorStreamRequestStreamResponse"),), 339 ) 340 341 with self.assertRaises(grpc.RpcError) as exception_context: 342 for _ in range(test_constants.STREAM_LENGTH // 2 + 1): 343 next(response_iterator) 344 345 self.assertIs( 346 grpc.StatusCode.UNKNOWN, exception_context.exception.code() 347 ) 348 349 def testDefectiveGenericRpcHandlerUnaryResponse(self): 350 request = b"\x07\x08" 351 multi_callable = _defective_handler_multi_callable(self._channel) 352 353 with self.assertRaises(grpc.RpcError) as exception_context: 354 multi_callable( 355 request, metadata=(("test", "DefectiveGenericRpcHandlerUnary"),) 356 ) 357 358 self.assertIs( 359 grpc.StatusCode.UNKNOWN, exception_context.exception.code() 360 ) 361 362 def testNestedExceptionGenericRpcHandlerUnaryResponse(self): 363 request = b"\x07\x08" 364 multi_callable = _defective_nested_exception_handler_multi_callable( 365 self._channel 366 ) 367 368 with self.assertRaises(grpc.RpcError) as exception_context: 369 multi_callable( 370 request, metadata=(("test", "DefectiveGenericRpcHandlerUnary"),) 371 ) 372 373 self.assertIs( 374 grpc.StatusCode.UNKNOWN, exception_context.exception.code() 375 ) 376 377 378if __name__ == "__main__": 379 logging.basicConfig() 380 unittest.main(verbosity=2) 381