1# Copyright 2015 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import platform 16import threading 17import time 18import unittest 19 20from grpc._cython import cygrpc 21 22from tests.unit import resources 23from tests.unit import test_common 24from tests.unit._cython import test_utilities 25 26_SSL_HOST_OVERRIDE = b"foo.test.google.fr" 27_CALL_CREDENTIALS_METADATA_KEY = "call-creds-key" 28_CALL_CREDENTIALS_METADATA_VALUE = "call-creds-value" 29_EMPTY_FLAGS = 0 30 31 32def _metadata_plugin(context, callback): 33 callback( 34 ( 35 ( 36 _CALL_CREDENTIALS_METADATA_KEY, 37 _CALL_CREDENTIALS_METADATA_VALUE, 38 ), 39 ), 40 cygrpc.StatusCode.ok, 41 b"", 42 ) 43 44 45class TypeSmokeTest(unittest.TestCase): 46 def testCompletionQueueUpDown(self): 47 completion_queue = cygrpc.CompletionQueue() 48 del completion_queue 49 50 def testServerUpDown(self): 51 server = cygrpc.Server( 52 set( 53 [ 54 ( 55 b"grpc.so_reuseport", 56 0, 57 ) 58 ] 59 ), 60 False, 61 ) 62 del server 63 64 def testChannelUpDown(self): 65 channel = cygrpc.Channel(b"[::]:0", None, None) 66 channel.close(cygrpc.StatusCode.cancelled, "Test method anyway!") 67 68 def test_metadata_plugin_call_credentials_up_down(self): 69 cygrpc.MetadataPluginCallCredentials( 70 _metadata_plugin, b"test plugin name!" 71 ) 72 73 def testServerStartNoExplicitShutdown(self): 74 server = cygrpc.Server( 75 [ 76 ( 77 b"grpc.so_reuseport", 78 0, 79 ) 80 ], 81 False, 82 ) 83 completion_queue = cygrpc.CompletionQueue() 84 server.register_completion_queue(completion_queue) 85 port = server.add_http2_port(b"[::]:0") 86 self.assertIsInstance(port, int) 87 server.start() 88 del server 89 90 def testServerStartShutdown(self): 91 completion_queue = cygrpc.CompletionQueue() 92 server = cygrpc.Server( 93 [ 94 ( 95 b"grpc.so_reuseport", 96 0, 97 ), 98 ], 99 False, 100 ) 101 server.add_http2_port(b"[::]:0") 102 server.register_completion_queue(completion_queue) 103 server.start() 104 shutdown_tag = object() 105 server.shutdown(completion_queue, shutdown_tag) 106 event = completion_queue.poll() 107 self.assertEqual( 108 cygrpc.CompletionType.operation_complete, event.completion_type 109 ) 110 self.assertIs(shutdown_tag, event.tag) 111 del server 112 del completion_queue 113 114 115class ServerClientMixin(object): 116 def setUpMixin(self, server_credentials, client_credentials, host_override): 117 self.server_completion_queue = cygrpc.CompletionQueue() 118 self.server = cygrpc.Server( 119 [ 120 ( 121 b"grpc.so_reuseport", 122 0, 123 ) 124 ], 125 False, 126 ) 127 self.server.register_completion_queue(self.server_completion_queue) 128 if server_credentials: 129 self.port = self.server.add_http2_port( 130 b"[::]:0", server_credentials 131 ) 132 else: 133 self.port = self.server.add_http2_port(b"[::]:0") 134 self.server.start() 135 self.client_completion_queue = cygrpc.CompletionQueue() 136 if client_credentials: 137 client_channel_arguments = ( 138 ( 139 cygrpc.ChannelArgKey.ssl_target_name_override, 140 host_override, 141 ), 142 ) 143 self.client_channel = cygrpc.Channel( 144 "localhost:{}".format(self.port).encode(), 145 client_channel_arguments, 146 client_credentials, 147 ) 148 else: 149 self.client_channel = cygrpc.Channel( 150 "localhost:{}".format(self.port).encode(), set(), None 151 ) 152 if host_override: 153 self.host_argument = None # default host 154 self.expected_host = host_override 155 else: 156 # arbitrary host name necessitating no further identification 157 self.host_argument = b"hostess" 158 self.expected_host = self.host_argument 159 160 def tearDownMixin(self): 161 self.client_channel.close(cygrpc.StatusCode.ok, "test being torn down!") 162 del self.client_channel 163 del self.server 164 del self.client_completion_queue 165 del self.server_completion_queue 166 167 def _perform_queue_operations( 168 self, operations, call, queue, deadline, description 169 ): 170 """Perform the operations with given call, queue, and deadline. 171 172 Invocation errors are reported with as an exception with `description` 173 in the message. Performs the operations asynchronously, returning a 174 future. 175 """ 176 177 def performer(): 178 tag = object() 179 try: 180 call_result = call.start_client_batch(operations, tag) 181 self.assertEqual(cygrpc.CallError.ok, call_result) 182 event = queue.poll(deadline=deadline) 183 self.assertEqual( 184 cygrpc.CompletionType.operation_complete, 185 event.completion_type, 186 ) 187 self.assertTrue(event.success) 188 self.assertIs(tag, event.tag) 189 except Exception as error: 190 raise Exception( 191 "Error in '{}': {}".format(description, error.message) 192 ) 193 return event 194 195 return test_utilities.SimpleFuture(performer) 196 197 def test_echo(self): 198 DEADLINE = time.time() + 5 199 DEADLINE_TOLERANCE = 0.25 200 CLIENT_METADATA_ASCII_KEY = "key" 201 CLIENT_METADATA_ASCII_VALUE = "val" 202 CLIENT_METADATA_BIN_KEY = "key-bin" 203 CLIENT_METADATA_BIN_VALUE = b"\0" * 1000 204 SERVER_INITIAL_METADATA_KEY = "init_me_me_me" 205 SERVER_INITIAL_METADATA_VALUE = "whodawha?" 206 SERVER_TRAILING_METADATA_KEY = "california_is_in_a_drought" 207 SERVER_TRAILING_METADATA_VALUE = "zomg it is" 208 SERVER_STATUS_CODE = cygrpc.StatusCode.ok 209 SERVER_STATUS_DETAILS = "our work is never over" 210 REQUEST = b"in death a member of project mayhem has a name" 211 RESPONSE = b"his name is robert paulson" 212 METHOD = b"twinkies" 213 214 server_request_tag = object() 215 request_call_result = self.server.request_call( 216 self.server_completion_queue, 217 self.server_completion_queue, 218 server_request_tag, 219 ) 220 221 self.assertEqual(cygrpc.CallError.ok, request_call_result) 222 223 client_call_tag = object() 224 client_initial_metadata = ( 225 ( 226 CLIENT_METADATA_ASCII_KEY, 227 CLIENT_METADATA_ASCII_VALUE, 228 ), 229 ( 230 CLIENT_METADATA_BIN_KEY, 231 CLIENT_METADATA_BIN_VALUE, 232 ), 233 ) 234 client_call = self.client_channel.integrated_call( 235 0, 236 METHOD, 237 self.host_argument, 238 DEADLINE, 239 client_initial_metadata, 240 None, 241 [ 242 ( 243 [ 244 cygrpc.SendInitialMetadataOperation( 245 client_initial_metadata, _EMPTY_FLAGS 246 ), 247 cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS), 248 cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), 249 cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), 250 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), 251 cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), 252 ], 253 client_call_tag, 254 ), 255 ], 256 ) 257 client_event_future = test_utilities.SimpleFuture( 258 self.client_channel.next_call_event 259 ) 260 261 request_event = self.server_completion_queue.poll(deadline=DEADLINE) 262 self.assertEqual( 263 cygrpc.CompletionType.operation_complete, 264 request_event.completion_type, 265 ) 266 self.assertIsInstance(request_event.call, cygrpc.Call) 267 self.assertIs(server_request_tag, request_event.tag) 268 self.assertTrue( 269 test_common.metadata_transmitted( 270 client_initial_metadata, request_event.invocation_metadata 271 ) 272 ) 273 self.assertEqual(METHOD, request_event.call_details.method) 274 self.assertEqual(self.expected_host, request_event.call_details.host) 275 self.assertLess( 276 abs(DEADLINE - request_event.call_details.deadline), 277 DEADLINE_TOLERANCE, 278 ) 279 280 server_call_tag = object() 281 server_call = request_event.call 282 server_start_batch_result = server_call.start_server_batch( 283 [ 284 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), 285 ], 286 server_call_tag, 287 ) 288 self.assertEqual(cygrpc.CallError.ok, server_start_batch_result) 289 290 server_message_event = self.server_completion_queue.poll( 291 deadline=DEADLINE 292 ) 293 294 server_call_tag = object() 295 server_initial_metadata = ( 296 ( 297 SERVER_INITIAL_METADATA_KEY, 298 SERVER_INITIAL_METADATA_VALUE, 299 ), 300 ) 301 server_trailing_metadata = ( 302 ( 303 SERVER_TRAILING_METADATA_KEY, 304 SERVER_TRAILING_METADATA_VALUE, 305 ), 306 ) 307 server_start_batch_result = server_call.start_server_batch( 308 [ 309 cygrpc.SendInitialMetadataOperation( 310 server_initial_metadata, _EMPTY_FLAGS 311 ), 312 cygrpc.SendMessageOperation(RESPONSE, _EMPTY_FLAGS), 313 cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), 314 cygrpc.SendStatusFromServerOperation( 315 server_trailing_metadata, 316 SERVER_STATUS_CODE, 317 SERVER_STATUS_DETAILS, 318 _EMPTY_FLAGS, 319 ), 320 ], 321 server_call_tag, 322 ) 323 self.assertEqual(cygrpc.CallError.ok, server_start_batch_result) 324 325 server_event = self.server_completion_queue.poll(deadline=DEADLINE) 326 client_event = client_event_future.result() 327 328 self.assertEqual(6, len(client_event.batch_operations)) 329 found_client_op_types = set() 330 for client_result in client_event.batch_operations: 331 # we expect each op type to be unique 332 self.assertNotIn(client_result.type(), found_client_op_types) 333 found_client_op_types.add(client_result.type()) 334 if ( 335 client_result.type() 336 == cygrpc.OperationType.receive_initial_metadata 337 ): 338 self.assertTrue( 339 test_common.metadata_transmitted( 340 server_initial_metadata, 341 client_result.initial_metadata(), 342 ) 343 ) 344 elif client_result.type() == cygrpc.OperationType.receive_message: 345 self.assertEqual(RESPONSE, client_result.message()) 346 elif ( 347 client_result.type() 348 == cygrpc.OperationType.receive_status_on_client 349 ): 350 self.assertTrue( 351 test_common.metadata_transmitted( 352 server_trailing_metadata, 353 client_result.trailing_metadata(), 354 ) 355 ) 356 self.assertEqual(SERVER_STATUS_DETAILS, client_result.details()) 357 self.assertEqual(SERVER_STATUS_CODE, client_result.code()) 358 self.assertEqual( 359 set( 360 [ 361 cygrpc.OperationType.send_initial_metadata, 362 cygrpc.OperationType.send_message, 363 cygrpc.OperationType.send_close_from_client, 364 cygrpc.OperationType.receive_initial_metadata, 365 cygrpc.OperationType.receive_message, 366 cygrpc.OperationType.receive_status_on_client, 367 ] 368 ), 369 found_client_op_types, 370 ) 371 372 self.assertEqual(1, len(server_message_event.batch_operations)) 373 found_server_op_types = set() 374 for server_result in server_message_event.batch_operations: 375 self.assertNotIn(server_result.type(), found_server_op_types) 376 found_server_op_types.add(server_result.type()) 377 if server_result.type() == cygrpc.OperationType.receive_message: 378 self.assertEqual(REQUEST, server_result.message()) 379 elif ( 380 server_result.type() 381 == cygrpc.OperationType.receive_close_on_server 382 ): 383 self.assertFalse(server_result.cancelled()) 384 self.assertEqual( 385 set( 386 [ 387 cygrpc.OperationType.receive_message, 388 ] 389 ), 390 found_server_op_types, 391 ) 392 393 self.assertEqual(4, len(server_event.batch_operations)) 394 found_server_op_types = set() 395 for server_result in server_event.batch_operations: 396 self.assertNotIn(server_result.type(), found_server_op_types) 397 found_server_op_types.add(server_result.type()) 398 if server_result.type() == cygrpc.OperationType.receive_message: 399 self.assertEqual(REQUEST, server_result.message()) 400 elif ( 401 server_result.type() 402 == cygrpc.OperationType.receive_close_on_server 403 ): 404 self.assertFalse(server_result.cancelled()) 405 self.assertEqual( 406 set( 407 [ 408 cygrpc.OperationType.send_initial_metadata, 409 cygrpc.OperationType.send_message, 410 cygrpc.OperationType.receive_close_on_server, 411 cygrpc.OperationType.send_status_from_server, 412 ] 413 ), 414 found_server_op_types, 415 ) 416 417 del client_call 418 del server_call 419 420 def test_6522(self): 421 DEADLINE = time.time() + 5 422 DEADLINE_TOLERANCE = 0.25 423 METHOD = b"twinkies" 424 425 empty_metadata = () 426 427 # Prologue 428 server_request_tag = object() 429 self.server.request_call( 430 self.server_completion_queue, 431 self.server_completion_queue, 432 server_request_tag, 433 ) 434 client_call = self.client_channel.segregated_call( 435 0, 436 METHOD, 437 self.host_argument, 438 DEADLINE, 439 None, 440 None, 441 ( 442 [ 443 ( 444 [ 445 cygrpc.SendInitialMetadataOperation( 446 empty_metadata, _EMPTY_FLAGS 447 ), 448 cygrpc.ReceiveInitialMetadataOperation( 449 _EMPTY_FLAGS 450 ), 451 ], 452 object(), 453 ), 454 ( 455 [ 456 cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), 457 ], 458 object(), 459 ), 460 ] 461 ), 462 ) 463 464 client_initial_metadata_event_future = test_utilities.SimpleFuture( 465 client_call.next_event 466 ) 467 468 request_event = self.server_completion_queue.poll(deadline=DEADLINE) 469 server_call = request_event.call 470 471 def perform_server_operations(operations, description): 472 return self._perform_queue_operations( 473 operations, 474 server_call, 475 self.server_completion_queue, 476 DEADLINE, 477 description, 478 ) 479 480 server_event_future = perform_server_operations( 481 [ 482 cygrpc.SendInitialMetadataOperation( 483 empty_metadata, _EMPTY_FLAGS 484 ), 485 ], 486 "Server prologue", 487 ) 488 489 client_initial_metadata_event_future.result() # force completion 490 server_event_future.result() 491 492 # Messaging 493 for _ in range(10): 494 client_call.operate( 495 [ 496 cygrpc.SendMessageOperation(b"", _EMPTY_FLAGS), 497 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), 498 ], 499 "Client message", 500 ) 501 client_message_event_future = test_utilities.SimpleFuture( 502 client_call.next_event 503 ) 504 server_event_future = perform_server_operations( 505 [ 506 cygrpc.SendMessageOperation(b"", _EMPTY_FLAGS), 507 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), 508 ], 509 "Server receive", 510 ) 511 512 client_message_event_future.result() # force completion 513 server_event_future.result() 514 515 # Epilogue 516 client_call.operate( 517 [ 518 cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), 519 ], 520 "Client epilogue", 521 ) 522 # One for ReceiveStatusOnClient, one for SendCloseFromClient. 523 client_events_future = test_utilities.SimpleFuture( 524 lambda: { 525 client_call.next_event(), 526 client_call.next_event(), 527 } 528 ) 529 530 server_event_future = perform_server_operations( 531 [ 532 cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), 533 cygrpc.SendStatusFromServerOperation( 534 empty_metadata, cygrpc.StatusCode.ok, b"", _EMPTY_FLAGS 535 ), 536 ], 537 "Server epilogue", 538 ) 539 540 client_events_future.result() # force completion 541 server_event_future.result() 542 543 544class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin): 545 def setUp(self): 546 self.setUpMixin(None, None, None) 547 548 def tearDown(self): 549 self.tearDownMixin() 550 551 552class SecureServerSecureClient(unittest.TestCase, ServerClientMixin): 553 def setUp(self): 554 server_credentials = cygrpc.server_credentials_ssl( 555 None, 556 [ 557 cygrpc.SslPemKeyCertPair( 558 resources.private_key(), resources.certificate_chain() 559 ) 560 ], 561 False, 562 ) 563 client_credentials = cygrpc.SSLChannelCredentials( 564 resources.test_root_certificates(), None, None 565 ) 566 self.setUpMixin( 567 server_credentials, client_credentials, _SSL_HOST_OVERRIDE 568 ) 569 570 def tearDown(self): 571 self.tearDownMixin() 572 573 574if __name__ == "__main__": 575 unittest.main(verbosity=2) 576