• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import platform
16import threading
17import time
18import unittest
19
20from grpc._cython import cygrpc
21
22from tests.unit import resources
23from tests.unit import test_common
24from tests.unit._cython import test_utilities
25
26_SSL_HOST_OVERRIDE = b"foo.test.google.fr"
27_CALL_CREDENTIALS_METADATA_KEY = "call-creds-key"
28_CALL_CREDENTIALS_METADATA_VALUE = "call-creds-value"
29_EMPTY_FLAGS = 0
30
31
32def _metadata_plugin(context, callback):
33    callback(
34        (
35            (
36                _CALL_CREDENTIALS_METADATA_KEY,
37                _CALL_CREDENTIALS_METADATA_VALUE,
38            ),
39        ),
40        cygrpc.StatusCode.ok,
41        b"",
42    )
43
44
45class TypeSmokeTest(unittest.TestCase):
46    def testCompletionQueueUpDown(self):
47        completion_queue = cygrpc.CompletionQueue()
48        del completion_queue
49
50    def testServerUpDown(self):
51        server = cygrpc.Server(
52            set(
53                [
54                    (
55                        b"grpc.so_reuseport",
56                        0,
57                    )
58                ]
59            ),
60            False,
61        )
62        del server
63
64    def testChannelUpDown(self):
65        channel = cygrpc.Channel(b"[::]:0", None, None)
66        channel.close(cygrpc.StatusCode.cancelled, "Test method anyway!")
67
68    def test_metadata_plugin_call_credentials_up_down(self):
69        cygrpc.MetadataPluginCallCredentials(
70            _metadata_plugin, b"test plugin name!"
71        )
72
73    def testServerStartNoExplicitShutdown(self):
74        server = cygrpc.Server(
75            [
76                (
77                    b"grpc.so_reuseport",
78                    0,
79                )
80            ],
81            False,
82        )
83        completion_queue = cygrpc.CompletionQueue()
84        server.register_completion_queue(completion_queue)
85        port = server.add_http2_port(b"[::]:0")
86        self.assertIsInstance(port, int)
87        server.start()
88        del server
89
90    def testServerStartShutdown(self):
91        completion_queue = cygrpc.CompletionQueue()
92        server = cygrpc.Server(
93            [
94                (
95                    b"grpc.so_reuseport",
96                    0,
97                ),
98            ],
99            False,
100        )
101        server.add_http2_port(b"[::]:0")
102        server.register_completion_queue(completion_queue)
103        server.start()
104        shutdown_tag = object()
105        server.shutdown(completion_queue, shutdown_tag)
106        event = completion_queue.poll()
107        self.assertEqual(
108            cygrpc.CompletionType.operation_complete, event.completion_type
109        )
110        self.assertIs(shutdown_tag, event.tag)
111        del server
112        del completion_queue
113
114
115class ServerClientMixin(object):
116    def setUpMixin(self, server_credentials, client_credentials, host_override):
117        self.server_completion_queue = cygrpc.CompletionQueue()
118        self.server = cygrpc.Server(
119            [
120                (
121                    b"grpc.so_reuseport",
122                    0,
123                )
124            ],
125            False,
126        )
127        self.server.register_completion_queue(self.server_completion_queue)
128        if server_credentials:
129            self.port = self.server.add_http2_port(
130                b"[::]:0", server_credentials
131            )
132        else:
133            self.port = self.server.add_http2_port(b"[::]:0")
134        self.server.start()
135        self.client_completion_queue = cygrpc.CompletionQueue()
136        if client_credentials:
137            client_channel_arguments = (
138                (
139                    cygrpc.ChannelArgKey.ssl_target_name_override,
140                    host_override,
141                ),
142            )
143            self.client_channel = cygrpc.Channel(
144                "localhost:{}".format(self.port).encode(),
145                client_channel_arguments,
146                client_credentials,
147            )
148        else:
149            self.client_channel = cygrpc.Channel(
150                "localhost:{}".format(self.port).encode(), set(), None
151            )
152        if host_override:
153            self.host_argument = None  # default host
154            self.expected_host = host_override
155        else:
156            # arbitrary host name necessitating no further identification
157            self.host_argument = b"hostess"
158            self.expected_host = self.host_argument
159
160    def tearDownMixin(self):
161        self.client_channel.close(cygrpc.StatusCode.ok, "test being torn down!")
162        del self.client_channel
163        del self.server
164        del self.client_completion_queue
165        del self.server_completion_queue
166
167    def _perform_queue_operations(
168        self, operations, call, queue, deadline, description
169    ):
170        """Perform the operations with given call, queue, and deadline.
171
172        Invocation errors are reported with as an exception with `description`
173        in the message. Performs the operations asynchronously, returning a
174        future.
175        """
176
177        def performer():
178            tag = object()
179            try:
180                call_result = call.start_client_batch(operations, tag)
181                self.assertEqual(cygrpc.CallError.ok, call_result)
182                event = queue.poll(deadline=deadline)
183                self.assertEqual(
184                    cygrpc.CompletionType.operation_complete,
185                    event.completion_type,
186                )
187                self.assertTrue(event.success)
188                self.assertIs(tag, event.tag)
189            except Exception as error:
190                raise Exception(
191                    "Error in '{}': {}".format(description, error.message)
192                )
193            return event
194
195        return test_utilities.SimpleFuture(performer)
196
197    def test_echo(self):
198        DEADLINE = time.time() + 5
199        DEADLINE_TOLERANCE = 0.25
200        CLIENT_METADATA_ASCII_KEY = "key"
201        CLIENT_METADATA_ASCII_VALUE = "val"
202        CLIENT_METADATA_BIN_KEY = "key-bin"
203        CLIENT_METADATA_BIN_VALUE = b"\0" * 1000
204        SERVER_INITIAL_METADATA_KEY = "init_me_me_me"
205        SERVER_INITIAL_METADATA_VALUE = "whodawha?"
206        SERVER_TRAILING_METADATA_KEY = "california_is_in_a_drought"
207        SERVER_TRAILING_METADATA_VALUE = "zomg it is"
208        SERVER_STATUS_CODE = cygrpc.StatusCode.ok
209        SERVER_STATUS_DETAILS = "our work is never over"
210        REQUEST = b"in death a member of project mayhem has a name"
211        RESPONSE = b"his name is robert paulson"
212        METHOD = b"twinkies"
213
214        server_request_tag = object()
215        request_call_result = self.server.request_call(
216            self.server_completion_queue,
217            self.server_completion_queue,
218            server_request_tag,
219        )
220
221        self.assertEqual(cygrpc.CallError.ok, request_call_result)
222
223        client_call_tag = object()
224        client_initial_metadata = (
225            (
226                CLIENT_METADATA_ASCII_KEY,
227                CLIENT_METADATA_ASCII_VALUE,
228            ),
229            (
230                CLIENT_METADATA_BIN_KEY,
231                CLIENT_METADATA_BIN_VALUE,
232            ),
233        )
234        client_call = self.client_channel.integrated_call(
235            0,
236            METHOD,
237            self.host_argument,
238            DEADLINE,
239            client_initial_metadata,
240            None,
241            [
242                (
243                    [
244                        cygrpc.SendInitialMetadataOperation(
245                            client_initial_metadata, _EMPTY_FLAGS
246                        ),
247                        cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS),
248                        cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
249                        cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
250                        cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
251                        cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
252                    ],
253                    client_call_tag,
254                ),
255            ],
256        )
257        client_event_future = test_utilities.SimpleFuture(
258            self.client_channel.next_call_event
259        )
260
261        request_event = self.server_completion_queue.poll(deadline=DEADLINE)
262        self.assertEqual(
263            cygrpc.CompletionType.operation_complete,
264            request_event.completion_type,
265        )
266        self.assertIsInstance(request_event.call, cygrpc.Call)
267        self.assertIs(server_request_tag, request_event.tag)
268        self.assertTrue(
269            test_common.metadata_transmitted(
270                client_initial_metadata, request_event.invocation_metadata
271            )
272        )
273        self.assertEqual(METHOD, request_event.call_details.method)
274        self.assertEqual(self.expected_host, request_event.call_details.host)
275        self.assertLess(
276            abs(DEADLINE - request_event.call_details.deadline),
277            DEADLINE_TOLERANCE,
278        )
279
280        server_call_tag = object()
281        server_call = request_event.call
282        server_start_batch_result = server_call.start_server_batch(
283            [
284                cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
285            ],
286            server_call_tag,
287        )
288        self.assertEqual(cygrpc.CallError.ok, server_start_batch_result)
289
290        server_message_event = self.server_completion_queue.poll(
291            deadline=DEADLINE
292        )
293
294        server_call_tag = object()
295        server_initial_metadata = (
296            (
297                SERVER_INITIAL_METADATA_KEY,
298                SERVER_INITIAL_METADATA_VALUE,
299            ),
300        )
301        server_trailing_metadata = (
302            (
303                SERVER_TRAILING_METADATA_KEY,
304                SERVER_TRAILING_METADATA_VALUE,
305            ),
306        )
307        server_start_batch_result = server_call.start_server_batch(
308            [
309                cygrpc.SendInitialMetadataOperation(
310                    server_initial_metadata, _EMPTY_FLAGS
311                ),
312                cygrpc.SendMessageOperation(RESPONSE, _EMPTY_FLAGS),
313                cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
314                cygrpc.SendStatusFromServerOperation(
315                    server_trailing_metadata,
316                    SERVER_STATUS_CODE,
317                    SERVER_STATUS_DETAILS,
318                    _EMPTY_FLAGS,
319                ),
320            ],
321            server_call_tag,
322        )
323        self.assertEqual(cygrpc.CallError.ok, server_start_batch_result)
324
325        server_event = self.server_completion_queue.poll(deadline=DEADLINE)
326        client_event = client_event_future.result()
327
328        self.assertEqual(6, len(client_event.batch_operations))
329        found_client_op_types = set()
330        for client_result in client_event.batch_operations:
331            # we expect each op type to be unique
332            self.assertNotIn(client_result.type(), found_client_op_types)
333            found_client_op_types.add(client_result.type())
334            if (
335                client_result.type()
336                == cygrpc.OperationType.receive_initial_metadata
337            ):
338                self.assertTrue(
339                    test_common.metadata_transmitted(
340                        server_initial_metadata,
341                        client_result.initial_metadata(),
342                    )
343                )
344            elif client_result.type() == cygrpc.OperationType.receive_message:
345                self.assertEqual(RESPONSE, client_result.message())
346            elif (
347                client_result.type()
348                == cygrpc.OperationType.receive_status_on_client
349            ):
350                self.assertTrue(
351                    test_common.metadata_transmitted(
352                        server_trailing_metadata,
353                        client_result.trailing_metadata(),
354                    )
355                )
356                self.assertEqual(SERVER_STATUS_DETAILS, client_result.details())
357                self.assertEqual(SERVER_STATUS_CODE, client_result.code())
358        self.assertEqual(
359            set(
360                [
361                    cygrpc.OperationType.send_initial_metadata,
362                    cygrpc.OperationType.send_message,
363                    cygrpc.OperationType.send_close_from_client,
364                    cygrpc.OperationType.receive_initial_metadata,
365                    cygrpc.OperationType.receive_message,
366                    cygrpc.OperationType.receive_status_on_client,
367                ]
368            ),
369            found_client_op_types,
370        )
371
372        self.assertEqual(1, len(server_message_event.batch_operations))
373        found_server_op_types = set()
374        for server_result in server_message_event.batch_operations:
375            self.assertNotIn(server_result.type(), found_server_op_types)
376            found_server_op_types.add(server_result.type())
377            if server_result.type() == cygrpc.OperationType.receive_message:
378                self.assertEqual(REQUEST, server_result.message())
379            elif (
380                server_result.type()
381                == cygrpc.OperationType.receive_close_on_server
382            ):
383                self.assertFalse(server_result.cancelled())
384        self.assertEqual(
385            set(
386                [
387                    cygrpc.OperationType.receive_message,
388                ]
389            ),
390            found_server_op_types,
391        )
392
393        self.assertEqual(4, len(server_event.batch_operations))
394        found_server_op_types = set()
395        for server_result in server_event.batch_operations:
396            self.assertNotIn(server_result.type(), found_server_op_types)
397            found_server_op_types.add(server_result.type())
398            if server_result.type() == cygrpc.OperationType.receive_message:
399                self.assertEqual(REQUEST, server_result.message())
400            elif (
401                server_result.type()
402                == cygrpc.OperationType.receive_close_on_server
403            ):
404                self.assertFalse(server_result.cancelled())
405        self.assertEqual(
406            set(
407                [
408                    cygrpc.OperationType.send_initial_metadata,
409                    cygrpc.OperationType.send_message,
410                    cygrpc.OperationType.receive_close_on_server,
411                    cygrpc.OperationType.send_status_from_server,
412                ]
413            ),
414            found_server_op_types,
415        )
416
417        del client_call
418        del server_call
419
420    def test_6522(self):
421        DEADLINE = time.time() + 5
422        DEADLINE_TOLERANCE = 0.25
423        METHOD = b"twinkies"
424
425        empty_metadata = ()
426
427        # Prologue
428        server_request_tag = object()
429        self.server.request_call(
430            self.server_completion_queue,
431            self.server_completion_queue,
432            server_request_tag,
433        )
434        client_call = self.client_channel.segregated_call(
435            0,
436            METHOD,
437            self.host_argument,
438            DEADLINE,
439            None,
440            None,
441            (
442                [
443                    (
444                        [
445                            cygrpc.SendInitialMetadataOperation(
446                                empty_metadata, _EMPTY_FLAGS
447                            ),
448                            cygrpc.ReceiveInitialMetadataOperation(
449                                _EMPTY_FLAGS
450                            ),
451                        ],
452                        object(),
453                    ),
454                    (
455                        [
456                            cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
457                        ],
458                        object(),
459                    ),
460                ]
461            ),
462        )
463
464        client_initial_metadata_event_future = test_utilities.SimpleFuture(
465            client_call.next_event
466        )
467
468        request_event = self.server_completion_queue.poll(deadline=DEADLINE)
469        server_call = request_event.call
470
471        def perform_server_operations(operations, description):
472            return self._perform_queue_operations(
473                operations,
474                server_call,
475                self.server_completion_queue,
476                DEADLINE,
477                description,
478            )
479
480        server_event_future = perform_server_operations(
481            [
482                cygrpc.SendInitialMetadataOperation(
483                    empty_metadata, _EMPTY_FLAGS
484                ),
485            ],
486            "Server prologue",
487        )
488
489        client_initial_metadata_event_future.result()  # force completion
490        server_event_future.result()
491
492        # Messaging
493        for _ in range(10):
494            client_call.operate(
495                [
496                    cygrpc.SendMessageOperation(b"", _EMPTY_FLAGS),
497                    cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
498                ],
499                "Client message",
500            )
501            client_message_event_future = test_utilities.SimpleFuture(
502                client_call.next_event
503            )
504            server_event_future = perform_server_operations(
505                [
506                    cygrpc.SendMessageOperation(b"", _EMPTY_FLAGS),
507                    cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
508                ],
509                "Server receive",
510            )
511
512            client_message_event_future.result()  # force completion
513            server_event_future.result()
514
515        # Epilogue
516        client_call.operate(
517            [
518                cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
519            ],
520            "Client epilogue",
521        )
522        # One for ReceiveStatusOnClient, one for SendCloseFromClient.
523        client_events_future = test_utilities.SimpleFuture(
524            lambda: {
525                client_call.next_event(),
526                client_call.next_event(),
527            }
528        )
529
530        server_event_future = perform_server_operations(
531            [
532                cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
533                cygrpc.SendStatusFromServerOperation(
534                    empty_metadata, cygrpc.StatusCode.ok, b"", _EMPTY_FLAGS
535                ),
536            ],
537            "Server epilogue",
538        )
539
540        client_events_future.result()  # force completion
541        server_event_future.result()
542
543
544class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin):
545    def setUp(self):
546        self.setUpMixin(None, None, None)
547
548    def tearDown(self):
549        self.tearDownMixin()
550
551
552class SecureServerSecureClient(unittest.TestCase, ServerClientMixin):
553    def setUp(self):
554        server_credentials = cygrpc.server_credentials_ssl(
555            None,
556            [
557                cygrpc.SslPemKeyCertPair(
558                    resources.private_key(), resources.certificate_chain()
559                )
560            ],
561            False,
562        )
563        client_credentials = cygrpc.SSLChannelCredentials(
564            resources.test_root_certificates(), None, None
565        )
566        self.setUpMixin(
567            server_credentials, client_credentials, _SSL_HOST_OVERRIDE
568        )
569
570    def tearDown(self):
571        self.tearDownMixin()
572
573
574if __name__ == "__main__":
575    unittest.main(verbosity=2)
576