1# Copyright 2015 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import time 16import threading 17import unittest 18import platform 19 20from grpc._cython import cygrpc 21from tests.unit._cython import test_utilities 22from tests.unit import test_common 23from tests.unit import resources 24 25_SSL_HOST_OVERRIDE = b'foo.test.google.fr' 26_CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key' 27_CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value' 28_EMPTY_FLAGS = 0 29 30 31def _metadata_plugin(context, callback): 32 callback((( 33 _CALL_CREDENTIALS_METADATA_KEY, 34 _CALL_CREDENTIALS_METADATA_VALUE, 35 ),), cygrpc.StatusCode.ok, b'') 36 37 38class TypeSmokeTest(unittest.TestCase): 39 40 def testCompletionQueueUpDown(self): 41 completion_queue = cygrpc.CompletionQueue() 42 del completion_queue 43 44 def testServerUpDown(self): 45 server = cygrpc.Server(set([( 46 b'grpc.so_reuseport', 47 0, 48 )]), False) 49 del server 50 51 def testChannelUpDown(self): 52 channel = cygrpc.Channel(b'[::]:0', None, None) 53 channel.close(cygrpc.StatusCode.cancelled, 'Test method anyway!') 54 55 def test_metadata_plugin_call_credentials_up_down(self): 56 cygrpc.MetadataPluginCallCredentials(_metadata_plugin, 57 b'test plugin name!') 58 59 def testServerStartNoExplicitShutdown(self): 60 server = cygrpc.Server([( 61 b'grpc.so_reuseport', 62 0, 63 )], False) 64 completion_queue = cygrpc.CompletionQueue() 65 server.register_completion_queue(completion_queue) 66 port = server.add_http2_port(b'[::]:0') 67 self.assertIsInstance(port, int) 68 server.start() 69 del server 70 71 def testServerStartShutdown(self): 72 completion_queue = cygrpc.CompletionQueue() 73 server = cygrpc.Server([ 74 ( 75 b'grpc.so_reuseport', 76 0, 77 ), 78 ], False) 79 server.add_http2_port(b'[::]:0') 80 server.register_completion_queue(completion_queue) 81 server.start() 82 shutdown_tag = object() 83 server.shutdown(completion_queue, shutdown_tag) 84 event = completion_queue.poll() 85 self.assertEqual(cygrpc.CompletionType.operation_complete, 86 event.completion_type) 87 self.assertIs(shutdown_tag, event.tag) 88 del server 89 del completion_queue 90 91 92class ServerClientMixin(object): 93 94 def setUpMixin(self, server_credentials, client_credentials, host_override): 95 self.server_completion_queue = cygrpc.CompletionQueue() 96 self.server = cygrpc.Server([( 97 b'grpc.so_reuseport', 98 0, 99 )], False) 100 self.server.register_completion_queue(self.server_completion_queue) 101 if server_credentials: 102 self.port = self.server.add_http2_port(b'[::]:0', 103 server_credentials) 104 else: 105 self.port = self.server.add_http2_port(b'[::]:0') 106 self.server.start() 107 self.client_completion_queue = cygrpc.CompletionQueue() 108 if client_credentials: 109 client_channel_arguments = (( 110 cygrpc.ChannelArgKey.ssl_target_name_override, 111 host_override, 112 ),) 113 self.client_channel = cygrpc.Channel( 114 'localhost:{}'.format(self.port).encode(), 115 client_channel_arguments, client_credentials) 116 else: 117 self.client_channel = cygrpc.Channel( 118 'localhost:{}'.format(self.port).encode(), set(), None) 119 if host_override: 120 self.host_argument = None # default host 121 self.expected_host = host_override 122 else: 123 # arbitrary host name necessitating no further identification 124 self.host_argument = b'hostess' 125 self.expected_host = self.host_argument 126 127 def tearDownMixin(self): 128 self.client_channel.close(cygrpc.StatusCode.ok, 'test being torn down!') 129 del self.client_channel 130 del self.server 131 del self.client_completion_queue 132 del self.server_completion_queue 133 134 def _perform_queue_operations(self, operations, call, queue, deadline, 135 description): 136 """Perform the operations with given call, queue, and deadline. 137 138 Invocation errors are reported with as an exception with `description` 139 in the message. Performs the operations asynchronously, returning a 140 future. 141 """ 142 143 def performer(): 144 tag = object() 145 try: 146 call_result = call.start_client_batch(operations, tag) 147 self.assertEqual(cygrpc.CallError.ok, call_result) 148 event = queue.poll(deadline=deadline) 149 self.assertEqual(cygrpc.CompletionType.operation_complete, 150 event.completion_type) 151 self.assertTrue(event.success) 152 self.assertIs(tag, event.tag) 153 except Exception as error: 154 raise Exception("Error in '{}': {}".format( 155 description, error.message)) 156 return event 157 158 return test_utilities.SimpleFuture(performer) 159 160 def test_echo(self): 161 DEADLINE = time.time() + 5 162 DEADLINE_TOLERANCE = 0.25 163 CLIENT_METADATA_ASCII_KEY = 'key' 164 CLIENT_METADATA_ASCII_VALUE = 'val' 165 CLIENT_METADATA_BIN_KEY = 'key-bin' 166 CLIENT_METADATA_BIN_VALUE = b'\0' * 1000 167 SERVER_INITIAL_METADATA_KEY = 'init_me_me_me' 168 SERVER_INITIAL_METADATA_VALUE = 'whodawha?' 169 SERVER_TRAILING_METADATA_KEY = 'california_is_in_a_drought' 170 SERVER_TRAILING_METADATA_VALUE = 'zomg it is' 171 SERVER_STATUS_CODE = cygrpc.StatusCode.ok 172 SERVER_STATUS_DETAILS = 'our work is never over' 173 REQUEST = b'in death a member of project mayhem has a name' 174 RESPONSE = b'his name is robert paulson' 175 METHOD = b'twinkies' 176 177 server_request_tag = object() 178 request_call_result = self.server.request_call( 179 self.server_completion_queue, self.server_completion_queue, 180 server_request_tag) 181 182 self.assertEqual(cygrpc.CallError.ok, request_call_result) 183 184 client_call_tag = object() 185 client_initial_metadata = ( 186 ( 187 CLIENT_METADATA_ASCII_KEY, 188 CLIENT_METADATA_ASCII_VALUE, 189 ), 190 ( 191 CLIENT_METADATA_BIN_KEY, 192 CLIENT_METADATA_BIN_VALUE, 193 ), 194 ) 195 client_call = self.client_channel.integrated_call( 196 0, METHOD, self.host_argument, DEADLINE, client_initial_metadata, 197 None, [ 198 ( 199 [ 200 cygrpc.SendInitialMetadataOperation( 201 client_initial_metadata, _EMPTY_FLAGS), 202 cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS), 203 cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), 204 cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), 205 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), 206 cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), 207 ], 208 client_call_tag, 209 ), 210 ]) 211 client_event_future = test_utilities.SimpleFuture( 212 self.client_channel.next_call_event) 213 214 request_event = self.server_completion_queue.poll(deadline=DEADLINE) 215 self.assertEqual(cygrpc.CompletionType.operation_complete, 216 request_event.completion_type) 217 self.assertIsInstance(request_event.call, cygrpc.Call) 218 self.assertIs(server_request_tag, request_event.tag) 219 self.assertTrue( 220 test_common.metadata_transmitted(client_initial_metadata, 221 request_event.invocation_metadata)) 222 self.assertEqual(METHOD, request_event.call_details.method) 223 self.assertEqual(self.expected_host, request_event.call_details.host) 224 self.assertLess(abs(DEADLINE - request_event.call_details.deadline), 225 DEADLINE_TOLERANCE) 226 227 server_call_tag = object() 228 server_call = request_event.call 229 server_initial_metadata = (( 230 SERVER_INITIAL_METADATA_KEY, 231 SERVER_INITIAL_METADATA_VALUE, 232 ),) 233 server_trailing_metadata = (( 234 SERVER_TRAILING_METADATA_KEY, 235 SERVER_TRAILING_METADATA_VALUE, 236 ),) 237 server_start_batch_result = server_call.start_server_batch([ 238 cygrpc.SendInitialMetadataOperation(server_initial_metadata, 239 _EMPTY_FLAGS), 240 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), 241 cygrpc.SendMessageOperation(RESPONSE, _EMPTY_FLAGS), 242 cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), 243 cygrpc.SendStatusFromServerOperation( 244 server_trailing_metadata, SERVER_STATUS_CODE, 245 SERVER_STATUS_DETAILS, _EMPTY_FLAGS) 246 ], server_call_tag) 247 self.assertEqual(cygrpc.CallError.ok, server_start_batch_result) 248 249 server_event = self.server_completion_queue.poll(deadline=DEADLINE) 250 client_event = client_event_future.result() 251 252 self.assertEqual(6, len(client_event.batch_operations)) 253 found_client_op_types = set() 254 for client_result in client_event.batch_operations: 255 # we expect each op type to be unique 256 self.assertNotIn(client_result.type(), found_client_op_types) 257 found_client_op_types.add(client_result.type()) 258 if client_result.type( 259 ) == cygrpc.OperationType.receive_initial_metadata: 260 self.assertTrue( 261 test_common.metadata_transmitted( 262 server_initial_metadata, 263 client_result.initial_metadata())) 264 elif client_result.type() == cygrpc.OperationType.receive_message: 265 self.assertEqual(RESPONSE, client_result.message()) 266 elif client_result.type( 267 ) == cygrpc.OperationType.receive_status_on_client: 268 self.assertTrue( 269 test_common.metadata_transmitted( 270 server_trailing_metadata, 271 client_result.trailing_metadata())) 272 self.assertEqual(SERVER_STATUS_DETAILS, client_result.details()) 273 self.assertEqual(SERVER_STATUS_CODE, client_result.code()) 274 self.assertEqual( 275 set([ 276 cygrpc.OperationType.send_initial_metadata, 277 cygrpc.OperationType.send_message, 278 cygrpc.OperationType.send_close_from_client, 279 cygrpc.OperationType.receive_initial_metadata, 280 cygrpc.OperationType.receive_message, 281 cygrpc.OperationType.receive_status_on_client 282 ]), found_client_op_types) 283 284 self.assertEqual(5, len(server_event.batch_operations)) 285 found_server_op_types = set() 286 for server_result in server_event.batch_operations: 287 self.assertNotIn(server_result.type(), found_server_op_types) 288 found_server_op_types.add(server_result.type()) 289 if server_result.type() == cygrpc.OperationType.receive_message: 290 self.assertEqual(REQUEST, server_result.message()) 291 elif server_result.type( 292 ) == cygrpc.OperationType.receive_close_on_server: 293 self.assertFalse(server_result.cancelled()) 294 self.assertEqual( 295 set([ 296 cygrpc.OperationType.send_initial_metadata, 297 cygrpc.OperationType.receive_message, 298 cygrpc.OperationType.send_message, 299 cygrpc.OperationType.receive_close_on_server, 300 cygrpc.OperationType.send_status_from_server 301 ]), found_server_op_types) 302 303 del client_call 304 del server_call 305 306 def test_6522(self): 307 DEADLINE = time.time() + 5 308 DEADLINE_TOLERANCE = 0.25 309 METHOD = b'twinkies' 310 311 empty_metadata = () 312 313 # Prologue 314 server_request_tag = object() 315 self.server.request_call(self.server_completion_queue, 316 self.server_completion_queue, 317 server_request_tag) 318 client_call = self.client_channel.segregated_call( 319 0, METHOD, self.host_argument, DEADLINE, None, None, 320 ([( 321 [ 322 cygrpc.SendInitialMetadataOperation(empty_metadata, 323 _EMPTY_FLAGS), 324 cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), 325 ], 326 object(), 327 ), 328 ( 329 [ 330 cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), 331 ], 332 object(), 333 )])) 334 335 client_initial_metadata_event_future = test_utilities.SimpleFuture( 336 client_call.next_event) 337 338 request_event = self.server_completion_queue.poll(deadline=DEADLINE) 339 server_call = request_event.call 340 341 def perform_server_operations(operations, description): 342 return self._perform_queue_operations(operations, server_call, 343 self.server_completion_queue, 344 DEADLINE, description) 345 346 server_event_future = perform_server_operations([ 347 cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS), 348 ], "Server prologue") 349 350 client_initial_metadata_event_future.result() # force completion 351 server_event_future.result() 352 353 # Messaging 354 for _ in range(10): 355 client_call.operate([ 356 cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS), 357 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), 358 ], "Client message") 359 client_message_event_future = test_utilities.SimpleFuture( 360 client_call.next_event) 361 server_event_future = perform_server_operations([ 362 cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS), 363 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), 364 ], "Server receive") 365 366 client_message_event_future.result() # force completion 367 server_event_future.result() 368 369 # Epilogue 370 client_call.operate([ 371 cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), 372 ], "Client epilogue") 373 # One for ReceiveStatusOnClient, one for SendCloseFromClient. 374 client_events_future = test_utilities.SimpleFuture(lambda: { 375 client_call.next_event(), 376 client_call.next_event(), 377 }) 378 379 server_event_future = perform_server_operations([ 380 cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), 381 cygrpc.SendStatusFromServerOperation( 382 empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS) 383 ], "Server epilogue") 384 385 client_events_future.result() # force completion 386 server_event_future.result() 387 388 389class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin): 390 391 def setUp(self): 392 self.setUpMixin(None, None, None) 393 394 def tearDown(self): 395 self.tearDownMixin() 396 397 398class SecureServerSecureClient(unittest.TestCase, ServerClientMixin): 399 400 def setUp(self): 401 server_credentials = cygrpc.server_credentials_ssl( 402 None, [ 403 cygrpc.SslPemKeyCertPair(resources.private_key(), 404 resources.certificate_chain()) 405 ], False) 406 client_credentials = cygrpc.SSLChannelCredentials( 407 resources.test_root_certificates(), None, None) 408 self.setUpMixin(server_credentials, client_credentials, 409 _SSL_HOST_OVERRIDE) 410 411 def tearDown(self): 412 self.tearDownMixin() 413 414 415if __name__ == '__main__': 416 unittest.main(verbosity=2) 417