1# Copyright 2015 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import time 16import threading 17import unittest 18import platform 19 20from grpc._cython import cygrpc 21from tests.unit._cython import test_utilities 22from tests.unit import test_common 23from tests.unit import resources 24 25_SSL_HOST_OVERRIDE = b'foo.test.google.fr' 26_CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key' 27_CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value' 28_EMPTY_FLAGS = 0 29 30 31def _metadata_plugin(context, callback): 32 callback((( 33 _CALL_CREDENTIALS_METADATA_KEY, 34 _CALL_CREDENTIALS_METADATA_VALUE, 35 ),), cygrpc.StatusCode.ok, b'') 36 37 38class TypeSmokeTest(unittest.TestCase): 39 40 def testCompletionQueueUpDown(self): 41 completion_queue = cygrpc.CompletionQueue() 42 del completion_queue 43 44 def testServerUpDown(self): 45 server = cygrpc.Server(set([ 46 ( 47 b'grpc.so_reuseport', 48 0, 49 ), 50 ])) 51 del server 52 53 def testChannelUpDown(self): 54 channel = cygrpc.Channel(b'[::]:0', None, None) 55 channel.close(cygrpc.StatusCode.cancelled, 'Test method anyway!') 56 57 def test_metadata_plugin_call_credentials_up_down(self): 58 cygrpc.MetadataPluginCallCredentials(_metadata_plugin, 59 b'test plugin name!') 60 61 def testServerStartNoExplicitShutdown(self): 62 server = cygrpc.Server([ 63 ( 64 b'grpc.so_reuseport', 65 0, 66 ), 67 ]) 68 completion_queue = cygrpc.CompletionQueue() 69 server.register_completion_queue(completion_queue) 70 port = server.add_http2_port(b'[::]:0') 71 self.assertIsInstance(port, int) 72 server.start() 73 del server 74 75 def testServerStartShutdown(self): 76 completion_queue = cygrpc.CompletionQueue() 77 server = cygrpc.Server([ 78 ( 79 b'grpc.so_reuseport', 80 0, 81 ), 82 ]) 83 server.add_http2_port(b'[::]:0') 84 server.register_completion_queue(completion_queue) 85 server.start() 86 shutdown_tag = object() 87 server.shutdown(completion_queue, shutdown_tag) 88 event = completion_queue.poll() 89 self.assertEqual(cygrpc.CompletionType.operation_complete, 90 event.completion_type) 91 self.assertIs(shutdown_tag, event.tag) 92 del server 93 del completion_queue 94 95 96class ServerClientMixin(object): 97 98 def setUpMixin(self, server_credentials, client_credentials, host_override): 99 self.server_completion_queue = cygrpc.CompletionQueue() 100 self.server = cygrpc.Server([ 101 ( 102 b'grpc.so_reuseport', 103 0, 104 ), 105 ]) 106 self.server.register_completion_queue(self.server_completion_queue) 107 if server_credentials: 108 self.port = self.server.add_http2_port(b'[::]:0', 109 server_credentials) 110 else: 111 self.port = self.server.add_http2_port(b'[::]:0') 112 self.server.start() 113 self.client_completion_queue = cygrpc.CompletionQueue() 114 if client_credentials: 115 client_channel_arguments = (( 116 cygrpc.ChannelArgKey.ssl_target_name_override, 117 host_override, 118 ),) 119 self.client_channel = cygrpc.Channel( 120 'localhost:{}'.format(self.port).encode(), 121 client_channel_arguments, client_credentials) 122 else: 123 self.client_channel = cygrpc.Channel( 124 'localhost:{}'.format(self.port).encode(), set(), None) 125 if host_override: 126 self.host_argument = None # default host 127 self.expected_host = host_override 128 else: 129 # arbitrary host name necessitating no further identification 130 self.host_argument = b'hostess' 131 self.expected_host = self.host_argument 132 133 def tearDownMixin(self): 134 self.client_channel.close(cygrpc.StatusCode.ok, 'test being torn down!') 135 del self.client_channel 136 del self.server 137 del self.client_completion_queue 138 del self.server_completion_queue 139 140 def _perform_queue_operations(self, operations, call, queue, deadline, 141 description): 142 """Perform the operations with given call, queue, and deadline. 143 144 Invocation errors are reported with as an exception with `description` 145 in the message. Performs the operations asynchronously, returning a 146 future. 147 """ 148 149 def performer(): 150 tag = object() 151 try: 152 call_result = call.start_client_batch(operations, tag) 153 self.assertEqual(cygrpc.CallError.ok, call_result) 154 event = queue.poll(deadline=deadline) 155 self.assertEqual(cygrpc.CompletionType.operation_complete, 156 event.completion_type) 157 self.assertTrue(event.success) 158 self.assertIs(tag, event.tag) 159 except Exception as error: 160 raise Exception("Error in '{}': {}".format( 161 description, error.message)) 162 return event 163 164 return test_utilities.SimpleFuture(performer) 165 166 def test_echo(self): 167 DEADLINE = time.time() + 5 168 DEADLINE_TOLERANCE = 0.25 169 CLIENT_METADATA_ASCII_KEY = 'key' 170 CLIENT_METADATA_ASCII_VALUE = 'val' 171 CLIENT_METADATA_BIN_KEY = 'key-bin' 172 CLIENT_METADATA_BIN_VALUE = b'\0' * 1000 173 SERVER_INITIAL_METADATA_KEY = 'init_me_me_me' 174 SERVER_INITIAL_METADATA_VALUE = 'whodawha?' 175 SERVER_TRAILING_METADATA_KEY = 'california_is_in_a_drought' 176 SERVER_TRAILING_METADATA_VALUE = 'zomg it is' 177 SERVER_STATUS_CODE = cygrpc.StatusCode.ok 178 SERVER_STATUS_DETAILS = 'our work is never over' 179 REQUEST = b'in death a member of project mayhem has a name' 180 RESPONSE = b'his name is robert paulson' 181 METHOD = b'twinkies' 182 183 server_request_tag = object() 184 request_call_result = self.server.request_call( 185 self.server_completion_queue, self.server_completion_queue, 186 server_request_tag) 187 188 self.assertEqual(cygrpc.CallError.ok, request_call_result) 189 190 client_call_tag = object() 191 client_initial_metadata = ( 192 ( 193 CLIENT_METADATA_ASCII_KEY, 194 CLIENT_METADATA_ASCII_VALUE, 195 ), 196 ( 197 CLIENT_METADATA_BIN_KEY, 198 CLIENT_METADATA_BIN_VALUE, 199 ), 200 ) 201 client_call = self.client_channel.integrated_call( 202 0, METHOD, self.host_argument, DEADLINE, client_initial_metadata, 203 None, [ 204 ( 205 [ 206 cygrpc.SendInitialMetadataOperation( 207 client_initial_metadata, _EMPTY_FLAGS), 208 cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS), 209 cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), 210 cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), 211 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), 212 cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), 213 ], 214 client_call_tag, 215 ), 216 ]) 217 client_event_future = test_utilities.SimpleFuture( 218 self.client_channel.next_call_event) 219 220 request_event = self.server_completion_queue.poll(deadline=DEADLINE) 221 self.assertEqual(cygrpc.CompletionType.operation_complete, 222 request_event.completion_type) 223 self.assertIsInstance(request_event.call, cygrpc.Call) 224 self.assertIs(server_request_tag, request_event.tag) 225 self.assertTrue( 226 test_common.metadata_transmitted(client_initial_metadata, 227 request_event.invocation_metadata)) 228 self.assertEqual(METHOD, request_event.call_details.method) 229 self.assertEqual(self.expected_host, request_event.call_details.host) 230 self.assertLess(abs(DEADLINE - request_event.call_details.deadline), 231 DEADLINE_TOLERANCE) 232 233 server_call_tag = object() 234 server_call = request_event.call 235 server_initial_metadata = (( 236 SERVER_INITIAL_METADATA_KEY, 237 SERVER_INITIAL_METADATA_VALUE, 238 ),) 239 server_trailing_metadata = (( 240 SERVER_TRAILING_METADATA_KEY, 241 SERVER_TRAILING_METADATA_VALUE, 242 ),) 243 server_start_batch_result = server_call.start_server_batch([ 244 cygrpc.SendInitialMetadataOperation(server_initial_metadata, 245 _EMPTY_FLAGS), 246 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), 247 cygrpc.SendMessageOperation(RESPONSE, _EMPTY_FLAGS), 248 cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), 249 cygrpc.SendStatusFromServerOperation( 250 server_trailing_metadata, SERVER_STATUS_CODE, 251 SERVER_STATUS_DETAILS, _EMPTY_FLAGS) 252 ], server_call_tag) 253 self.assertEqual(cygrpc.CallError.ok, server_start_batch_result) 254 255 server_event = self.server_completion_queue.poll(deadline=DEADLINE) 256 client_event = client_event_future.result() 257 258 self.assertEqual(6, len(client_event.batch_operations)) 259 found_client_op_types = set() 260 for client_result in client_event.batch_operations: 261 # we expect each op type to be unique 262 self.assertNotIn(client_result.type(), found_client_op_types) 263 found_client_op_types.add(client_result.type()) 264 if client_result.type( 265 ) == cygrpc.OperationType.receive_initial_metadata: 266 self.assertTrue( 267 test_common.metadata_transmitted( 268 server_initial_metadata, 269 client_result.initial_metadata())) 270 elif client_result.type() == cygrpc.OperationType.receive_message: 271 self.assertEqual(RESPONSE, client_result.message()) 272 elif client_result.type( 273 ) == cygrpc.OperationType.receive_status_on_client: 274 self.assertTrue( 275 test_common.metadata_transmitted( 276 server_trailing_metadata, 277 client_result.trailing_metadata())) 278 self.assertEqual(SERVER_STATUS_DETAILS, client_result.details()) 279 self.assertEqual(SERVER_STATUS_CODE, client_result.code()) 280 self.assertEqual( 281 set([ 282 cygrpc.OperationType.send_initial_metadata, 283 cygrpc.OperationType.send_message, 284 cygrpc.OperationType.send_close_from_client, 285 cygrpc.OperationType.receive_initial_metadata, 286 cygrpc.OperationType.receive_message, 287 cygrpc.OperationType.receive_status_on_client 288 ]), found_client_op_types) 289 290 self.assertEqual(5, len(server_event.batch_operations)) 291 found_server_op_types = set() 292 for server_result in server_event.batch_operations: 293 self.assertNotIn(server_result.type(), found_server_op_types) 294 found_server_op_types.add(server_result.type()) 295 if server_result.type() == cygrpc.OperationType.receive_message: 296 self.assertEqual(REQUEST, server_result.message()) 297 elif server_result.type( 298 ) == cygrpc.OperationType.receive_close_on_server: 299 self.assertFalse(server_result.cancelled()) 300 self.assertEqual( 301 set([ 302 cygrpc.OperationType.send_initial_metadata, 303 cygrpc.OperationType.receive_message, 304 cygrpc.OperationType.send_message, 305 cygrpc.OperationType.receive_close_on_server, 306 cygrpc.OperationType.send_status_from_server 307 ]), found_server_op_types) 308 309 del client_call 310 del server_call 311 312 def test_6522(self): 313 DEADLINE = time.time() + 5 314 DEADLINE_TOLERANCE = 0.25 315 METHOD = b'twinkies' 316 317 empty_metadata = () 318 319 # Prologue 320 server_request_tag = object() 321 self.server.request_call(self.server_completion_queue, 322 self.server_completion_queue, 323 server_request_tag) 324 client_call = self.client_channel.segregated_call( 325 0, METHOD, self.host_argument, DEADLINE, None, None, 326 ([( 327 [ 328 cygrpc.SendInitialMetadataOperation(empty_metadata, 329 _EMPTY_FLAGS), 330 cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), 331 ], 332 object(), 333 ), 334 ( 335 [ 336 cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), 337 ], 338 object(), 339 )])) 340 341 client_initial_metadata_event_future = test_utilities.SimpleFuture( 342 client_call.next_event) 343 344 request_event = self.server_completion_queue.poll(deadline=DEADLINE) 345 server_call = request_event.call 346 347 def perform_server_operations(operations, description): 348 return self._perform_queue_operations(operations, server_call, 349 self.server_completion_queue, 350 DEADLINE, description) 351 352 server_event_future = perform_server_operations([ 353 cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS), 354 ], "Server prologue") 355 356 client_initial_metadata_event_future.result() # force completion 357 server_event_future.result() 358 359 # Messaging 360 for _ in range(10): 361 client_call.operate([ 362 cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS), 363 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), 364 ], "Client message") 365 client_message_event_future = test_utilities.SimpleFuture( 366 client_call.next_event) 367 server_event_future = perform_server_operations([ 368 cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS), 369 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), 370 ], "Server receive") 371 372 client_message_event_future.result() # force completion 373 server_event_future.result() 374 375 # Epilogue 376 client_call.operate([ 377 cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), 378 ], "Client epilogue") 379 # One for ReceiveStatusOnClient, one for SendCloseFromClient. 380 client_events_future = test_utilities.SimpleFuture(lambda: { 381 client_call.next_event(), 382 client_call.next_event(), 383 }) 384 385 server_event_future = perform_server_operations([ 386 cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), 387 cygrpc.SendStatusFromServerOperation( 388 empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS) 389 ], "Server epilogue") 390 391 client_events_future.result() # force completion 392 server_event_future.result() 393 394 395class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin): 396 397 def setUp(self): 398 self.setUpMixin(None, None, None) 399 400 def tearDown(self): 401 self.tearDownMixin() 402 403 404class SecureServerSecureClient(unittest.TestCase, ServerClientMixin): 405 406 def setUp(self): 407 server_credentials = cygrpc.server_credentials_ssl( 408 None, [ 409 cygrpc.SslPemKeyCertPair(resources.private_key(), 410 resources.certificate_chain()) 411 ], False) 412 client_credentials = cygrpc.SSLChannelCredentials( 413 resources.test_root_certificates(), None, None) 414 self.setUpMixin(server_credentials, client_credentials, 415 _SSL_HOST_OVERRIDE) 416 417 def tearDown(self): 418 self.tearDownMixin() 419 420 421if __name__ == '__main__': 422 unittest.main(verbosity=2) 423