1# Copyright 2015 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import time 16import threading 17import unittest 18import platform 19 20from grpc._cython import cygrpc 21from tests.unit._cython import test_utilities 22from tests.unit import test_common 23from tests.unit import resources 24 25_SSL_HOST_OVERRIDE = b'foo.test.google.fr' 26_CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key' 27_CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value' 28_EMPTY_FLAGS = 0 29 30 31def _metadata_plugin(context, callback): 32 callback((( 33 _CALL_CREDENTIALS_METADATA_KEY, 34 _CALL_CREDENTIALS_METADATA_VALUE, 35 ),), cygrpc.StatusCode.ok, b'') 36 37 38class TypeSmokeTest(unittest.TestCase): 39 40 def testCompletionQueueUpDown(self): 41 completion_queue = cygrpc.CompletionQueue() 42 del completion_queue 43 44 def testServerUpDown(self): 45 server = cygrpc.Server(set([ 46 ( 47 b'grpc.so_reuseport', 48 0, 49 ), 50 ])) 51 del server 52 53 def testChannelUpDown(self): 54 channel = cygrpc.Channel(b'[::]:0', None, None) 55 channel.close(cygrpc.StatusCode.cancelled, 'Test method anyway!') 56 57 def test_metadata_plugin_call_credentials_up_down(self): 58 cygrpc.MetadataPluginCallCredentials(_metadata_plugin, 59 b'test plugin name!') 60 61 def testServerStartNoExplicitShutdown(self): 62 server = cygrpc.Server([ 63 ( 64 b'grpc.so_reuseport', 65 0, 66 ), 67 ]) 68 completion_queue = cygrpc.CompletionQueue() 69 server.register_completion_queue(completion_queue) 70 port = server.add_http2_port(b'[::]:0') 71 self.assertIsInstance(port, int) 72 server.start() 73 del server 74 75 def testServerStartShutdown(self): 76 completion_queue = cygrpc.CompletionQueue() 77 server = cygrpc.Server([ 78 ( 79 b'grpc.so_reuseport', 80 0, 81 ), 82 ]) 83 server.add_http2_port(b'[::]:0') 84 server.register_completion_queue(completion_queue) 85 server.start() 86 shutdown_tag = object() 87 server.shutdown(completion_queue, shutdown_tag) 88 event = completion_queue.poll() 89 self.assertEqual(cygrpc.CompletionType.operation_complete, 90 event.completion_type) 91 self.assertIs(shutdown_tag, event.tag) 92 del server 93 del completion_queue 94 95 96class ServerClientMixin(object): 97 98 def setUpMixin(self, server_credentials, client_credentials, host_override): 99 self.server_completion_queue = cygrpc.CompletionQueue() 100 self.server = cygrpc.Server([ 101 ( 102 b'grpc.so_reuseport', 103 0, 104 ), 105 ]) 106 self.server.register_completion_queue(self.server_completion_queue) 107 if server_credentials: 108 self.port = self.server.add_http2_port(b'[::]:0', 109 server_credentials) 110 else: 111 self.port = self.server.add_http2_port(b'[::]:0') 112 self.server.start() 113 self.client_completion_queue = cygrpc.CompletionQueue() 114 if client_credentials: 115 client_channel_arguments = (( 116 cygrpc.ChannelArgKey.ssl_target_name_override, 117 host_override, 118 ),) 119 self.client_channel = cygrpc.Channel('localhost:{}'.format( 120 self.port).encode(), client_channel_arguments, 121 client_credentials) 122 else: 123 self.client_channel = cygrpc.Channel('localhost:{}'.format( 124 self.port).encode(), set(), None) 125 if host_override: 126 self.host_argument = None # default host 127 self.expected_host = host_override 128 else: 129 # arbitrary host name necessitating no further identification 130 self.host_argument = b'hostess' 131 self.expected_host = self.host_argument 132 133 def tearDownMixin(self): 134 self.client_channel.close(cygrpc.StatusCode.ok, 'test being torn down!') 135 del self.client_channel 136 del self.server 137 del self.client_completion_queue 138 del self.server_completion_queue 139 140 def _perform_queue_operations(self, operations, call, queue, deadline, 141 description): 142 """Perform the operations with given call, queue, and deadline. 143 144 Invocation errors are reported with as an exception with `description` 145 in the message. Performs the operations asynchronously, returning a 146 future. 147 """ 148 149 def performer(): 150 tag = object() 151 try: 152 call_result = call.start_client_batch(operations, tag) 153 self.assertEqual(cygrpc.CallError.ok, call_result) 154 event = queue.poll(deadline=deadline) 155 self.assertEqual(cygrpc.CompletionType.operation_complete, 156 event.completion_type) 157 self.assertTrue(event.success) 158 self.assertIs(tag, event.tag) 159 except Exception as error: 160 raise Exception("Error in '{}': {}".format( 161 description, error.message)) 162 return event 163 164 return test_utilities.SimpleFuture(performer) 165 166 def test_echo(self): 167 DEADLINE = time.time() + 5 168 DEADLINE_TOLERANCE = 0.25 169 CLIENT_METADATA_ASCII_KEY = 'key' 170 CLIENT_METADATA_ASCII_VALUE = 'val' 171 CLIENT_METADATA_BIN_KEY = 'key-bin' 172 CLIENT_METADATA_BIN_VALUE = b'\0' * 1000 173 SERVER_INITIAL_METADATA_KEY = 'init_me_me_me' 174 SERVER_INITIAL_METADATA_VALUE = 'whodawha?' 175 SERVER_TRAILING_METADATA_KEY = 'california_is_in_a_drought' 176 SERVER_TRAILING_METADATA_VALUE = 'zomg it is' 177 SERVER_STATUS_CODE = cygrpc.StatusCode.ok 178 SERVER_STATUS_DETAILS = 'our work is never over' 179 REQUEST = b'in death a member of project mayhem has a name' 180 RESPONSE = b'his name is robert paulson' 181 METHOD = b'twinkies' 182 183 server_request_tag = object() 184 request_call_result = self.server.request_call( 185 self.server_completion_queue, self.server_completion_queue, 186 server_request_tag) 187 188 self.assertEqual(cygrpc.CallError.ok, request_call_result) 189 190 client_call_tag = object() 191 client_initial_metadata = ( 192 ( 193 CLIENT_METADATA_ASCII_KEY, 194 CLIENT_METADATA_ASCII_VALUE, 195 ), 196 ( 197 CLIENT_METADATA_BIN_KEY, 198 CLIENT_METADATA_BIN_VALUE, 199 ), 200 ) 201 client_call = self.client_channel.integrated_call( 202 0, METHOD, self.host_argument, DEADLINE, client_initial_metadata, 203 None, [ 204 ( 205 [ 206 cygrpc.SendInitialMetadataOperation( 207 client_initial_metadata, _EMPTY_FLAGS), 208 cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS), 209 cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), 210 cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), 211 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), 212 cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), 213 ], 214 client_call_tag, 215 ), 216 ]) 217 client_event_future = test_utilities.SimpleFuture( 218 self.client_channel.next_call_event) 219 220 request_event = self.server_completion_queue.poll(deadline=DEADLINE) 221 self.assertEqual(cygrpc.CompletionType.operation_complete, 222 request_event.completion_type) 223 self.assertIsInstance(request_event.call, cygrpc.Call) 224 self.assertIs(server_request_tag, request_event.tag) 225 self.assertTrue( 226 test_common.metadata_transmitted(client_initial_metadata, 227 request_event.invocation_metadata)) 228 self.assertEqual(METHOD, request_event.call_details.method) 229 self.assertEqual(self.expected_host, request_event.call_details.host) 230 self.assertLess( 231 abs(DEADLINE - request_event.call_details.deadline), 232 DEADLINE_TOLERANCE) 233 234 server_call_tag = object() 235 server_call = request_event.call 236 server_initial_metadata = (( 237 SERVER_INITIAL_METADATA_KEY, 238 SERVER_INITIAL_METADATA_VALUE, 239 ),) 240 server_trailing_metadata = (( 241 SERVER_TRAILING_METADATA_KEY, 242 SERVER_TRAILING_METADATA_VALUE, 243 ),) 244 server_start_batch_result = server_call.start_server_batch([ 245 cygrpc.SendInitialMetadataOperation(server_initial_metadata, 246 _EMPTY_FLAGS), 247 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), 248 cygrpc.SendMessageOperation(RESPONSE, _EMPTY_FLAGS), 249 cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), 250 cygrpc.SendStatusFromServerOperation( 251 server_trailing_metadata, SERVER_STATUS_CODE, 252 SERVER_STATUS_DETAILS, _EMPTY_FLAGS) 253 ], server_call_tag) 254 self.assertEqual(cygrpc.CallError.ok, server_start_batch_result) 255 256 server_event = self.server_completion_queue.poll(deadline=DEADLINE) 257 client_event = client_event_future.result() 258 259 self.assertEqual(6, len(client_event.batch_operations)) 260 found_client_op_types = set() 261 for client_result in client_event.batch_operations: 262 # we expect each op type to be unique 263 self.assertNotIn(client_result.type(), found_client_op_types) 264 found_client_op_types.add(client_result.type()) 265 if client_result.type( 266 ) == cygrpc.OperationType.receive_initial_metadata: 267 self.assertTrue( 268 test_common.metadata_transmitted( 269 server_initial_metadata, 270 client_result.initial_metadata())) 271 elif client_result.type() == cygrpc.OperationType.receive_message: 272 self.assertEqual(RESPONSE, client_result.message()) 273 elif client_result.type( 274 ) == cygrpc.OperationType.receive_status_on_client: 275 self.assertTrue( 276 test_common.metadata_transmitted( 277 server_trailing_metadata, 278 client_result.trailing_metadata())) 279 self.assertEqual(SERVER_STATUS_DETAILS, client_result.details()) 280 self.assertEqual(SERVER_STATUS_CODE, client_result.code()) 281 self.assertEqual( 282 set([ 283 cygrpc.OperationType.send_initial_metadata, 284 cygrpc.OperationType.send_message, 285 cygrpc.OperationType.send_close_from_client, 286 cygrpc.OperationType.receive_initial_metadata, 287 cygrpc.OperationType.receive_message, 288 cygrpc.OperationType.receive_status_on_client 289 ]), found_client_op_types) 290 291 self.assertEqual(5, len(server_event.batch_operations)) 292 found_server_op_types = set() 293 for server_result in server_event.batch_operations: 294 self.assertNotIn(server_result.type(), found_server_op_types) 295 found_server_op_types.add(server_result.type()) 296 if server_result.type() == cygrpc.OperationType.receive_message: 297 self.assertEqual(REQUEST, server_result.message()) 298 elif server_result.type( 299 ) == cygrpc.OperationType.receive_close_on_server: 300 self.assertFalse(server_result.cancelled()) 301 self.assertEqual( 302 set([ 303 cygrpc.OperationType.send_initial_metadata, 304 cygrpc.OperationType.receive_message, 305 cygrpc.OperationType.send_message, 306 cygrpc.OperationType.receive_close_on_server, 307 cygrpc.OperationType.send_status_from_server 308 ]), found_server_op_types) 309 310 del client_call 311 del server_call 312 313 def test_6522(self): 314 DEADLINE = time.time() + 5 315 DEADLINE_TOLERANCE = 0.25 316 METHOD = b'twinkies' 317 318 empty_metadata = () 319 320 # Prologue 321 server_request_tag = object() 322 self.server.request_call(self.server_completion_queue, 323 self.server_completion_queue, 324 server_request_tag) 325 client_call = self.client_channel.segregated_call( 326 0, METHOD, self.host_argument, DEADLINE, None, None, ([( 327 [ 328 cygrpc.SendInitialMetadataOperation(empty_metadata, 329 _EMPTY_FLAGS), 330 cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), 331 ], 332 object(), 333 ), ( 334 [ 335 cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), 336 ], 337 object(), 338 )])) 339 340 client_initial_metadata_event_future = test_utilities.SimpleFuture( 341 client_call.next_event) 342 343 request_event = self.server_completion_queue.poll(deadline=DEADLINE) 344 server_call = request_event.call 345 346 def perform_server_operations(operations, description): 347 return self._perform_queue_operations(operations, server_call, 348 self.server_completion_queue, 349 DEADLINE, description) 350 351 server_event_future = perform_server_operations([ 352 cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS), 353 ], "Server prologue") 354 355 client_initial_metadata_event_future.result() # force completion 356 server_event_future.result() 357 358 # Messaging 359 for _ in range(10): 360 client_call.operate([ 361 cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS), 362 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), 363 ], "Client message") 364 client_message_event_future = test_utilities.SimpleFuture( 365 client_call.next_event) 366 server_event_future = perform_server_operations([ 367 cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS), 368 cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), 369 ], "Server receive") 370 371 client_message_event_future.result() # force completion 372 server_event_future.result() 373 374 # Epilogue 375 client_call.operate([ 376 cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), 377 ], "Client epilogue") 378 # One for ReceiveStatusOnClient, one for SendCloseFromClient. 379 client_events_future = test_utilities.SimpleFuture( 380 lambda: { 381 client_call.next_event(), 382 client_call.next_event(),}) 383 384 server_event_future = perform_server_operations([ 385 cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), 386 cygrpc.SendStatusFromServerOperation( 387 empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS) 388 ], "Server epilogue") 389 390 client_events_future.result() # force completion 391 server_event_future.result() 392 393 394class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin): 395 396 def setUp(self): 397 self.setUpMixin(None, None, None) 398 399 def tearDown(self): 400 self.tearDownMixin() 401 402 403class SecureServerSecureClient(unittest.TestCase, ServerClientMixin): 404 405 def setUp(self): 406 server_credentials = cygrpc.server_credentials_ssl( 407 None, [ 408 cygrpc.SslPemKeyCertPair(resources.private_key(), 409 resources.certificate_chain()) 410 ], False) 411 client_credentials = cygrpc.SSLChannelCredentials( 412 resources.test_root_certificates(), None, None) 413 self.setUpMixin(server_credentials, client_credentials, 414 _SSL_HOST_OVERRIDE) 415 416 def tearDown(self): 417 self.tearDownMixin() 418 419 420if __name__ == '__main__': 421 unittest.main(verbosity=2) 422